"""
Translation Cache for SQL Tier Aggregation.
This module provides caching for translated pipeline-to-SQL queries,
with O(1) LRU (Least Recently Used) eviction using OrderedDict.
The cache stores SQL templates (not results), allowing the same translated
SQL to be reused across multiple query executions with different parameters.
"""
from __future__ import annotations
from collections import OrderedDict
from operator import itemgetter
from typing import Any
[docs]
class CacheEntry:
"""Single cache entry with hit statistics."""
__slots__ = ("sql_template", "param_names", "hit_count")
[docs]
def __init__(self, sql_template: str, param_names: tuple[str, ...]):
self.sql_template = sql_template
self.param_names = param_names
self.hit_count = 0
[docs]
class TranslationCache:
"""
LRU cache for SQL translation templates with O(1) get/put operations.
Uses OrderedDict for efficient LRU eviction: most recently used entries
are moved to the end, least recently used are evicted from the front.
"""
DEFAULT_MAX_SIZE = 100
[docs]
def __init__(self, max_size: int = DEFAULT_MAX_SIZE):
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_size = max_size
self._miss_count = 0
self._hit_count = 0
[docs]
def get(self, key: str) -> tuple[str, tuple[str, ...]] | None:
"""Get cached SQL template by key. Returns (sql, param_names) or None."""
if self._max_size == 0:
self._miss_count += 1
return None
entry = self._cache.get(key)
if entry is None:
self._miss_count += 1
return None
# Move to end (most recently used) for LRU
self._cache.move_to_end(key)
entry.hit_count += 1
self._hit_count += 1
return entry.sql_template, entry.param_names
[docs]
def put(
self, key: str, sql_template: str, param_names: tuple[str, ...]
) -> None:
"""Store SQL template in cache with extracted parameter names."""
if self._max_size == 0:
return # Cache disabled
if key in self._cache:
# Already exists, update and move to end (most recently used)
entry = self._cache[key]
entry.sql_template = sql_template
entry.param_names = param_names
self._cache.move_to_end(key)
return
# Evict if full (O(1) LRU: remove least recently used from front)
if len(self._cache) >= self._max_size:
self._cache.popitem(last=False)
# Add new entry at end (most recent)
self._cache[key] = CacheEntry(sql_template, param_names)
[docs]
def make_key(self, pipeline: list[dict[str, Any]]) -> str:
"""
Create a cache key from pipeline structure.
Key includes: operator names + field names + nested operator names.
Values like $sample.size, $limit, $skip are NOT included in the key because
we now parameterize them in SQL (using ?) - the same cached SQL template can
be reused with different parameter values.
"""
key_parts = []
for stage in pipeline:
stage_name = next(iter(stage.keys()))
spec = stage[stage_name]
match spec:
case str() as s:
key_parts.append(f"{stage_name}:{s}")
case list() as lst:
key_parts.append(f"{stage_name}:{tuple(sorted(lst))}")
case dict() as d:
nested_struct = self._extract_structure(d)
key_parts.append(f"{stage_name}:{nested_struct}")
case _:
key_parts.append(stage_name)
return "|".join(key_parts)
[docs]
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
total = self._hit_count + self._miss_count
hit_rate = self._hit_count / total if total > 0 else 0.0
entries: list[dict[str, Any]] = []
for key, entry in self._cache.items():
entries.append(
{
"key": key[:50] + "..." if len(key) > 50 else key,
"hit_count": entry.hit_count,
"sql_preview": (
entry.sql_template[:60].replace("\n", " ") + "..."
if len(entry.sql_template) > 60
else entry.sql_template
),
}
)
entries.sort(key=itemgetter("hit_count"), reverse=True)
return {
"size": len(self._cache),
"max_size": self._max_size,
"hits": self._hit_count,
"misses": self._miss_count,
"hit_rate": hit_rate,
"total_accesses": total,
"entries": entries,
}
[docs]
def clear(self) -> None:
"""Clear the cache and reset statistics."""
self._cache.clear()
self._miss_count = 0
self._hit_count = 0
[docs]
def resize(self, new_size: int) -> None:
"""Resize cache, evicting entries if needed."""
self._max_size = new_size
while len(self._cache) > new_size:
self._cache.popitem(last=False)
[docs]
def evict(self, key: str) -> bool:
"""Evict a specific entry by key. Returns True if evicted."""
if key in self._cache:
del self._cache[key]
return True
return False
[docs]
def contains(self, key: str) -> bool:
"""Check if a key is in the cache."""
return key in self._cache
[docs]
def get_entry(self, key: str) -> dict | None:
"""Get detailed info about a specific cache entry."""
entry = self._cache.get(key)
if entry is None:
return None
return {
"key": key,
"sql_template": entry.sql_template,
"param_names": entry.param_names,
"hit_count": entry.hit_count,
}
[docs]
def _get_entry_hit_count(self, item: tuple[str, CacheEntry]) -> int:
"""Helper to extract hit_count from cache entry for sorting."""
return item[1].hit_count
[docs]
def dump(self) -> list[dict]:
"""Dump all cache entries for debugging."""
sorted_items = sorted(
self._cache.items(), key=self._get_entry_hit_count, reverse=True
)
return [
{
"key": key,
"sql_preview": entry.sql_template[:100].replace("\n", " "),
"param_names": entry.param_names,
"hit_count": entry.hit_count,
}
for key, entry in sorted_items
]
[docs]
def is_enabled(self) -> bool:
"""Check if cache is enabled."""
return self._max_size > 0
def __len__(self) -> int:
"""Return number of entries in cache."""
return len(self._cache)