"""
Tier 2: Temporary table-based expression evaluation for complex $expr queries.
This module implements the second tier of the 3-tier $expr evaluation architecture:
- Uses temporary tables for complex expressions that can't be evaluated in a single query
- Leverages SQLite JSON/JSONB native functions for optimal performance
- Ensures data is converted to JSON format when moving to Python space
Tier 2 is used when:
- Expressions are too complex for Tier 1 (single SQL WHERE clause)
- Multiple intermediate calculations are needed
- Aggregation or grouping is required within the expression
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
from .._sqlite import sqlite3
from .json_path_utils import parse_json_path
from .jsonb_support import supports_jsonb
from .query_helper.translation_cache import TranslationCache
[docs]
class TempTableExprEvaluator:
"""
Tier 2 evaluator using temporary tables for complex $expr expressions.
This evaluator:
1. Creates temporary tables to store intermediate results
2. Uses SQLite JSON/JSONB functions for efficient processing
3. Converts to JSON format when exporting to Python space
"""
[docs]
def __init__(
self,
db_connection,
data_column: str = "data",
translation_cache_size: int | None = 100,
):
"""
Initialize the temporary table evaluator.
Args:
db_connection: SQLite database connection
data_column: Name of the column containing JSON data (default: "data")
translation_cache_size: Size of translation cache (None=default, 0=disabled)
"""
self.db = db_connection
self.data_column = data_column
self._jsonb_supported = supports_jsonb(db_connection)
self._temp_tables: list[str] = []
if translation_cache_size is None:
translation_cache_size = 100
self._translation_cache = TranslationCache(
max_size=translation_cache_size
)
@property
def json_function_prefix(self) -> str:
"""Get the appropriate JSON function prefix (json or jsonb)."""
return "jsonb" if self._jsonb_supported else "json"
[docs]
def is_cache_enabled(self) -> bool:
"""Check if translation cache is enabled."""
return self._translation_cache.is_enabled()
[docs]
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return self._translation_cache.get_stats()
[docs]
def clear_cache(self) -> None:
"""Clear the translation cache."""
self._translation_cache.clear()
[docs]
def cache_size(self) -> int:
"""Get current cache size."""
return len(self._translation_cache)
[docs]
def cache_contains(self, expr: dict) -> bool:
"""Check if expression is in cache."""
key = self._make_expr_key(expr)
return self._translation_cache.contains(key)
[docs]
def evict_from_cache(self, expr: dict) -> bool:
"""Evict expression from cache."""
key = self._make_expr_key(expr)
return self._translation_cache.evict(key)
[docs]
def resize_cache(self, new_size: int) -> None:
"""Resize the cache."""
self._translation_cache.resize(new_size)
[docs]
def dump_cache(self) -> list:
"""Dump cache contents for debugging."""
return self._translation_cache.dump()
[docs]
def evaluate(
self,
expr: dict[str, Any],
collection_name: str,
filter_expr: dict[str, Any] | None = None,
) -> tuple[str | None, list[Any], list[str]]:
"""
Evaluate a $expr expression using temporary tables.
Args:
expr: The $expr expression dictionary
collection_name: Name of the collection table
filter_expr: Optional additional filter to apply
Returns:
Tuple of (SQL query, parameters, table_names) or (None, [], []) for Python fallback
"""
try:
# Analyze expression complexity
complexity = self._analyze_complexity(expr)
# If too simple for Tier 2, let Tier 1 handle it
if complexity < 2:
return None, [], []
# If too complex for Tier 2, fall back to Tier 3 (Python)
if complexity > 10:
return None, [], []
# Try cache first
cache_key = self._make_expr_key(expr)
cached = self._translation_cache.get(cache_key)
if cached is not None:
# Build query with cached translation
return self._build_from_cache(expr, collection_name, cached)
# Build the evaluation query
result = self._build_tier2_query(expr, collection_name, filter_expr)
if result[0] is not None:
# Cache the translation: WHERE clause template and fields to extract
sql_query, params, temp_tables = result
fields = self._extract_field_references(expr)
# Extract WHERE clause from the full query for caching
where_clause = self._extract_where_clause(sql_query)
# Store with temp_table placeholder for substitution on cache hit
# FIX: self._temp_tables[-1] already includes the prefix
where_clause_template = where_clause.replace(
temp_tables[-1], "{temp_table}"
)
self._translation_cache.put(
cache_key, where_clause_template, tuple(fields)
)
return result
except (NotImplementedError, ValueError) as e:
logger.debug(
f"Tier 2 evaluation failed, falling back to Python: {e}"
)
# Fall back to Python evaluation
return None, [], []
# Cleanup handled by Cursor
[docs]
def _analyze_complexity(self, expr: dict[str, Any]) -> int:
"""
Analyze expression complexity to determine if Tier 2 is appropriate.
Complexity scoring:
- Base expression: 1 point
- Each nested operator: +1 point
- Each arithmetic operator: +1 point
- Each conditional operator: +2 points
- Each array operator: +2 points
- Each subquery/correlated expression: +3 points
Returns:
int: Complexity score (1-2: Tier 1, 3-10: Tier 2, 11+: Tier 3)
"""
if not isinstance(expr, dict):
return 0
score = 1 # Base score
for operator, operands in expr.items():
match operator:
case "$add" | "$subtract" | "$multiply" | "$divide" | "$mod":
score += 1
if isinstance(operands, list):
for op in operands:
if isinstance(op, dict):
score += self._analyze_complexity(op)
case "$cond" | "$switch":
score += 2
if isinstance(operands, dict):
if "if" in operands:
score += self._analyze_complexity(operands["if"])
if "then" in operands:
if isinstance(operands["then"], dict):
score += self._analyze_complexity(
operands["then"]
)
case "$and" | "$or" | "$nor":
if isinstance(operands, list):
for op in operands:
if isinstance(op, dict):
score += self._analyze_complexity(op)
case "$not":
if isinstance(operands, list) and len(operands) > 0:
if isinstance(operands[0], dict):
score += self._analyze_complexity(operands[0])
case "$size" | "$in" | "$arrayElemAt":
score += 2
case "$concat" | "$substr" | "$trim":
score += 1
return score
[docs]
def _make_expr_key(self, expr: dict[str, Any]) -> str:
"""
Create a cache key from expression structure.
Uses TranslationCache._extract_structure to create a hashable key
that preserves field references ($field) but parameterizes literal values.
"""
structure = self._translation_cache._extract_structure(expr)
return str(structure)
[docs]
def _build_from_cache(
self,
expr: dict[str, Any],
collection_name: str,
cached: tuple[str, tuple[str, ...]],
) -> tuple[str, list[Any], list[str]]:
"""
Build query from cached translation.
Args:
expr: The original expression (for extracting actual parameter values)
collection_name: Collection table name
cached: Tuple of (where_clause_template, field_list)
Returns:
Tuple of (SQL query, parameters, table_names)
"""
where_clause_template, field_list = cached
# Generate unique temp table name
temp_table = f"temp_expr_{uuid.uuid4().hex[:8]}"
self._temp_tables.append(temp_table)
# Create temp table with cached fields
self._create_temp_table(temp_table, collection_name, list(field_list))
# Build WHERE clause by substituting temp table name into template
where_clause = where_clause_template.replace("{temp_table}", temp_table)
# Extract parameter values from expression
params = self._extract_param_values_from_expr(expr)
# Build SELECT with json() conversion for Python-space data
if self._jsonb_supported:
select_data = f"json({collection_name}.{self.data_column}) as data"
else:
select_data = f"{collection_name}.data as data"
# Build the final query
query = f"""
SELECT {temp_table}.id, {temp_table}._id, {select_data}
FROM {collection_name}
JOIN {temp_table} ON {collection_name}.id = {temp_table}.id
WHERE {where_clause}
"""
return query.strip(), params, [temp_table]
[docs]
def _build_tier2_query(
self,
expr: dict[str, Any],
collection_name: str,
filter_expr: dict[str, Any] | None = None,
) -> tuple[str, list[Any], list[str]]:
"""
Build a Tier 2 query using temporary tables.
Strategy:
1. Create temp table with document IDs and intermediate calculations
2. Populate temp table with JSON-extracted values
3. Query main table joined with temp table using calculated values
Args:
expr: The $expr expression
collection_name: Collection table name
filter_expr: Optional additional filter
Returns:
Tuple of (SQL query, parameters, table_names)
"""
# Generate unique temp table name
temp_table = f"temp_expr_{uuid.uuid4().hex[:8]}"
self._temp_tables.append(temp_table)
# Extract all field references from expression
fields = self._extract_field_references(expr)
# Create temp table to store intermediate results
self._create_temp_table(temp_table, collection_name, fields)
# Build the main query
sql, params = self._build_main_query(
expr, collection_name, temp_table, filter_expr
)
return sql, params, [temp_table]
[docs]
def _create_temp_table(
self, temp_table: str, collection_name: str, fields: list[str]
) -> None:
"""
Create a temporary table with extracted field values.
Args:
temp_table: Name of the temporary table
collection_name: Source collection name
fields: List of field paths to extract
"""
# Build column definitions for extracted fields
select_exprs = ["id"]
# Check if _id column exists
try:
self.db.execute(f"SELECT _id FROM {collection_name} LIMIT 0")
has_id = True
except sqlite3.OperationalError as e:
logger.debug(
f"Collection '{collection_name}' does not have _id column: {e}"
)
has_id = False
if has_id:
select_exprs.append("_id")
for field in fields:
# Sanitize field name for SQL column
col_name = self._sanitize_field_name(field)
# Build JSON extract expression
if self._jsonb_supported:
extract_expr = f"jsonb_extract({self.data_column}, '{parse_json_path(field)}') as {col_name}"
else:
extract_expr = f"json_extract({self.data_column}, '{parse_json_path(field)}') as {col_name}"
select_exprs.append(extract_expr)
# Create temp table with SELECT (faster than CREATE + INSERT)
select_sql = f"SELECT {', '.join(select_exprs)} FROM {collection_name}"
create_sql = (
f"CREATE TEMPORARY TABLE IF NOT EXISTS {temp_table} AS {select_sql}"
)
self.db.execute(create_sql)
# Create index on temp table for faster joins
self.db.execute(
f"CREATE INDEX IF NOT EXISTS idx_{temp_table}_id ON {temp_table}(id)"
)
[docs]
def _build_main_query(
self,
expr: dict[str, Any],
collection_name: str,
temp_table: str,
filter_expr: dict[str, Any] | None = None,
) -> tuple[str, list[Any]]:
"""
Build the main query that uses the temporary table.
Args:
expr: The $expr expression
collection_name: Collection table name
temp_table: Temporary table name
filter_expr: Optional additional filter
Returns:
Tuple of (SQL query, parameters)
"""
# Build WHERE clause from expression using temp table columns
where_clause, params = self._build_expr_where_from_temp(
expr, temp_table
)
# Build SELECT with json() conversion for Python-space data
if self._jsonb_supported:
# Use json() to convert jsonb to regular JSON for Python space
select_data = f"json({collection_name}.{self.data_column}) as data"
else:
select_data = f"{collection_name}.data as data"
# Build the final query
query = f"""
SELECT {temp_table}.id, {temp_table}._id, {select_data}
FROM {collection_name}
JOIN {temp_table} ON {collection_name}.id = {temp_table}.id
WHERE {where_clause}
"""
return query.strip(), params
[docs]
def _build_expr_where_from_temp(
self, expr: dict[str, Any], temp_table: str
) -> tuple[str, list[Any]]:
"""
Build WHERE clause using temporary table columns.
Args:
expr: The $expr expression
temp_table: Temporary table name
Returns:
Tuple of (WHERE clause, parameters)
"""
sql, params = self._convert_expr_to_temp_sql(expr, temp_table)
# MongoDB $expr truthiness: NOT (null, 0, false, undefined).
# In SQLite, we use COALESCE and != 0 to return 1 for truthy and 0 for falsy.
truthy_sql = f"COALESCE(({sql}), 0) != 0"
return truthy_sql, params
[docs]
def _convert_expr_to_temp_sql(
self, expr: dict[str, Any], temp_table: str
) -> tuple[str, list[Any]]:
"""
Convert expression to SQL using temporary table columns.
Args:
expr: The $expr expression
temp_table: Temporary table name
Returns:
Tuple of (SQL expression, parameters)
"""
if not isinstance(expr, dict) or len(expr) != 1:
raise ValueError("Invalid $expr expression structure")
operator, operands = next(iter(expr.items()))
# Handle different operator types
match operator:
case "$and" | "$or" | "$not" | "$nor":
return self._convert_logical_to_temp_sql(
operator, operands, temp_table
)
case "$gt" | "$gte" | "$lt" | "$lte" | "$eq" | "$ne":
return self._convert_comparison_to_temp_sql(
operator, operands, temp_table
)
case "$add" | "$subtract" | "$multiply" | "$divide" | "$mod":
return self._convert_arithmetic_to_temp_sql(
operator, operands, temp_table
)
case "$cond":
return self._convert_cond_to_temp_sql(operands, temp_table)
case "$cmp":
return self._convert_cmp_to_temp_sql(operands, temp_table)
case "$abs" | "$ceil" | "$floor" | "$round":
return self._convert_math_to_temp_sql(
operator, operands, temp_table
)
case _:
raise NotImplementedError(
f"Operator {operator} not supported in Tier 2"
)
[docs]
def _convert_logical_to_temp_sql(
self, operator: str, operands: list[Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert logical operators to SQL using temp table."""
if not isinstance(operands, list):
raise ValueError(f"{operator} requires a list of expressions")
if operator == "$not":
if len(operands) != 1:
raise ValueError("$not requires exactly one operand")
inner_sql, inner_params = self._convert_expr_to_temp_sql(
operands[0], temp_table
)
return f"NOT ({inner_sql})", inner_params
if len(operands) < 2:
raise ValueError(f"{operator} requires at least 2 operands")
sql_parts = []
all_params = []
for operand in operands:
operand_sql, operand_params = self._convert_expr_to_temp_sql(
operand, temp_table
)
sql_parts.append(f"({operand_sql})")
all_params.extend(operand_params)
match operator:
case "$and":
sql = " AND ".join(sql_parts)
case "$or":
sql = " OR ".join(sql_parts)
case "$nor":
sql = f"NOT ({' OR '.join(sql_parts)})"
case _:
raise ValueError(f"Unknown logical operator: {operator}")
return sql, all_params
[docs]
def _convert_comparison_to_temp_sql(
self, operator: str, operands: list[Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert comparison operators to SQL using temp table."""
if len(operands) != 2:
raise ValueError(f"{operator} requires exactly 2 operands")
left_sql, left_params = self._convert_operand_to_temp_sql(
operands[0], temp_table
)
right_sql, right_params = self._convert_operand_to_temp_sql(
operands[1], temp_table
)
op_mapping = {
"$eq": "=",
"$gt": ">",
"$gte": ">=",
"$lt": "<",
"$lte": "<=",
"$ne": "!=",
}
sql_operator = op_mapping.get(operator, operator)
return (
f"{left_sql} {sql_operator} {right_sql}",
left_params + right_params,
)
[docs]
def _convert_arithmetic_to_temp_sql(
self, operator: str, operands: list[Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert arithmetic operators to SQL using temp table."""
if len(operands) < 2:
raise ValueError(f"{operator} requires at least 2 operands")
sql_parts = []
all_params = []
for operand in operands:
operand_sql, operand_params = self._convert_operand_to_temp_sql(
operand, temp_table
)
sql_parts.append(operand_sql)
all_params.extend(operand_params)
op_mapping = {
"$add": "+",
"$subtract": "-",
"$multiply": "*",
"$divide": "/",
"$mod": "%",
}
sql_operator = op_mapping.get(operator, operator)
sql = f"({f' {sql_operator} '.join(sql_parts)})"
return sql, all_params
[docs]
def _convert_cond_to_temp_sql(
self, operands: dict[str, Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert $cond operator to SQL CASE statement using temp table."""
if not isinstance(operands, dict):
raise ValueError("$cond requires a dictionary")
if "if" not in operands or "then" not in operands:
raise ValueError("$cond requires 'if' and 'then' fields")
condition_sql, condition_params = self._convert_expr_to_temp_sql(
operands["if"], temp_table
)
then_sql, then_params = self._convert_operand_to_temp_sql(
operands["then"], temp_table
)
if "else" in operands:
else_sql, else_params = self._convert_operand_to_temp_sql(
operands["else"], temp_table
)
else:
else_sql, else_params = "NULL", []
sql = f"CASE WHEN {condition_sql} THEN {then_sql} ELSE {else_sql} END"
return sql, condition_params + then_params + else_params
[docs]
def _convert_cmp_to_temp_sql(
self, operands: list[Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert $cmp operator to SQL CASE statement using temp table."""
if len(operands) != 2:
raise ValueError("$cmp requires exactly 2 operands")
left_sql, left_params = self._convert_operand_to_temp_sql(
operands[0], temp_table
)
right_sql, right_params = self._convert_operand_to_temp_sql(
operands[1], temp_table
)
sql = f"(CASE WHEN {left_sql} < {right_sql} THEN -1 WHEN {left_sql} > {right_sql} THEN 1 ELSE 0 END)"
return sql, left_params + right_params
[docs]
def _convert_math_to_temp_sql(
self, operator: str, operands: list[Any], temp_table: str
) -> tuple[str, list[Any]]:
"""Convert math operators to SQL using temp table."""
if len(operands) != 1:
raise ValueError(f"{operator} requires exactly 1 operand")
value_sql, value_params = self._convert_operand_to_temp_sql(
operands[0], temp_table
)
op_mapping = {
"$abs": "abs",
"$ceil": "ceil",
"$floor": "floor",
"$round": "round",
}
sql_func = op_mapping.get(operator)
if sql_func is None:
raise NotImplementedError(
f"Math operator {operator} not supported in Tier 2"
)
return f"{sql_func}({value_sql})", value_params
[docs]
def _convert_operand_to_temp_sql(
self, operand: Any, temp_table: str
) -> tuple[str, list[Any]]:
"""
Convert operand to SQL using temporary table columns.
Args:
operand: Operand to convert
temp_table: Temporary table name
Returns:
Tuple of (SQL expression, parameters)
"""
if isinstance(operand, str) and operand.startswith("$"):
# Field reference - use temp table column
field_path = operand[1:] # Remove $
col_name = self._sanitize_field_name(field_path)
return f"{temp_table}.{col_name}", []
elif isinstance(operand, dict):
# Nested expression
return self._convert_expr_to_temp_sql(operand, temp_table)
else:
# Literal value
return "?", [operand]
[docs]
def _extract_field_references_from_operand(self, operand: Any) -> list[str]:
"""Extract field references from an operand."""
fields = []
if isinstance(operand, str) and operand.startswith("$"):
fields.append(operand[1:])
elif isinstance(operand, list):
for item in operand:
fields.extend(self._extract_field_references_from_operand(item))
elif isinstance(operand, dict):
for op, val in operand.items():
fields.extend(self._extract_field_references_from_operand(val))
return fields
[docs]
def _sanitize_field_name(self, field_path: str) -> str:
"""
Sanitize a field path for use as a SQL column name.
Args:
field_path: Field path (e.g., "stats.wins")
Returns:
Sanitized column name (e.g., "stats_wins")
"""
# Replace dots with underscores
col_name = field_path.replace(".", "_")
# Replace brackets (for array indices)
col_name = col_name.replace("[", "_").replace("]", "")
# Ensure valid SQL identifier
if col_name and not col_name[0].isalpha():
col_name = "f_" + col_name
return col_name
[docs]
def cleanup_temp_tables(self) -> None:
"""Clean up all temporary tables created by this evaluator."""
for table in self._temp_tables:
try:
self.db.execute(f"DROP TABLE IF EXISTS {table}")
except Exception as e:
logger.debug(
f"Failed to drop temporary table {table} during cleanup: {e}"
)
pass # Ignore cleanup errors
self._temp_tables.clear()