MRRA LogoMRRA
Advanced Features

Caching and Performance

Cache management, optimization strategies, and recommended workflows

Caching and Performance

MRRA includes comprehensive caching mechanisms to optimize performance and reduce redundant computations. This is especially important when working with expensive LLM calls and complex graph operations.

Cache Architecture

Cache Location and Structure

MRRA uses a hierarchical caching system based on trajectory batch hashes:

.mrra_cache/
├── <tb_hash>/
│   ├── activities_<key>.json
│   ├── graph_<key>.gpickle
│   ├── chains_<key>.json
│   ├── patterns_<user>.json
│   └── reflection_<query_hash>.json
└── global/
    ├── models/
    └── configs/

Cache Directory: The .mrra_cache/ directory is automatically added to .gitignore to prevent version control of cache files.

Trajectory Batch Hashing

The cache system uses content-based hashing to ensure cache validity:

from mrra.persist.cache import CacheManager, compute_tb_hash

# Generate cache key from trajectory data
tb_hash = compute_tb_hash(tb)
print(f"Trajectory batch hash: {tb_hash}")

# Cache is automatically organized by this hash
cm = CacheManager()

Core Caching Operations

Activity Caching

Activities with assigned purposes are cached to avoid expensive LLM re-computation:

from mrra.data.activity import ActivityExtractor
from mrra.analysis.activity_purpose import ActivityPurposeAssigner
from mrra.persist.cache import CacheManager, compute_tb_hash

# Extract and assign purposes
ext_cfg = dict(method="radius", radius_m=300, min_dwell_minutes=30)
acts = ActivityExtractor(tb, **ext_cfg).extract()
acts = ActivityPurposeAssigner(tb, llm=llm, concurrency=8).assign(acts)

# Cache activities with purposes
cm = CacheManager()
tb_hash = compute_tb_hash(tb)
cm.save_activities(tb_hash, "default", acts)

print(f"Cached {len(acts)} activities")
# Load cached activities
cm = CacheManager()
tb_hash = compute_tb_hash(tb)

cached_acts = cm.load_activities(tb_hash, "default")
if cached_acts:
    print(f"Loaded {len(cached_acts)} cached activities")
    acts = cached_acts
else:
    print("No cached activities found, computing...")
    # Compute activities as needed
# Use descriptive cache keys for different configurations
def create_activity_cache_key(ext_cfg, llm_config):
    """Create cache key that includes configuration parameters"""
    key_parts = [
        ext_cfg.get('method', 'radius'),
        str(ext_cfg.get('radius_m', 300)),
        str(ext_cfg.get('min_dwell_minutes', 30)),
        llm_config.get('model', 'default'),
        str(ext_cfg.get('max_gap_minutes', 90))
    ]
    return "_".join(key_parts)

# Use configuration-specific cache key
cache_key = create_activity_cache_key(ext_cfg, llm_cfg)
cm.save_activities(tb_hash, cache_key, acts)

# Later retrieval with same key
cached_acts = cm.load_activities(tb_hash, cache_key)

Graph Caching

Mobility graphs are cached using pickle serialization for fast loading:

from mrra.graph.mobility_graph import MobilityGraph, GraphConfig

# Build and cache graph
cfg = GraphConfig(grid_size_m=200, min_dwell_minutes=5, use_activities=True)
mg = MobilityGraph(tb, cfg, activities=acts, assume_purposes_assigned=True)

# Cache graph
graph_key = f"mobility_grid{cfg.grid_size_m}_dwell{cfg.min_dwell_minutes}"
cm.save_graph(tb_hash, graph_key, mg.G)

# Load cached graph
cached_graph = cm.load_graph(tb_hash, graph_key)
if cached_graph:
    mg.G = cached_graph
    print("Loaded cached mobility graph")
else:
    print("Building new mobility graph...")

Chain and Pattern Caching

Cache activity chains and user patterns for quick retrieval:

# Cache activity chains
chain_records = []  # Generated from activities
cm.save_json(tb_hash, "chains_default", {
    "count": len(chain_records), 
    "records": chain_records[:1000]  # Limit size
}, kind="chains")

# Cache user patterns  
from mrra.graph.pattern import PatternGenerate

pat = PatternGenerate(tb)
patterns = pat.long_short_patterns(user_id)
cm.save_json(tb_hash, f"patterns_{user_id}", patterns, kind="patterns")

# Load cached data
cached_chains = cm.load_json(tb_hash, "chains_default", kind="chains")
cached_patterns = cm.load_json(tb_hash, f"patterns_{user_id}", kind="patterns")

Advanced Caching Strategies

Layered Caching

Implement multiple cache layers for different data types:

class LayeredCacheManager:
    def __init__(self, base_dir=None):
        self.cache = CacheManager(base_dir)
        self.memory_cache = {}  # In-memory cache for frequent access
        self.cache_stats = {'hits': 0, 'misses': 0}
    
    def get_activities(self, tb_hash, key):
        """Get activities with memory + disk caching"""
        
        # Check memory cache first
        memory_key = f"activities_{tb_hash}_{key}"
        if memory_key in self.memory_cache:
            self.cache_stats['hits'] += 1
            return self.memory_cache[memory_key]
        
        # Check disk cache
        disk_result = self.cache.load_activities(tb_hash, key)
        if disk_result:
            # Store in memory for future access
            self.memory_cache[memory_key] = disk_result
            self.cache_stats['hits'] += 1
            return disk_result
        
        # Cache miss
        self.cache_stats['misses'] += 1
        return None
    
    def save_activities(self, tb_hash, key, activities):
        """Save to both memory and disk"""
        
        # Save to disk
        self.cache.save_activities(tb_hash, key, activities)
        
        # Save to memory
        memory_key = f"activities_{tb_hash}_{key}"
        self.memory_cache[memory_key] = activities
    
    def clear_memory_cache(self):
        """Clear memory cache to free up RAM"""
        self.memory_cache.clear()
    
    def get_cache_stats(self):
        """Get cache performance statistics"""
        total = self.cache_stats['hits'] + self.cache_stats['misses']
        hit_rate = self.cache_stats['hits'] / total if total > 0 else 0
        
        return {
            'hits': self.cache_stats['hits'],
            'misses': self.cache_stats['misses'],
            'hit_rate': hit_rate,
            'memory_entries': len(self.memory_cache)
        }

Time-Based Cache Invalidation

Implement time-based cache expiration for dynamic data:

from datetime import datetime, timedelta
import json

class TimeCacheManager(CacheManager):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.default_ttl = timedelta(hours=24)  # 24-hour default TTL
    
    def save_with_ttl(self, tb_hash, key, data, ttl=None, kind="general"):
        """Save data with time-to-live"""
        ttl = ttl or self.default_ttl
        expires_at = datetime.now() + ttl
        
        cache_data = {
            'data': data,
            'created_at': datetime.now().isoformat(),
            'expires_at': expires_at.isoformat(),
            'ttl_hours': ttl.total_seconds() / 3600
        }
        
        self.save_json(tb_hash, key, cache_data, kind=kind)
    
    def load_with_ttl_check(self, tb_hash, key, kind="general"):
        """Load data with TTL validation"""
        cached = self.load_json(tb_hash, key, kind=kind)
        
        if not cached:
            return None
        
        # Check expiration
        expires_at = datetime.fromisoformat(cached['expires_at'])
        if datetime.now() > expires_at:
            print(f"Cache expired for {key}, removing...")
            self._remove_cache_file(tb_hash, key, kind)
            return None
        
        return cached['data']
    
    def cleanup_expired(self, tb_hash):
        """Remove all expired cache entries"""
        import os
        import glob
        
        cache_dir = os.path.join(self.base_dir, tb_hash)
        if not os.path.exists(cache_dir):
            return
        
        expired_count = 0
        for cache_file in glob.glob(os.path.join(cache_dir, "*.json")):
            try:
                with open(cache_file, 'r') as f:
                    data = json.load(f)
                
                if 'expires_at' in data:
                    expires_at = datetime.fromisoformat(data['expires_at'])
                    if datetime.now() > expires_at:
                        os.remove(cache_file)
                        expired_count += 1
            except Exception as e:
                print(f"Error checking cache file {cache_file}: {e}")
        
        print(f"Cleaned up {expired_count} expired cache entries")

Optimal Caching Strategy

Follow this recommended workflow for maximum performance:

def optimized_mrra_workflow(df, llm_cfg, force_refresh=False):
    """Optimized MRRA workflow with comprehensive caching"""
    
    # 1. Initialize trajectory batch and cache
    tb = TrajectoryBatch(df)
    cm = CacheManager()
    tb_hash = compute_tb_hash(tb)
    
    print(f"Working with trajectory batch: {tb_hash}")
    
    # 2. Try to load cached activities (most expensive step)
    activities_key = "activities_with_purposes"
    
    if not force_refresh:
        acts = cm.load_activities(tb_hash, activities_key)
        if acts:
            print(f"✅ Loaded {len(acts)} cached activities with purposes")
        else:
            print("📊 Computing activities and purposes...")
            acts = compute_and_cache_activities(tb, llm_cfg, cm, tb_hash, activities_key)
    else:
        print("🔄 Force refresh: recomputing activities...")
        acts = compute_and_cache_activities(tb, llm_cfg, cm, tb_hash, activities_key)
    
    # 3. Try to load cached graph
    graph_key = "mobility_graph_default"
    cached_graph = cm.load_graph(tb_hash, graph_key)
    
    if cached_graph and not force_refresh:
        print("✅ Loaded cached mobility graph")
        cfg = GraphConfig(grid_size_m=200, min_dwell_minutes=5, use_activities=True)
        mg = MobilityGraph(tb, cfg)
        mg.G = cached_graph
    else:
        print("📊 Building mobility graph...")
        cfg = GraphConfig(grid_size_m=200, min_dwell_minutes=5, use_activities=True)
        mg = MobilityGraph(tb, cfg, activities=acts, assume_purposes_assigned=True)
        cm.save_graph(tb_hash, graph_key, mg.G)
        print("💾 Cached mobility graph")
    
    # 4. Generate and cache patterns if not exists
    user_id = tb.users()[0]
    patterns_key = f"patterns_{user_id}"
    patterns = cm.load_json(tb_hash, patterns_key, kind="patterns")
    
    if not patterns or force_refresh:
        print("📊 Generating user patterns...")
        from mrra.graph.pattern import PatternGenerate
        pat = PatternGenerate(tb)
        patterns = pat.long_short_patterns(user_id)
        cm.save_json(tb_hash, patterns_key, patterns, kind="patterns")
        print("💾 Cached user patterns")
    else:
        print("✅ Loaded cached user patterns")
    
    # 5. Build retriever and agent (these are fast)
    retriever = GraphRAGGenerate(tb=tb, mobility_graph=mg)
    
    reflection_cfg = dict(
        max_round=1,
        subAgents=[
            {"name": "temporal", "prompt": "Select the most likely location id from Options."},
            {"name": "spatial",  "prompt": "Select the most likely location id from Options."},
        ],
        aggregator="confidence_weighted_voting",
    )
    
    agent = build_mrra_agent(llm=llm_cfg, retriever=retriever, reflection=reflection_cfg)
    
    return agent, {
        'trajectory_batch': tb,
        'activities': acts,
        'mobility_graph': mg,
        'patterns': patterns,
        'cache_hash': tb_hash
    }

def compute_and_cache_activities(tb, llm_cfg, cm, tb_hash, key):
    """Compute activities with purposes and cache results"""
    
    # Extract activities
    ext_cfg = dict(method="radius", radius_m=300, min_dwell_minutes=30, max_gap_minutes=90)
    acts = ActivityExtractor(tb, **ext_cfg).extract()
    
    # Assign purposes with LLM
    llm = make_llm(**llm_cfg)
    acts = ActivityPurposeAssigner(tb, llm=llm, concurrency=8).assign(acts)
    
    # Cache results
    cm.save_activities(tb_hash, key, acts)
    print(f"💾 Cached {len(acts)} activities with purposes")
    
    return acts

# Usage
agent, components = optimized_mrra_workflow(df, llm_cfg, force_refresh=False)

# Subsequent runs will be much faster due to caching
agent2, components2 = optimized_mrra_workflow(df, llm_cfg, force_refresh=False)

Performance Monitoring

Cache Performance Metrics

class CacheMonitor:
    def __init__(self, cache_manager):
        self.cache = cache_manager
        self.metrics = {
            'hits': 0,
            'misses': 0,
            'saves': 0,
            'load_times': [],
            'save_times': []
        }
    
    def time_operation(self, operation, *args, **kwargs):
        """Time cache operations"""
        import time
        
        start_time = time.time()
        result = operation(*args, **kwargs)
        elapsed = time.time() - start_time
        
        if 'load' in operation.__name__:
            self.metrics['load_times'].append(elapsed)
            if result:
                self.metrics['hits'] += 1
            else:
                self.metrics['misses'] += 1
        elif 'save' in operation.__name__:
            self.metrics['save_times'].append(elapsed)
            self.metrics['saves'] += 1
        
        return result
    
    def get_performance_report(self):
        """Generate performance report"""
        total_ops = self.metrics['hits'] + self.metrics['misses']
        hit_rate = self.metrics['hits'] / total_ops if total_ops > 0 else 0
        
        avg_load_time = sum(self.metrics['load_times']) / len(self.metrics['load_times']) if self.metrics['load_times'] else 0
        avg_save_time = sum(self.metrics['save_times']) / len(self.metrics['save_times']) if self.metrics['save_times'] else 0
        
        return {
            'cache_hit_rate': hit_rate,
            'total_operations': total_ops,
            'average_load_time': avg_load_time,
            'average_save_time': avg_save_time,
            'total_saves': self.metrics['saves']
        }

# Usage
monitor = CacheMonitor(cm)

# Monitor cache operations
acts = monitor.time_operation(cm.load_activities, tb_hash, "default")
monitor.time_operation(cm.save_activities, tb_hash, "default", acts)

print("Cache Performance:", monitor.get_performance_report())

Memory Usage Optimization

def optimize_memory_usage(cache_manager, max_memory_mb=500):
    """Optimize memory usage by limiting cache size"""
    import psutil
    import gc
    
    process = psutil.Process()
    memory_mb = process.memory_info().rss / 1024 / 1024
    
    if memory_mb > max_memory_mb:
        print(f"Memory usage ({memory_mb:.1f}MB) exceeds limit ({max_memory_mb}MB)")
        
        # Clear in-memory caches if they exist
        if hasattr(cache_manager, 'memory_cache'):
            cache_manager.clear_memory_cache()
        
        # Force garbage collection
        gc.collect()
        
        new_memory_mb = process.memory_info().rss / 1024 / 1024
        print(f"Memory reduced to {new_memory_mb:.1f}MB (saved {memory_mb - new_memory_mb:.1f}MB)")

Production Considerations:

  • Monitor cache directory size to prevent disk space issues
  • Implement cache cleanup strategies for long-running services
  • Consider cache warming strategies for frequently accessed data
  • Use cache versioning when data schemas change

Advanced Cache Management

Custom Cache Backends

class RedisCacheManager(CacheManager):
    """Redis-based cache backend for distributed caching"""
    
    def __init__(self, redis_client, base_dir=None):
        super().__init__(base_dir)
        self.redis = redis_client
        self.default_ttl = 86400  # 24 hours
    
    def save_activities(self, tb_hash, key, activities, ttl=None):
        """Save activities to Redis"""
        import pickle
        
        redis_key = f"mrra:activities:{tb_hash}:{key}"
        serialized = pickle.dumps(activities)
        
        self.redis.setex(redis_key, ttl or self.default_ttl, serialized)
        
        # Also save to disk as backup
        super().save_activities(tb_hash, key, activities)
    
    def load_activities(self, tb_hash, key):
        """Load activities from Redis with disk fallback"""
        import pickle
        
        redis_key = f"mrra:activities:{tb_hash}:{key}"
        
        try:
            serialized = self.redis.get(redis_key)
            if serialized:
                return pickle.loads(serialized)
        except Exception as e:
            print(f"Redis error: {e}, falling back to disk")
        
        # Fallback to disk cache
        return super().load_activities(tb_hash, key)

# Usage with Redis
import redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
redis_cache = RedisCacheManager(redis_client)

High-Performance Tips:

  • Use SSD storage for cache directories when possible
  • Implement cache warming for critical data
  • Consider distributed caching for multi-node deployments
  • Monitor cache hit rates and adjust strategies accordingly

Next Steps