Source code for neosqlite.collection.query_engine

from __future__ import annotations

import importlib.util
import logging
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any

logger = logging.getLogger(__name__)

from neosqlite.collection.jsonb_support import (
    _get_json_each_function,
    supports_jsonb,
    supports_jsonb_each,
)

from ...bulk_operations import BulkOperationExecutor
from ...exceptions import MalformedQueryException
from ...requests import DeleteOne, InsertOne, UpdateOne
from ...results import BulkWriteResult
from ..cursor import DESCENDING
from ..expr_evaluator import AggregationContext, ExprEvaluator, _is_expression
from ..query_helper import QueryHelper
from ..raw_batch_cursor import RawBatchCursor
from ..sql_tier_aggregator import SQLTierAggregator
from ..sql_translator_unified import SQLTranslator
from ..type_utils import validate_session
from .crud_operations import CRUDOperationsMixin
from .find_operations import FindOperationsMixin
from .query_methods import QueryMethodsMixin

if TYPE_CHECKING:
    from quez import CompressedQueue

    from ..client_session import ClientSession

TierChangeCallback = Callable[[str | None, str, list], None]

# Check if quez is available
_HAS_QUEZ = importlib.util.find_spec("quez") is not None


[docs] class QueryEngine(CRUDOperationsMixin, FindOperationsMixin, QueryMethodsMixin): """ A class that provides methods for querying and manipulating documents in a collection. The QueryEngine handles all database operations including inserting, updating, deleting, and finding documents. It also supports aggregation pipelines, bulk operations, and various utility methods for counting and retrieving distinct values. """
[docs] def __init__(self, collection): """ Initialize the QueryEngine with a collection. Args: collection: The collection instance this QueryEngine will operate on. """ self.collection = collection self.helpers = QueryHelper(collection) # Check if JSONB is supported for this connection self._jsonb_supported = supports_jsonb(collection.db) self._jsonb_each_supported = supports_jsonb_each(collection.db) json_each_function = _get_json_each_function( self._jsonb_supported, self._jsonb_each_supported ) self.sql_translator = SQLTranslator( collection.name, "data", "id", self._jsonb_supported, json_each_function, ) # Get translation cache size from connection (default: 100, 0 to disable) # collection._database is the NeoSQLite Connection, collection.db is sqlite3 neosqlite_conn = collection._database cache_size = getattr(neosqlite_conn, "_translation_cache_size", 100) # Initialize SQL tier aggregator for optimized aggregation pipelines self.sql_tier_aggregator = SQLTierAggregator( collection, expr_evaluator=ExprEvaluator( data_column="data", db_connection=collection.db ), translation_cache_size=cache_size, ) self._tier_callbacks: list = [] # type: ignore[annotation-unchecked] self._last_tier: str | None = None # type: ignore[annotation-unchecked]
[docs] def add_tier_change_callback(self, callback: "TierChangeCallback") -> None: """Add a callback to be notified when query tier changes. Callback receives: (previous_tier: str | None, new_tier: str, pipeline: list) where tier is one of: - "tier1" (SQL CTE - new aggregation optimizer) - "tier1_standard" (non-CTE SQL aggregation) - "tier2" (temp table for complex $expr) - "tier3" (Python fallback) - None (before any query) """ self._tier_callbacks.append(callback)
[docs] def remove_tier_change_callback( self, callback: "TierChangeCallback" ) -> bool: """Remove a tier change callback. Returns True if found.""" try: self._tier_callbacks.remove(callback) return True except ValueError: return False
[docs] def get_last_tier(self) -> str | None: """Get the last tier that was used for query execution.""" return self._last_tier
[docs] def clear_tier_callbacks(self) -> None: """Clear all tier change callbacks.""" self._tier_callbacks.clear()
[docs] def _notify_tier_change(self, new_tier: str, pipeline: list) -> None: """Notify all callbacks of a tier change.""" if self._last_tier != new_tier: for callback in self._tier_callbacks: try: callback(self._last_tier, new_tier, pipeline) except Exception as e: logger.debug(f"Query tier callback error: {e}") pass # Don't let callback errors affect query execution self._last_tier = new_tier
[docs] def cleanup(self) -> None: """Clean up resources used by the QueryEngine.""" if hasattr(self, "helpers"): self.helpers.cleanup()
[docs] def aggregate( self, pipeline: list[dict[str, Any]], batch_size: int = 101, session: ClientSession | None = None, ) -> list[dict[str, Any]]: """ Applies a list of aggregation pipeline stages to the collection. This method handles both simple and complex queries. For simpler queries, it leverages the database's native indexing capabilities to optimize performance. For more complex queries, it falls back to a Python-based processing mechanism. Args: pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages to apply. batch_size (int): The batch size for fetching results from database. session (ClientSession, optional): A ClientSession for transactions. Returns: list[dict[str, Any]]: The list of documents after applying the aggregation pipeline. """ validate_session(session, self.collection._database) return self.aggregate_with_constraints( pipeline, batch_size=batch_size, session=session )
[docs] def aggregate_with_constraints( self, pipeline: list[dict[str, Any]], batch_size: int = 101, memory_constrained: bool = False, session: ClientSession | None = None, ) -> list[dict[str, Any]] | "CompressedQueue": """ Applies a list of aggregation pipeline stages with memory constraints. Args: pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages to apply. batch_size (int): The batch size for processing large result sets. memory_constrained (bool): Whether to use memory-constrained processing. session (ClientSession, optional): A ClientSession for transactions. Returns: list[dict[str, Any]] | CompressedQueue: The results as either a list or compressed queue. """ validate_session(session, self.collection._database) # If memory_constrained is True and quez is available, use quez for processing if memory_constrained and _HAS_QUEZ: # Use quez for memory-constrained processing return self._aggregate_with_quez(pipeline, batch_size) # Try SQL Tier 1 optimization first (new CTE-based approach) try: if self.sql_tier_aggregator.can_optimize_pipeline(pipeline): sql, params = self.sql_tier_aggregator.build_pipeline_sql( pipeline ) if sql is not None: db_cursor = self.collection.db.execute(sql, params) results = [] # Use fetchmany to avoid loading all results into memory at once while True: rows = db_cursor.fetchmany(batch_size) if not rows: break for row in rows: # Load document from data column # Row structure: # If root_data preserved: (id, _id, root_data, data) - len 4 # Normal: (id, _id, data) - len 3 # GROUP BY results might have id=NULL and data as a custom object doc_data = row[-1] doc_id = row[0] stored_id = row[1] if doc_data is None: continue if doc_data.startswith("{") and doc_data.endswith( "}" ): # It's a JSON object (standard or GROUP BY result) from neosqlite.collection.json_helpers import ( neosqlite_json_loads, ) document = neosqlite_json_loads(doc_data) if ( "_id" not in document and stored_id is not None ): document["_id"] = ( self.collection._parse_stored_id( stored_id ) ) results.append(document) else: # Normal loading via _load results.append( self.collection._load( doc_id, doc_data, stored_id=stored_id ) ) self._notify_tier_change("tier1", pipeline) return results except NotImplementedError as e: # Operator not yet translated to SQL — log at WARNING for visibility # during development/comparison runs, then fall back to next tier logger.warning("SQL tier 1 aggregation fallback: %s", e) except Exception as e: # If SQL tier optimization fails, continue to next approach logger.debug("SQL tier 1 aggregation optimization failed: %s", e) # Try existing SQL optimization (legacy CTE-based approach) try: query_result = self.helpers._build_aggregation_query(pipeline) if query_result is not None: cmd, params, output_fields = query_result db_cursor = self.collection.db.execute(cmd, params) if output_fields: # Handle results from a GROUP BY query from neosqlite.collection.json_helpers import ( neosqlite_json_loads, ) results = [] # Use fetchmany to avoid loading all results into memory at once while True: rows = db_cursor.fetchmany(batch_size) if not rows: break for row in rows: processed_row = [] for i, value in enumerate(row): # If this field contains a JSON array string, parse it # This handles $push and $addToSet results if ( output_fields[i] != "_id" and isinstance(value, str) and value.startswith("[") and value.endswith("]") ): try: processed_row.append( neosqlite_json_loads(value) ) except Exception as e: logger.debug( f"Failed to parse JSON in aggregation result: {e}" ) processed_row.append(value) else: processed_row.append(value) results.append( dict(zip(output_fields, processed_row)) ) self._notify_tier_change("tier1_standard", pipeline) return results else: # Handle results from a regular find query # Use fetchmany to avoid loading all results into memory at once results = [] while True: rows = db_cursor.fetchmany(batch_size) if not rows: break for row in rows: # Row structure: (id, data) or (id, root_data, data) if len(row) == 3: # root_data is present, data is in row[2] results.append( self.collection._load(row[0], row[2]) ) else: # No root_data, data is in row[1] results.append( self.collection._load(row[0], row[1]) ) self._notify_tier_change("tier1_standard", pipeline) return results except Exception as e: # If SQL optimization fails, continue to next approach logger.debug( "SQL tier 1 standard aggregation optimization failed: %s", e ) # Try the temporary table approach for complex pipelines that the # current SQL optimization can't handle efficiently try: from ..temporary_table_aggregation import ( execute_2nd_tier_aggregation, ) # Use the temporary table aggregation which provides enhanced # SQL processing for complex pipelines result = execute_2nd_tier_aggregation( self, pipeline, batch_size=batch_size ) if result is not None: self._notify_tier_change("tier2", pipeline) return result except NotImplementedError as e: # Operator not yet translated to SQL — log at WARNING for visibility # during development/comparison runs, then fall back to Python tier logger.warning("SQL tier 2 aggregation fallback: %s", e) except Exception as e: # If temporary table approach fails for other reasons, # continue to fallback below logger.debug("SQL tier 2 aggregation optimization failed: %s", e) # Optimize $count in SQLite when possible if ( pipeline and isinstance(pipeline[-1], dict) and "$count" in pipeline[-1] ): count_field = pipeline[-1]["$count"] if not pipeline[:-1]: # No previous stages, count all documents count = self.estimated_document_count() return [{count_field: count}] elif len(pipeline) == 2 and "$match" in pipeline[0]: # Only $match before $count, use count_documents filter = pipeline[0]["$match"] count = self.count_documents(filter) return [{count_field: count}] # For more complex pipelines, fall back to Python # Fallback to old method for complex queries (Python implementation) docs: list[dict[str, Any]] = list(self.find(session=session)) # Store original documents for $$ROOT variable support # Each document is wrapped with metadata for variable scoping docs_with_context = [ {"__doc__": doc, "__root__": deepcopy(doc)} for doc in docs ] for stage in pipeline: if not stage: raise MalformedQueryException("Empty pipeline stage") stage_name = next(iter(stage.keys())).strip() match stage_name: case "$match": query = stage["$match"] docs_with_context = [ dc for dc in docs_with_context if self.helpers._apply_query(query, dc["__doc__"]) ] case "$sort": sort_spec = stage["$sort"] for key, direction in reversed(list(sort_spec.items())): def make_sort_key(key, dir): """ Create a sort key function for the given key and direction. """ def sort_key(dc): """ Extract sort key from document context. """ val = self.collection._get_val( dc["__doc__"], key ) # Handle None values - sort them last for ascending, first for descending if val is None: return (0 if dir == DESCENDING else 1, None) return (0, val) return sort_key sort_key_func = make_sort_key(key, direction) docs_with_context.sort( key=sort_key_func, reverse=direction == DESCENDING, ) case "$skip": count = stage["$skip"] docs_with_context = docs_with_context[count:] case "$limit": count = stage["$limit"] docs_with_context = docs_with_context[:count] case "$project": projection = stage["$project"] docs_with_context = [ { "__doc__": self.helpers._apply_projection( projection, dc["__doc__"] ), "__root__": dc["__root__"], } for dc in docs_with_context ] case "$replaceRoot" | "$replaceWith": # Handle $replaceRoot and $replaceWith if stage_name == "$replaceRoot": new_root_expr = stage["$replaceRoot"].get("newRoot") else: # $replaceWith new_root_expr = stage["$replaceWith"] new_docs_with_context = [] for dc in docs_with_context: # Evaluate the new root expression evaluator = ExprEvaluator() # Use _evaluate_operand_python to get the actual value (not forced to bool) try: new_doc = evaluator._evaluate_operand_python( new_root_expr, dc["__doc__"] ) except Exception as e: # Fallback if evaluation fails logger.debug(f"$replaceRoot evaluation failed: {e}") new_doc = dc["__doc__"] # MongoDB requirement: result MUST be an object if not isinstance(new_doc, dict): raise MalformedQueryException( f"$replaceRoot requires an object, got {type(new_doc).__name__}" ) # Ensure _id is preserved if it was present in the original if ( "_id" not in new_doc and "__doc__" in dc and "_id" in dc["__doc__"] ): new_doc["_id"] = dc["__doc__"]["_id"] new_docs_with_context.append( {"__doc__": new_doc, "__root__": dc["__root__"]} ) docs_with_context = new_docs_with_context case "$unset": # Handle $unset aggregation stage unset_spec = stage["$unset"] if isinstance(unset_spec, str): fields_to_unset = [unset_spec] elif isinstance(unset_spec, list): fields_to_unset = unset_spec else: raise MalformedQueryException( "$unset requires a string or a list of strings" ) for dc in docs_with_context: doc = dc["__doc__"] for field in fields_to_unset: # Use collection._get_val logic to navigate and pop if "." in field: parts = field.split(".") target: dict[str, Any] | None = doc for part in parts[:-1]: if ( isinstance(target, dict) and part in target ): target = target[part] else: target = None break if isinstance(target, dict): target.pop(parts[-1], None) else: doc.pop(field, None) case "$group": group_spec = stage["$group"] # For $group, we don't preserve __root__ since grouping creates new documents grouped_docs = self.helpers._process_group_stage( group_spec, [dc["__doc__"] for dc in docs_with_context] ) docs_with_context = [ {"__doc__": doc, "__root__": doc} for doc in grouped_docs ] case "$unwind": # Handle both string and object forms of $unwind unwind_spec = stage["$unwind"] if isinstance(unwind_spec, str): # Legacy string form field_path = unwind_spec.lstrip("$") include_array_index = None preserve_null_and_empty = False elif isinstance(unwind_spec, dict): # New object form with advanced options field_path = unwind_spec["path"].lstrip("$") include_array_index = unwind_spec.get( "includeArrayIndex" ) preserve_null_and_empty = unwind_spec.get( "preserveNullAndEmptyArrays", False ) else: raise MalformedQueryException( f"Invalid $unwind specification: {unwind_spec}" ) unwound_docs_with_context = [] for dc in docs_with_context: doc = dc["__doc__"] root = dc["__root__"] array_to_unwind = self.collection._get_val( doc, field_path ) # For nested fields, check if parent exists # If parent is None or missing and we're trying to unwind a nested field, # don't process this document field_parts = field_path.split(".") process_document = True if len(field_parts) > 1: # This is a nested field parent_path = ".".join(field_parts[:-1]) parent_value = self.collection._get_val( doc, parent_path ) if parent_value is None: # Parent is None or missing, don't process this document process_document = False if not process_document: continue if isinstance(array_to_unwind, list): # Handle array values if array_to_unwind: # Non-empty array - unwind normally for idx, item in enumerate(array_to_unwind): new_doc = deepcopy(doc) self.collection._set_val( new_doc, field_path, item ) # Add array index if requested if include_array_index: new_doc[include_array_index] = idx # Preserve __root__ for $$ROOT variable unwound_docs_with_context.append( {"__doc__": new_doc, "__root__": root} ) elif preserve_null_and_empty: # Empty array but preserve is requested new_doc = deepcopy(doc) # Remove the field entirely (MongoDB behavior) # For nested fields, we need to navigate to parent if "." in field_path: # Handle nested field removal parts = field_path.split(".") current = new_doc for part in parts[:-1]: if part in current: current = current[part] else: break else: # Remove the final field if parts[-1] in current: del current[parts[-1]] else: # Simple field removal if field_path in new_doc: del new_doc[field_path] # Add array index if requested if include_array_index: new_doc[include_array_index] = None unwound_docs_with_context.append( {"__doc__": new_doc, "__root__": root} ) # If empty array and preserve is False, don't add any documents elif ( not isinstance(array_to_unwind, list) and field_path in doc and preserve_null_and_empty ): # Non-array value (None, string, number, etc.) that exists in the document and preserve is requested new_doc = deepcopy(doc) # Keep the value as-is # Add array index if requested if include_array_index: new_doc[include_array_index] = None unwound_docs_with_context.append( {"__doc__": new_doc, "__root__": root} ) # Missing fields (field_path not in doc) are never preserved # Default case: non-array values are ignored unless they exist and preserveNullAndEmptyArrays is True docs_with_context = unwound_docs_with_context case "$lookup": # Python fallback implementation for $lookup lookup_spec = stage["$lookup"] from_collection_name = lookup_spec["from"] local_field = lookup_spec["localField"] foreign_field = lookup_spec["foreignField"] as_field = lookup_spec["as"] # Get the from collection from the database from_collection = self.collection._database[ from_collection_name ] # Process each document for dc in docs_with_context: doc = dc["__doc__"] # Get the local field value local_value = self.collection._get_val(doc, local_field) # Find matching documents in the from collection matching_docs = [] for match_doc in from_collection.find(session=session): foreign_value = from_collection._get_val( match_doc, foreign_field ) if local_value == foreign_value: # Add the matching document (without _id) match_doc_copy = match_doc.copy() match_doc_copy.pop("_id", None) matching_docs.append(match_doc_copy) # Add the matching documents as an array field doc[as_field] = matching_docs case "$addFields": add_fields_spec = stage["$addFields"] # Create expression evaluator for this stage evaluator_add = ExprEvaluator( data_column="data", db_connection=self.collection.db ) for dc in docs_with_context: doc = dc["__doc__"] root = dc["__root__"] # Create context for this document ctx = AggregationContext() ctx.bind_document( root ) # Bind original document as $$ROOT ctx.update_current(doc) # Set current document state for new_field, expr in add_fields_spec.items(): if _is_expression(expr): # Full expression - evaluate in Python with current context value = evaluator_add._evaluate_expr_python( expr, doc ) self.collection._set_val(doc, new_field, value) elif isinstance(expr, str) and expr.startswith("$"): # Field reference if expr.startswith("$$"): # Aggregation variable if expr == "$$ROOT": # $$ROOT always refers to original document value = root.copy() elif expr == "$$CURRENT": # $$CURRENT refers to document as it evolves value = doc.copy() else: value = None self.collection._set_val( doc, new_field, value ) else: # Regular field reference - may reference newly added field source_field_name = expr[1:] source_value = self.collection._get_val( doc, source_field_name ) self.collection._set_val( doc, new_field, source_value ) else: # Literal value self.collection._set_val(doc, new_field, expr) # Update $$CURRENT after all fields are added ctx.update_current(doc) case "$setWindowFields": from ..query_helper.window_operators import ( process_set_window_fields, ) window_spec = stage["$setWindowFields"] evaluator_window = ExprEvaluator( data_column="data", db_connection=self.collection.db ) docs_with_context = process_set_window_fields( docs_with_context, window_spec, self.collection, evaluator_window, ) case "$graphLookup": from ..query_helper.graph_lookup import process_graph_lookup graph_spec = stage["$graphLookup"] evaluator_graph = ExprEvaluator( data_column="data", db_connection=self.collection.db ) docs_with_context = process_graph_lookup( docs_with_context, graph_spec, self.collection, evaluator_graph, ) case "$fill": from ..query_helper.fill_stage import process_fill fill_spec = stage["$fill"] evaluator_fill = ExprEvaluator( data_column="data", db_connection=self.collection.db ) docs_with_context = process_fill( docs_with_context, fill_spec, self.collection, evaluator_fill, ) case "$sample": sample_spec = stage["$sample"] sample_size = sample_spec["size"] if sample_size < 0: raise MalformedQueryException( "$sample size must be non-negative" ) import random docs_with_context = random.sample( docs_with_context, min(sample_size, len(docs_with_context)), ) case "$facet": facet_spec = stage["$facet"] facet_tables: dict[str, str] = {} # Get input documents sub_docs = [dc["__doc__"] for dc in docs_with_context] # Run each sub-pipeline, streaming results to temp tables for facet_name, sub_pipeline in facet_spec.items(): result_table = self.helpers._run_subpipeline( sub_pipeline, sub_docs ) facet_tables[facet_name] = result_table # Load all results from temp tables and combine facet_results: dict[str, Any] = {} for facet_name, table_name in facet_tables.items(): cursor = self.collection.db.execute( f"SELECT data FROM {table_name}" ) from neosqlite.collection.json_helpers import ( neosqlite_json_loads, ) facet_results[facet_name] = [ neosqlite_json_loads(row[0]) for row in cursor.fetchall() ] # Clean up temp table after loading try: self.collection.db.execute( f"DROP TABLE IF EXISTS {table_name}" ) except Exception as e: logger.debug( f"Failed to drop facet temporary table '{table_name}': {e}" ) pass docs_with_context = [ {"__doc__": facet_results, "__root__": facet_results} ] case "$count": count_field = stage["$count"] docs_with_context = [ { "__doc__": {count_field: len(docs_with_context)}, "__root__": {count_field: len(docs_with_context)}, } ] case "$bucket": bucket_spec = stage["$bucket"] group_by = bucket_spec.get("groupBy", "").lstrip("$") boundaries = bucket_spec.get("boundaries", []) default_label = bucket_spec.get("default", "Other") output_spec = bucket_spec.get( "output", {"count": {"$sum": 1}} ) if not group_by or not boundaries: docs_with_context = [] break sorted_boundaries = sorted(boundaries) # Group documents by bucket # MongoDB uses the lower boundary value as _id, not a string label buckets: dict[Any, list[dict[str, Any]]] = {} for dc in docs_with_context: doc = dc["__doc__"] val = self.collection._get_val(doc, group_by) # Skip documents with None values if val is None: continue # Determine bucket - use lower boundary as key (MongoDB behavior) bucket_key: Any = default_label try: for i in range(len(sorted_boundaries) - 1): if ( sorted_boundaries[i] <= val < sorted_boundaries[i + 1] ): bucket_key = sorted_boundaries[ i ] # Use lower boundary as _id break else: # Last bucket (inclusive) - use last boundary if val >= sorted_boundaries[-1]: bucket_key = sorted_boundaries[-1] except TypeError: # Comparison failed (e.g., mixed types), use default bucket_key = default_label if bucket_key not in buckets: buckets[bucket_key] = [] buckets[bucket_key].append(doc) # Build output documents new_docs = [] for bucket_id, bucket_docs in sorted(buckets.items()): output_doc: dict[str, Any] = {"_id": bucket_id} for field_name, accumulator in output_spec.items(): if "$sum" in accumulator: sum_field = accumulator["$sum"] if sum_field == 1: output_doc[field_name] = len(bucket_docs) else: sum_field = sum_field.lstrip("$") output_doc[field_name] = sum( self.collection._get_val(d, sum_field) or 0 for d in bucket_docs ) elif "$avg" in accumulator: avg_field = accumulator["$avg"].lstrip("$") values = [ self.collection._get_val(d, avg_field) for d in bucket_docs ] output_doc[field_name] = ( sum(values) / len(values) if values else 0 ) elif "$count" in accumulator: output_doc[field_name] = len(bucket_docs) elif "$min" in accumulator: min_field = accumulator["$min"].lstrip("$") values = [ self.collection._get_val(d, min_field) for d in bucket_docs ] output_doc[field_name] = ( min(values) if values else None ) elif "$max" in accumulator: max_field = accumulator["$max"].lstrip("$") values = [ self.collection._get_val(d, max_field) for d in bucket_docs ] output_doc[field_name] = ( max(values) if values else None ) else: output_doc[field_name] = len(bucket_docs) new_docs.append(output_doc) docs_with_context = [ {"__doc__": doc, "__root__": doc} for doc in new_docs ] case "$bucketAuto": bucket_auto_spec = stage["$bucketAuto"] group_by = bucket_auto_spec.get("groupBy", "").lstrip("$") num_buckets = bucket_auto_spec.get("buckets", 10) output_spec = bucket_auto_spec.get( "output", {"count": {"$sum": 1}} ) if not group_by or num_buckets <= 0: docs_with_context = [] break # Sort documents by groupBy field def get_group_val(dc): """ Extract the value to group by for a given document. Args: dc: Document with context from the aggregation stage. Returns: The value to group by or 0 if not found. """ return ( self.collection._get_val(dc["__doc__"], group_by) or 0 ) sorted_docs = sorted( docs_with_context, key=get_group_val, ) # Distribute into buckets bucket_size = max(1, len(sorted_docs) // num_buckets) bucket_list: list[list[dict[str, Any]]] = [] bucket_bounds: list[tuple[Any, Any]] = [] current_bucket = [] current_min = None for dc in sorted_docs: current_bucket.append(dc["__doc__"]) if current_min is None: current_min = get_group_val(dc) if ( len(current_bucket) >= bucket_size and len(bucket_list) < num_buckets - 1 ): bucket_list.append(current_bucket) bucket_bounds.append( (current_min, get_group_val(dc)) ) current_bucket = [] current_min = None if current_bucket: bucket_list.append(current_bucket) bucket_bounds.append( ( current_min, ( get_group_val(sorted_docs[-1]) if sorted_docs else current_min ), ) ) # Build output documents new_docs = [] for i, bucket_docs in enumerate(bucket_list): output_doc2: dict[str, Any] = { "_id": { "min": bucket_bounds[i][0], "max": bucket_bounds[i][1], } } for field_name, accumulator in output_spec.items(): if "$sum" in accumulator: sum_field = accumulator["$sum"] if sum_field == 1: output_doc2[field_name] = len(bucket_docs) else: sum_field = sum_field.lstrip("$") output_doc2[field_name] = sum( self.collection._get_val(d, sum_field) or 0 for d in bucket_docs ) elif "$avg" in accumulator: avg_field = accumulator["$avg"].lstrip("$") values = [ self.collection._get_val(d, avg_field) for d in bucket_docs ] output_doc2[field_name] = ( sum(values) / len(values) if values else 0 ) elif "$count" in accumulator: output_doc2[field_name] = len(bucket_docs) else: output_doc2[field_name] = len(bucket_docs) new_docs.append(output_doc2) docs_with_context = [ {"__doc__": doc, "__root__": doc} for doc in new_docs ] case "$unionWith": union_spec = stage["$unionWith"] coll_name = union_spec.get("coll") pipeline = union_spec.get("pipeline", []) if not coll_name: break # Get documents from other collection other_coll = self.collection._database[coll_name] other_docs = list(other_coll.find()) # Apply pipeline if specified if pipeline: other_docs = list(other_coll.aggregate(pipeline)) # Combine documents current_docs = [dc["__doc__"] for dc in docs_with_context] combined_docs = current_docs + other_docs docs_with_context = [ {"__doc__": doc, "__root__": doc} for doc in combined_docs ] case "$merge": # $merge writes results to a collection merge_spec = stage["$merge"] # Handle different merge spec formats if isinstance(merge_spec, str): # Simple format: just collection name target_coll_name = merge_spec merge_options = {} elif isinstance(merge_spec, dict): # Full format with options into = merge_spec.get("into", "") if isinstance(into, dict): target_coll_name = ( into.get("db", "") + "." + into.get("coll", "") ) else: target_coll_name = into merge_options = { "on": merge_spec.get("on", "_id"), "whenMatched": merge_spec.get( "whenMatched", "replace" ), "whenNotMatched": merge_spec.get( "whenNotMatched", "insert" ), } else: target_coll_name = "merged" merge_options = {} # Get or create target collection if "." in target_coll_name: db_name, coll_name = target_coll_name.split(".", 1) target_coll = self.collection._database.client[db_name][ coll_name ] else: target_coll = self.collection._database[ target_coll_name ] # Process each document for dc in docs_with_context: doc = dc["__doc__"] # Get the "on" field value for matching on_field = merge_options.get("on", "_id") on_value = self.collection._get_val(doc, on_field) # Try to find existing document existing = None if on_value is not None: existing = target_coll.find_one( {on_field: on_value} ) when_matched = merge_options.get( "whenMatched", "replace" ) when_not_matched = merge_options.get( "whenNotMatched", "insert" ) if existing: match when_matched: case "replace": existing_id = existing.get("_id") new_doc = { k: v for k, v in doc.items() if k != "_id" } if existing_id is not None: new_doc["_id"] = existing_id target_coll.update_one( {on_field: on_value}, {"$set": new_doc}, ) fields_to_remove = [ k for k in existing if k not in new_doc and k != on_field ] if fields_to_remove: target_coll.update_one( {on_field: on_value}, { "$unset": { f: "" for f in fields_to_remove } }, ) case "merge": update_doc = { k: v for k, v in doc.items() if k != "_id" } target_coll.update_one( {on_field: on_value}, {"$set": update_doc}, ) case "keepExisting": # Keep existing, don't update pass case "fail": raise Exception( f"$merge failed: document with {on_field}={on_value} already exists" ) # Note: "pipeline" mode not implemented else: # Document doesn't exist - handle based on whenNotMatched match when_not_matched: case "insert": target_coll.insert_one(doc) case "fail": raise Exception( f"$merge failed: no document found with {on_field}={on_value}" ) # After merge, return empty or pass through based on requirements # MongoDB returns the merged documents for further pipeline processing pass case "$redact": # $redact filters document content based on conditions redact_spec = stage["$redact"] # Create evaluator for condition evaluation evaluator_redact = ExprEvaluator( data_column="data", db_connection=self.collection.db ) def apply_redact(doc, spec): """Recursively apply redaction to a document.""" if not isinstance(doc, dict): return doc result = {} for key, value in doc.items(): # Evaluate the redact condition for this field redact_action = evaluate_redact_condition( spec, doc, key, value ) if redact_action == "$$KEEP": # Keep the field as-is result[key] = value elif redact_action == "$$DESCEND": # Keep and process sub-fields if isinstance(value, dict): result[key] = apply_redact(value, spec) elif isinstance(value, list): result[key] = [ ( apply_redact(item, spec) if isinstance(item, dict) else item ) for item in value ] else: result[key] = value elif redact_action == "$$PRUNE": # Remove this field (don't add to result) pass return result def evaluate_redact_condition(spec, doc, key, value): """Evaluate the redact condition and return KEEP/DESCEND/PRUNE.""" if "$cond" in spec: cond = spec["$cond"] if_expr = cond.get("if", {}) then_expr = cond.get("then", "$$DESCEND") else_expr = cond.get("else", "$$DESCEND") # Evaluate the condition try: cond_result = ( evaluator_redact._evaluate_expr_python( if_expr, doc ) ) if cond_result: return then_expr else: return else_expr except Exception as e: logger.debug( f"Redaction evaluation failed: {e}" ) return "$$DESCEND" # If spec is a direct expression, evaluate it if ( spec.startswith("$") if isinstance(spec, str) else False ): try: result = evaluator_redact._evaluate_expr_python( spec, doc ) if result in ("$$KEEP", "$$DESCEND", "$$PRUNE"): return result except Exception as e: logger.debug( f"Redaction expression evaluation failed: {e}" ) pass return "$$DESCEND" # Apply redaction to each document new_docs_with_context = [] for dc in docs_with_context: doc = dc["__doc__"] root = dc["__root__"] # Apply redaction redacted_doc = apply_redact(doc, redact_spec) # Check if document should be kept (not fully pruned) if redacted_doc: # Non-empty document new_docs_with_context.append( {"__doc__": redacted_doc, "__root__": root} ) docs_with_context = new_docs_with_context case "$densify": # $densify fills gaps in sequential data densify_spec = stage["$densify"] field = densify_spec.get("field") range_spec = densify_spec.get("range", {}) partition_by = densify_spec.get("partitionByFields", []) output_spec = densify_spec.get("output", {}) if not field: # No field specified, pass through pass else: # Get bounds bounds = range_spec.get("bounds") step = range_spec.get("step", 1) unit = range_spec.get("unit", None) # For dates # Determine if we're working with dates or numbers is_date = unit is not None # Group documents by partition fields partitions: dict[tuple, list[dict[str, Any]]] = {} for dc in docs_with_context: doc = dc["__doc__"] # Get partition key if partition_by: partition_key = tuple( self.collection._get_val(doc, pf) for pf in partition_by ) else: partition_key = () if partition_key not in partitions: partitions[partition_key] = [] field_val = self.collection._get_val(doc, field) partitions[partition_key].append( {"doc": doc, "field_val": field_val, "dc": dc} ) # Generate densified output for each partition new_docs_with_context = [] for partition_key, items in partitions.items(): # Get all field values in this partition field_values = [ item["field_val"] for item in items if item["field_val"] is not None ] if not field_values: continue # Determine range if bounds == "full": min_val = min(field_values) max_val = max(field_values) elif isinstance(bounds, list) and len(bounds) >= 2: min_val, max_val = bounds[0], bounds[1] else: min_val = min(field_values) max_val = max(field_values) # Generate all values in range existing_values = set(field_values) all_values = [] if is_date: # Handle date ranges from datetime import timedelta current = min_val while current <= max_val: all_values.append(current) match unit: case "year": current = current.replace( year=current.year + step ) case "month": new_month = current.month + step new_year = ( current.year + (new_month - 1) // 12 ) new_month = ( (new_month - 1) % 12 ) + 1 try: current = current.replace( year=new_year, month=new_month, ) except ValueError: # Handle month-end edge break case "day": current = current + timedelta( days=step ) case "hour": current = current + timedelta( hours=step ) case "minute": current = current + timedelta( minutes=step ) case "second": current = current + timedelta( seconds=step ) case _: break else: # Handle numeric ranges current = min_val while current <= max_val: all_values.append(current) current = current + step # Create documents for all values for val in all_values: if val in existing_values: # Use existing document for item in items: if item["field_val"] == val: new_docs_with_context.append( item["dc"] ) break else: # Create new document with filled value for item in items: base_doc = deepcopy(item["doc"]) base_doc[field] = val # Apply output spec for additional fields for ( out_field, out_expr, ) in output_spec.items(): if out_field != field: base_doc[out_field] = ( out_expr # Could evaluate expression ) new_docs_with_context.append( { "__doc__": base_doc, "__root__": base_doc, } ) break # Use first item as template docs_with_context = new_docs_with_context case "$collStats": from ...sql_utils import quote_table_name coll_stats_spec = ( stage.get("$collStats") or stage.get(" $collStats") or {} ) table_name = self.collection.name quoted_table = quote_table_name(table_name) db = self.collection.db count_cursor = db.execute( f"SELECT COUNT(*) FROM {quoted_table}" ) count = count_cursor.fetchone()[0] or 0 size = 0 try: size_cursor = db.execute( f"SELECT SUM(LENGTH(data)) FROM {quoted_table}" ) size = size_cursor.fetchone()[0] or 0 except Exception as e: logger.debug( f"Failed to calculate collection size for stats: {e}" ) pass avg_obj_size = size / count if count > 0 else 0 storage_size = 0 total_index_size = 0 index_sizes: dict[str, int] = {} try: db.execute( "CREATE VIRTUAL TABLE IF NOT EXISTS temp.dbstat USING dbstat(main)" ) storage_cursor = db.execute( "SELECT SUM(pgsize) FROM dbstat WHERE name = ?", (table_name,), ) storage_size = storage_cursor.fetchone()[0] or 0 index_cursor = db.execute( "SELECT name, SUM(pgsize) as size FROM dbstat " "WHERE tbl_name = ? AND type = 'index' GROUP BY name", (table_name,), ) for row in index_cursor.fetchall(): idx_name, idx_size = row if idx_name and idx_size: index_sizes[idx_name] = idx_size total_index_size += idx_size except Exception as e: logger.debug( f"Failed to calculate storage/index sizes for stats: {e}" ) pass db_name = ( self.collection._database.name if self.collection._database else "unknown" ) stats_result: dict[str, Any] = { "ns": f"{db_name}.{table_name}", "count": count, "size": size, "avgObjSize": avg_obj_size, "storageSize": storage_size, "totalIndexSize": total_index_size, "indexSizes": index_sizes, } if coll_stats_spec and "count" in coll_stats_spec: stats_result = {"count": count} elif coll_stats_spec and "storageStats" in coll_stats_spec: stats_result = { "ns": f"{db_name}.{table_name}", "storageStats": { "count": count, "size": size, "avgObjSize": avg_obj_size, "storageSize": storage_size, "totalIndexSize": total_index_size, "indexSizes": index_sizes, }, } docs_with_context = [ {"__doc__": stats_result, "__root__": stats_result} ] case _: raise MalformedQueryException( f"Aggregation stage '{stage_name}' not supported" ) self._notify_tier_change("tier3", pipeline) return [dc["__doc__"] for dc in docs_with_context]
[docs] def explain_aggregation( self, pipeline: list[dict[str, Any]], session: ClientSession | None = None, ) -> dict[str, Any]: """ Explain the execution plan for an aggregation pipeline. Args: pipeline (list[dict[str, Any]]): The aggregation pipeline to explain. session (ClientSession, optional): A ClientSession for transactions. Returns: dict[str, Any]: The execution plan explanation. """ # 1. Try SQL Tier 1 optimization if self.sql_tier_aggregator.can_optimize_pipeline(pipeline): sql, params = self.sql_tier_aggregator.build_pipeline_sql(pipeline) if sql is not None: # Use EXPLAIN QUERY PLAN to get SQLite's plan explain_sql = f"EXPLAIN QUERY PLAN {sql}" db_cursor = self.collection.db.execute(explain_sql, params) plan = db_cursor.fetchall() return { "tier": 1, "type": "SQL Tier 1 (CTE-based)", "sql": sql, "params": params, "sqlite_plan": plan, } # 2. Try legacy SQL optimization query_result = self.helpers._build_aggregation_query(pipeline) if query_result is not None: cmd, params, _ = query_result explain_sql = f"EXPLAIN QUERY PLAN {cmd}" db_cursor = self.collection.db.execute(explain_sql, params) plan = db_cursor.fetchall() return { "tier": 1, "type": "SQL Tier 1.5 (Non-CTE-based)", "sql": cmd, "params": params, "sqlite_plan": plan, } # 3. Check if Tier 2 (Temp Table) can handle it from ..temporary_table_aggregation import ( can_process_with_temporary_tables, ) if can_process_with_temporary_tables(pipeline): return { "tier": 2, "type": "Temporary Table Aggregation", "pipeline": pipeline, } # 4. Fallback to Python return { "tier": 3, "type": "Python Fallback", "pipeline": pipeline, }
[docs] def aggregate_raw_batches( self, pipeline: list[dict[str, Any]], batch_size: int = 100, session: ClientSession | None = None, ) -> RawBatchCursor: """ Perform aggregation and retrieve batches of raw JSON. Similar to the :meth:`aggregate` method but returns a :class:`~neosqlite.raw_batch_cursor.RawBatchCursor`. This method returns raw JSON batches which can be more efficient for certain use cases where you want to process data in batches rather than individual documents. Args: pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages to apply. batch_size (int): The number of documents to include in each batch. session (ClientSession, optional): A ClientSession for transactions. Returns: RawBatchCursor instance. """ validate_session(session, self.collection._database) return RawBatchCursor( self.collection, None, None, None, batch_size, pipeline=pipeline, session=session, )
# --- Bulk Write methods ---
[docs] def bulk_write( self, requests: list[Any], ordered: bool = True, session: ClientSession | None = None, ) -> BulkWriteResult: """ Execute bulk write operations on the collection. Args: requests: List of write operations to execute. ordered: If true, operations will be performed in order and will raise an exception if a single operation fails. session (ClientSession, optional): A ClientSession for transactions. Returns: BulkWriteResult: A result object containing the number of matched, modified, and inserted documents. """ validate_session(session, self.collection._database) inserted_count = 0 matched_count = 0 modified_count = 0 deleted_count = 0 upserted_count = 0 released = False self.collection.db.execute("SAVEPOINT bulk_write") try: for req in requests: match req: case InsertOne(document=doc): self.insert_one(doc, session=session) inserted_count += 1 case UpdateOne(filter=f, update=u, upsert=up): update_res = self.update_one(f, u, up, session=session) matched_count += update_res.matched_count modified_count += update_res.modified_count if update_res.upserted_id: upserted_count += 1 case DeleteOne(filter=f): delete_res = self.delete_one(f, session=session) deleted_count += delete_res.deleted_count self.collection.db.execute("RELEASE SAVEPOINT bulk_write") released = True except Exception as e: logger.debug(f"Error in bulk_write: {e}") self.collection.db.execute("ROLLBACK TO SAVEPOINT bulk_write") raise e finally: if not released: try: self.collection.db.execute("RELEASE SAVEPOINT bulk_write") except Exception as e: logger.debug(f"Failed to release bulk_write savepoint: {e}") pass return BulkWriteResult( inserted_count=inserted_count, matched_count=matched_count, modified_count=modified_count, deleted_count=deleted_count, upserted_count=upserted_count, )
[docs] def _aggregate_with_quez( self, pipeline: list[dict[str, Any]], batch_size: int = 101 ) -> CompressedQueue: """ Process aggregation pipeline with quez compressed queue for memory efficiency. Args: pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages to apply. batch_size (int): The batch size for quez queue processing. Returns: CompressedQueue: A compressed queue containing the results. """ try: if _HAS_QUEZ: from quez import CompressedQueue # Create a compressed queue for results with a reasonable size # Use unbounded queue to avoid blocking during population result_queue = CompressedQueue() # Get results from normal aggregation results = self.aggregate(pipeline) # Add all results to the compressed queue for result in results: result_queue.put(result) return result_queue except ImportError: # If quez is not available, fall back to normal processing # This should never happen since we check for quez availability before calling this method raise RuntimeError("Quez is not available but was expected to be")
[docs] def initialize_ordered_bulk_op(self) -> BulkOperationExecutor: """Initialize an ordered bulk operation. Returns: BulkOperationExecutor: An executor for ordered bulk operations. """ return BulkOperationExecutor(self.collection, ordered=True)
[docs] def initialize_unordered_bulk_op(self) -> BulkOperationExecutor: """Initialize an unordered bulk operation. Returns: BulkOperationExecutor: An executor for unordered bulk operations. """ return BulkOperationExecutor(self.collection, ordered=False)