Source code for neosqlite.collection.query_helper.query_optimizer

"""
Query optimization utilities for NeoSQLite collections.

This module provides a mixin class with methods for query cost estimation,
index analysis, and pipeline optimization.
"""

from typing import TYPE_CHECKING, Any

from ...sql_utils import quote_table_name

if TYPE_CHECKING:
    from .. import Collection


[docs] class QueryOptimizerMixin: """ A mixin class providing query optimization methods. This mixin assumes it will be used with a class that has: - self.collection (with db and name attributes) - self._jsonb_supported - self._json_function_prefix """ collection: "Collection" _jsonb_supported: bool _json_function_prefix: str
[docs] def _get_indexed_fields(self) -> list[str]: """ Get a list of indexed fields for this collection. Returns: list[str]: A list of field names that have indexes. """ # Get indexes that match our naming convention cmd = ( "SELECT name FROM sqlite_master WHERE type='index' AND name LIKE ?" ) like_pattern = f"idx_{quote_table_name(self.collection.name)}_%" indexes = self.collection.db.execute(cmd, (like_pattern,)).fetchall() indexed_fields = [] for idx in indexes: # Extract key name from index name (idx_collection_key -> key) key_name = idx[0][ len(f"idx_{quote_table_name(self.collection.name)}_") : ] # Convert underscores back to dots for nested keys key_name = key_name.replace("_", ".") # Skip the automatically created _id index since it should be hidden # like MongoDB's automatic _id index if key_name == "id": # This corresponds to the _id column index continue indexed_fields.append(key_name) return indexed_fields
[docs] def _estimate_result_size(self, pipeline: list[dict[str, Any]]) -> int: """ Estimate the size of the aggregation result in bytes. This method analyzes the pipeline to estimate the size of the result set. Args: pipeline: The aggregation pipeline to analyze Returns: Estimated size in bytes """ # Get the base collection size base_count = self.collection.estimated_document_count() # Apply pipeline stages to estimate result size estimated_count = base_count estimated_avg_doc_size = 1024 # Default estimate of 1KB per document for stage in pipeline: stage_name = next(iter(stage.keys())) match stage_name: case "$match": # Matches typically reduce the result set # For now, we'll use a rough estimate estimated_count = max(1, int(estimated_count * 0.5)) case "$limit": limit_count = stage["$limit"] estimated_count = min(estimated_count, limit_count) case "$skip": skip_count = stage["$skip"] estimated_count = max(0, estimated_count - skip_count) case "$unwind": # Unwind operations can multiply the result set # This is a very rough estimate estimated_count = ( estimated_count * 3 ) # Assume 3 elements per array on average case "$group": # Group operations typically reduce the result set # This is a very rough estimate estimated_count = max(1, int(estimated_count * 0.1)) case _: # For other operations, we'll assume they don't significantly change the size pass # Apply some limits to prevent extreme estimates estimated_count = min( estimated_count, base_count * 10 ) # Cap at 10x the base count estimated_count = max(estimated_count, 0) # Ensure non-negative return estimated_count * estimated_avg_doc_size
[docs] def _estimate_query_cost(self, query: dict[str, Any]) -> float: """ Estimate the cost of executing a query based on index availability. Lower cost values indicate more efficient queries. Args: query (dict[str, Any]): A dictionary representing the query criteria. Returns: float: Estimated cost of the query (lower is better). """ # Get indexed fields indexed_fields = self._get_indexed_fields() # Base cost cost = 1.0 # Check if we can use indexes for any fields in the query for field, value in query.items(): if field in ("$and", "$or", "$nor", "$not"): # Handle logical operators recursively if isinstance(value, list): for subquery in value: if isinstance(subquery, dict): cost *= self._estimate_query_cost(subquery) elif isinstance(value, dict): cost *= self._estimate_query_cost(value) elif field == "_id": # _id field is always indexed (it's a column) cost *= 0.1 # Very low cost for _id queries elif field in indexed_fields: # Field is indexed, reduce cost cost *= 0.3 # Lower cost when using an index else: # Field is not indexed, increase cost cost *= 1.0 # No change for non-indexed fields return cost
[docs] def _estimate_pipeline_cost(self, pipeline: list[dict[str, Any]]) -> float: """ Estimate the total cost of executing an aggregation pipeline. Lower cost values indicate more efficient pipelines. This method considers data flow - earlier stages affect more documents. Args: pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages. Returns: float: Estimated cost of the pipeline (lower is better). """ total_cost = 0.0 cumulative_multiplier = ( 1.0 # Represents how much data flows through each stage ) for i, stage in enumerate(pipeline): stage_name = next(iter(stage.keys())) stage_cost = 0.0 match stage_name: case "$match": # Estimate cost of match stage query = stage["$match"] stage_cost = self._estimate_query_cost(query) # Matches early in the pipeline are more beneficial because they reduce # the amount of data flowing to later stages stage_cost *= cumulative_multiplier # Update data flow multiplier based on selectivity # Assume matches reduce data by 50% on average cumulative_multiplier *= 0.5 case "$sort": # Sort operations have moderate cost, weighted by data volume stage_cost = 1.0 * cumulative_multiplier case "$skip": # Skip operations have low cost stage_cost = 0.1 * cumulative_multiplier case "$limit": # Limit operations have low cost but dramatically reduce data flow stage_cost = 0.1 * cumulative_multiplier # Limits significantly reduce data flow to subsequent stages cumulative_multiplier *= ( 0.1 # Assume limits reduce data by 90% ) case "$group": # Group operations have high cost (require processing all data) stage_cost = 5.0 * cumulative_multiplier # Groups typically reduce data significantly cumulative_multiplier *= ( 0.2 # Assume groups reduce data by 80% ) case "$unwind": # Unwind operations multiply the data size, increasing cost and data flow stage_cost = 2.0 * cumulative_multiplier # Unwinds increase data volume (assume 5x increase on average) cumulative_multiplier *= 5.0 case "$lookup": # Lookup operations have high cost (joins) stage_cost = 3.0 * cumulative_multiplier # Lookups may increase data slightly cumulative_multiplier *= 1.2 case _: # Unknown operations have moderate cost stage_cost = 1.5 * cumulative_multiplier # Assume unknown operations don't significantly change data volume # cumulative_multiplier stays the same total_cost += stage_cost return total_cost
[docs] def _optimize_match_pushdown( self, pipeline: list[dict[str, Any]] ) -> list[dict[str, Any]]: """ Optimize pipeline by pushing $match stages down to earlier positions when beneficial. This optimization moves $match stages earlier in the pipeline when they can filter data before expensive operations like $unwind or $group. Note: $match stages with $text search are NOT pushed down when they follow $unwind stages, as the text search semantics depend on the unwound data. Args: pipeline (list[dict[str, Any]]): The pipeline stages to optimize. Returns: list[dict[str, Any]]: The optimized pipeline. """ if len(pipeline) < 2: return pipeline # Look for patterns where we can push matches down optimized = pipeline.copy() # Find all $match stages match_stages = [] other_stages = [] for i, stage in enumerate(optimized): stage_name = next(iter(stage.keys())) if stage_name == "$match": # Don't push down $match with $text search or $jsonSchema - they need to operate # on the data as transformed by previous stages or required specific handling match_spec = stage["$match"] if "$text" in match_spec or "$jsonSchema" in match_spec: other_stages.append((i, stage)) else: match_stages.append((i, stage)) else: other_stages.append((i, stage)) # If we have matches and expensive operations, consider reordering expensive_ops = {"$unwind", "$group", "$lookup"} has_expensive_ops = any( next(iter(stage.keys())) in expensive_ops for _, stage in other_stages ) if match_stages and has_expensive_ops: # Move matches to the front to filter early match_stage_items = [stage for _, stage in match_stages] other_stage_items = [stage for _, stage in other_stages] return match_stage_items + other_stage_items return optimized
[docs] def _is_datetime_indexed_field(self, field: str) -> bool: """ Check if a field has a datetime index by looking for it in the database indexes. Datetime indexes are created with the pattern: idx_{collection}_{field}_utc Args: field: The field name to check for datetime indexing Returns: bool: True if the field has a datetime index, False otherwise """ # Construct the expected index name for datetime indexes # Convert dots to underscores in field name field_name_for_index = field.replace(".", "_") expected_datetime_index_name = f"idx_{quote_table_name(self.collection.name)}_{field_name_for_index}_utc" # Query the SQLite master table to check if this specific index exists cursor = self.collection.db.execute( "SELECT name FROM sqlite_master WHERE type='index' AND name = ?", (expected_datetime_index_name,), ) return cursor.fetchone() is not None
[docs] def _reorder_pipeline_for_indexes( self, pipeline: list[dict[str, Any]] ) -> list[dict[str, Any]]: """ Reorder pipeline stages to optimize performance based on index availability. Moves $match stages with indexed fields to the beginning of the pipeline to take advantage of index-based filtering. Args: pipeline (list[dict[str, Any]]): The original pipeline stages. Returns: list[dict[str, Any]]: The reordered pipeline stages. """ if not pipeline: return pipeline # Get indexed fields indexed_fields = set(self._get_indexed_fields()) # Separate match stages with indexed fields from others indexed_matches = [] other_stages = [] for stage in pipeline: stage_name = next(iter(stage.keys())) if stage_name == "$match": # Check if this match uses indexed fields match_query = stage["$match"] has_indexed_field = False # Simple check for direct field references for field in match_query.keys(): if field in indexed_fields or field == "_id": has_indexed_field = True break # For logical operators, check nested fields if not has_indexed_field: for field, value in match_query.items(): if field in ("$and", "$or") and isinstance(value, list): for condition in value: if isinstance(condition, dict): for subfield in condition.keys(): if ( subfield in indexed_fields or subfield == "_id" ): has_indexed_field = True break if has_indexed_field: break elif field == "_id": has_indexed_field = True if has_indexed_field: indexed_matches.append(stage) else: other_stages.append(stage) else: other_stages.append(stage) # Return reordered pipeline: indexed matches first, then other stages return indexed_matches + other_stages