"""
Aggregation pipeline methods for NeoSQLite.
This module contains the AggregationMixin class, which provides methods for
building and executing MongoDB-like aggregation pipelines using SQL.
"""
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from ...sql_utils import quote_table_name
from ..cursor import DESCENDING
from ..expr_evaluator import (
AggregationContext,
ExprEvaluator,
_is_expression,
)
from ..json_path_utils import (
parse_json_path,
)
logger = logging.getLogger(__name__)
# Import utility functions
from .utils import (
get_force_fallback,
)
if TYPE_CHECKING:
from .. import Collection
[docs]
class AggregationMixin:
"""
Mixin class providing aggregation pipeline methods.
This mixin assumes it will be used with a class that has the following:
Attributes:
self.collection: A collection instance with:
- db: Database connection
- name: Collection name
- _load: Method to load documents
- _get_val: Method to get values from documents
- _set_val: Method to set values in documents
self._jsonb_supported: Whether JSONB is supported
self._json_function_prefix: "json" or "jsonb"
self._json_each_function: "json_each" or "jsonb_each"
self._build_simple_where_clause: Method to build WHERE clauses
self._reorder_pipeline_for_indexes: Method to reorder pipelines
self._estimate_pipeline_cost: Method to estimate costs
self._optimize_match_pushdown: Method to optimize match pushdown
self._is_datetime_indexed_field: Method to check datetime indexes
self._build_group_query: Method to build group queries
self._apply_query: Method to apply queries to documents
"""
collection: "Collection"
_jsonb_supported: bool
_json_function_prefix: str
_json_each_function: str
_build_simple_where_clause: Any
_reorder_pipeline_for_indexes: Any
_estimate_pipeline_cost: Any
_optimize_match_pushdown: Any
_is_datetime_indexed_field: Any
_apply_query: Any
[docs]
def _build_aggregation_query(
self,
pipeline: list[dict[str, Any]],
) -> tuple[str, list[Any], list[str] | None] | None:
"""
Builds a SQL query for the given MongoDB-like aggregation pipeline.
This method constructs a SQL query based on the stages provided in the
aggregation pipeline. It currently handles $match, $sort, $skip,
and $limit stages, while $group stages are handled in Python. The method
returns a tuple containing the SQL command and a list of parameters.
Args:
pipeline (list[dict[str, Any]]): A list of aggregation pipeline stages.
Returns:
tuple[str, list[Any]] | None: A tuple containing the SQL command and
a list of parameters, or None if the
pipeline contains unsupported stages
or complex queries.
"""
# Check if we should force fallback for benchmarking/debugging
if get_force_fallback():
return None # Force fallback to Python implementation
# Try to optimize the pipeline by reordering for better index usage
optimized_pipeline = self._reorder_pipeline_for_indexes(pipeline)
# Estimate costs for both original and optimized pipelines
original_cost = self._estimate_pipeline_cost(pipeline)
optimized_cost = self._estimate_pipeline_cost(optimized_pipeline)
# Use the better pipeline based on cost estimation
if optimized_cost < original_cost:
# Use optimized pipeline
effective_pipeline = optimized_pipeline
else:
# Use original pipeline
effective_pipeline = pipeline
# Additional optimization: Check if we can push match filters down into SQL operations
effective_pipeline = self._optimize_match_pushdown(effective_pipeline)
where_clause = ""
params: list[Any] = []
order_by = ""
limit = ""
offset = ""
group_by = ""
select_clause = "SELECT id, data"
output_fields: list[str] | None = None
for i, stage in enumerate(effective_pipeline):
stage_name = next(iter(stage.keys()))
match stage_name:
case "$match":
query = stage["$match"]
where_result = self._build_simple_where_clause(query)
if where_result is None:
return None # Fallback for complex queries
where_clause, params, tables = where_result
case "$sort":
sort_spec = stage["$sort"]
sort_clauses = []
for key, direction in sort_spec.items():
# When sorting after a group stage, we sort by the output field name
if group_by:
sort_clauses.append(
f"{key} {'DESC' if direction == DESCENDING else 'ASC'}"
)
else:
sort_clauses.append(
f"{self._json_function_prefix}_extract(data, '{parse_json_path(key)}') "
f"{'DESC' if direction == DESCENDING else 'ASC'}"
)
order_by = "ORDER BY " + ", ".join(sort_clauses)
case "$skip":
count = stage["$skip"]
offset = f"OFFSET {count}"
case "$limit":
count = stage["$limit"]
limit = f"LIMIT {count}"
case "$group":
# Check if this is a $unwind + $group pattern we can optimize
optimization_result = self._optimize_unwind_group_pattern(
i, pipeline
)
if optimization_result is not None:
return optimization_result
# A group stage must be the first stage or after a match stage
if i > 1 or (i == 1 and "$match" not in pipeline[0]):
return None
group_spec = stage["$group"]
group_result = self._build_group_query(group_spec)
if group_result is None:
return None
select_clause, group_by, output_fields = group_result
case "$unwind":
# Check if this is followed by $match with $text - fall back to temp tables
# The text search on unwound elements requires special handling that
# the single-SQL optimization cannot provide
# Also check for multiple consecutive $unwind stages
if len(pipeline) > i + 1:
# Count consecutive $unwind stages starting from position i
unwind_count = 0
match_with_text_idx = -1
j = i
while j < len(pipeline) and "$unwind" in pipeline[j]:
unwind_count += 1
j += 1
# Check if next stage after unwinds is $match with $text
if (
j < len(pipeline)
and "$match" in pipeline[j]
and "$text" in pipeline[j]["$match"]
):
match_with_text_idx = j
# Fall back for: single unwind + text, multiple unwinds, or multiple unwinds + text
if match_with_text_idx >= 0 or unwind_count > 1:
return None # Fall back to temp table approach
# Check if this is part of an $unwind + $group pattern we can optimize
# Case 1: $unwind is first stage followed by $group
if i == 0 and len(pipeline) > 1 and "$group" in pipeline[1]:
# $unwind followed by $group - try to optimize with SQL
group_stage = pipeline[1]["$group"]
unwind_field = stage["$unwind"]
if (
isinstance(unwind_field, str)
and unwind_field.startswith("$")
and isinstance(group_stage.get("_id"), str)
and group_stage.get("_id").startswith("$")
):
unwind_field_name = unwind_field[
1:
] # Remove leading $
group_id_field = group_stage["_id"][
1:
] # Remove leading $
# Build SELECT and GROUP BY clauses for unwound data
select_expressions = []
output_fields = ["_id"]
# Handle _id field
if group_id_field == unwind_field_name:
# Grouping by the unwound field
select_expressions.append("je.value AS _id")
group_by_clause = "GROUP BY je.value"
else:
# Grouping by another field
select_expressions.append(
f"{self._json_function_prefix}_extract({quote_table_name(self.collection.name)}.data, '{parse_json_path(group_id_field)}') AS _id"
)
group_by_clause = f"GROUP BY {self._json_function_prefix}_extract({quote_table_name(self.collection.name)}.data, '{parse_json_path(group_id_field)}')"
# Try to build the group query using the general method
# This supports all accumulator operations including $avg, $min, $max
group_result = self._build_group_query(group_stage)
if group_result is not None:
(
select_clause,
group_by_clause,
group_output_fields,
) = group_result
# Modify the SELECT clause to work with the unwound data
# Replace json_extract(data, '$.field') with appropriate expressions
table_name = quote_table_name(
self.collection.name
)
# For the _id field, if it matches the unwind field, use je.value
group_id_field = group_stage["_id"][
1:
] # Remove leading $
if group_id_field == unwind_field_name:
# Replace the _id extraction with je.value
modified_select = select_clause.replace(
f"{self._json_function_prefix}_extract(data, '{parse_json_path(unwind_field_name)}') AS _id",
"je.value AS _id",
)
# For GROUP BY clause, use je.value when grouping by the unwind field
group_by_clause = "GROUP BY je.value"
else:
modified_select = select_clause
# Keep the original GROUP BY but ensure it references the correct table
group_by_clause = group_by_clause.replace(
f"{self._json_function_prefix}_extract(data,",
f"{self._json_function_prefix}_extract({table_name}.data,",
)
# Replace all other json_extract(data, ...) with json_extract(table.data, ...)
# to properly reference the table column in the JOIN context
modified_select = modified_select.replace(
f"{self._json_function_prefix}_extract(data,",
f"{self._json_function_prefix}_extract({table_name}.data,",
)
# For fields that reference the unwind field (e.g., $push: "$tags" when unwinding "$tags"),
# replace with je.value to get the unwound value instead of the full array
modified_select = modified_select.replace(
f"{self._json_function_prefix}_extract({table_name}.data, '{parse_json_path(unwind_field_name)}')",
"je.value",
)
# Build the FROM clause with json_each for unwinding
from_clause = f"FROM {table_name}, {self._json_each_function}({self._json_function_prefix}_extract({table_name}.data, '{parse_json_path(unwind_field_name)}')) as je"
# Add ordering by _id for consistent results
order_by_clause = "ORDER BY _id"
# Construct the full SQL command
cmd = f"{modified_select} {from_clause} {group_by_clause} {order_by_clause}"
return cmd, [], group_output_fields
# Fall back to Python for complex unwind scenarios
return None
case _:
return None # Fallback for unsupported stages
cmd = f"{select_clause} FROM {quote_table_name(self.collection.name)} {where_clause} {group_by} {order_by} {limit} {offset}"
return cmd, params, output_fields
[docs]
def _optimize_unwind_group_pattern(
self, group_stage_index: int, pipeline: list[dict[str, Any]]
) -> tuple[str, list[Any], list[str]] | None:
"""
Optimize $unwind + $group pattern with SQL-based processing.
This method handles the specific optimization pattern where a $unwind stage
is immediately followed by a $group stage. It supports all accumulator
operations by leveraging the general _build_group_query method while
handling the $unwind optimization.
Args:
group_stage_index: Index of the $group stage in the pipeline
pipeline: The complete aggregation pipeline
Returns:
tuple[str, list[Any], list[str]] | None: SQL command, params, and output fields
if optimization is possible, None otherwise
"""
# Check if this is a $unwind + $group pattern we can optimize
if group_stage_index == 1 and "$unwind" in pipeline[0]:
# $unwind followed by $group - try to optimize with SQL
unwind_stage = pipeline[0]["$unwind"]
group_spec = pipeline[group_stage_index]["$group"]
if (
isinstance(unwind_stage, str)
and unwind_stage.startswith("$")
and isinstance(group_spec.get("_id"), str)
and group_spec.get("_id").startswith("$")
):
unwind_field = unwind_stage[1:] # Remove leading $
table_name = quote_table_name(self.collection.name)
# Try to build the group query using the general method
group_result = self._build_group_query(group_spec)
if group_result is not None:
select_clause, group_by_clause, output_fields = group_result
# Modify the SELECT clause to work with the unwound data
# Replace json_extract(data, '$.field') with appropriate expressions
# For the _id field, if it matches the unwind field, use je.value
group_id_field = group_spec["_id"][1:] # Remove leading $
if group_id_field == unwind_field:
# Replace the _id extraction with je.value
modified_select = select_clause.replace(
f"{self._json_function_prefix}_extract(data, '{parse_json_path(unwind_field)}') AS _id",
"je.value AS _id",
)
# Also replace any other references to the unwind field in the SELECT clause
modified_select = modified_select.replace(
f"{self._json_function_prefix}_extract(data, '{parse_json_path(unwind_field)}')",
"je.value",
)
# For GROUP BY clause, use je.value when grouping by the unwind field
modified_group_by = "GROUP BY je.value"
else:
modified_select = select_clause
# Keep the original GROUP BY but ensure it references the correct table
modified_group_by = group_by_clause.replace(
f"{self._json_function_prefix}_extract(data,",
f"{self._json_function_prefix}_extract({table_name}.data,",
)
# Replace all other json_extract(data, ...) with json_extract(table.data, ...)
# to properly reference the table column in the JOIN context
# This is needed for fields that aren't the unwind field (e.g., $push: $name)
modified_select = modified_select.replace(
f"{self._json_function_prefix}_extract(data,",
f"{self._json_function_prefix}_extract({table_name}.data,",
)
# Build the FROM clause with json_each for unwinding
from_clause = f"FROM {table_name}, {self._json_each_function}({self._json_function_prefix}_extract({table_name}.data, '{parse_json_path(unwind_field)}')) as je"
# Add ordering by _id for consistent results
order_by_clause = "ORDER BY _id"
# Construct the full SQL command
cmd = f"{modified_select} {from_clause} {modified_group_by} {order_by_clause}"
return cmd, [], output_fields
else:
# If we can't build the group query, fall back to Python
return None
return None
[docs]
def _build_unwind_query(
self,
pipeline_index: int,
pipeline: list[dict[str, Any]],
unwind_stages: list[str],
) -> tuple[str, list[Any], list[str] | None] | None:
"""
Builds a SQL query for a sequence of $unwind stages.
This method constructs a SQL query to handle one or more consecutive $unwind
stages in an aggregation pipeline. It processes array fields by joining
with SQLite's `json_each`/`jsonb_each` function to "unwind" the arrays into separate rows.
The method also handles necessary array type checks and integrates with
other pipeline stages like $match, $sort, $skip, and $limit.
Args:
pipeline_index (int): The index of the first $unwind stage in the pipeline.
pipeline (list[dict[str, Any]]): The full aggregation pipeline.
unwind_stages (list[str]): A list of field paths to unwind,
each prefixed with '$'.
Returns:
tuple[str, list[Any], list[str] | None] | None: A tuple containing:
- The constructed SQL command string.
- A list of parameters for the SQL query.
- A list of output field names (None if not applicable).
Returns None if the unwind stages cannot be processed with SQL and a
fallback to Python is required.
"""
field_names = []
for field in unwind_stages:
if (
not isinstance(field, str)
or not field.startswith("$")
or len(field) == 1
):
return None # Fallback to Python implementation
field_names.append(field[1:])
# Build SELECT clause with nested json_set calls
select_parts = [f"{quote_table_name(self.collection.name)}.data"]
for i, field_name in enumerate(field_names):
select_parts.insert(0, "json_set(")
select_parts.append(
f", '{parse_json_path(field_name)}', je{i + 1}.value)"
)
select_expr = "".join(select_parts)
select_clause = f"SELECT {quote_table_name(self.collection.name)}.id, {select_expr} as data"
# Build FROM clause with multiple json_each calls
from_clause, unwound_fields = self._build_unwind_from_clause(
field_names
)
# Handle $match stage and array type checks
all_where_clauses = []
params: list[Any] = []
if pipeline_index == 1 and "$match" in pipeline[0]:
match_query = pipeline[0]["$match"]
where_result = self._build_simple_where_clause(match_query)
if where_result and where_result[0]:
all_where_clauses.append(
where_result[0].replace("WHERE ", "", 1)
)
params.extend(where_result[1])
for field_name in field_names:
parent_field, parent_alias = self._find_parent_unwind(
field_name, unwound_fields
)
if parent_field and parent_alias:
nested_path = field_name[len(parent_field) + 1 :]
all_where_clauses.append(
f"json_type({self._json_function_prefix}_extract({parent_alias}.value, '{parse_json_path(nested_path)}')) = 'array'"
)
else:
all_where_clauses.append(
f"json_type({self._json_function_prefix}_extract({quote_table_name(self.collection.name)}.data, '{parse_json_path(field_name)}')) = 'array'"
)
where_clause = ""
if all_where_clauses:
where_clause = "WHERE " + " AND ".join(all_where_clauses)
# Handle sort, skip, and limit operations
start_index = pipeline_index + len(unwind_stages)
end_index = len(pipeline)
order_by, limit, offset = self._build_sort_skip_limit_clauses(
pipeline, start_index, end_index, unwound_fields
)
cmd = f"{select_clause} {from_clause} {where_clause} {order_by} {limit} {offset}"
return cmd, params, None
[docs]
def _build_unwind_from_clause(
self, field_names: list[str]
) -> tuple[str, dict[str, str]]:
"""
Builds the FROM clause for a SQL query with one or more $unwind stages.
This method constructs the FROM clause needed to handle multiple $unwind
operations in an aggregation pipeline. It creates joins with SQLite's
`json_each`/`jsonb_each` function for each field to be unwound, allowing array elements
to be processed as separate rows. It also manages nested unwinds by
identifying parent-child relationships between fields.
Args:
field_names (list[str]): A list of field paths to unwind. Each path
should be a string without the leading '$'.
Returns:
tuple[str, dict[str, str]]: A tuple containing:
- The constructed FROM clause as a string.
- A dictionary mapping each unwound field path to its corresponding
alias (e.g., 'je1', 'je2').
"""
from_clause, unwound_fields = self._build_unwind_from_clause_impl(
field_names
)
return from_clause, unwound_fields
[docs]
def _build_unwind_from_clause_impl(
self, field_names: list[str]
) -> tuple[str, dict[str, str]]:
"""
Internal implementation for building the FROM clause.
Args:
field_names (list[str]): A list of field paths to unwind.
Returns:
tuple[str, dict[str, str]]: A tuple containing the FROM clause and
unwound fields mapping.
"""
from_parts = [f"FROM {quote_table_name(self.collection.name)}"]
unwound_fields: dict[str, str] = {}
for i, field_name in enumerate(field_names):
je_alias = f"je{i + 1}"
parent_field, parent_alias = self._find_parent_unwind(
field_name, unwound_fields
)
if parent_field and parent_alias:
nested_path = field_name[len(parent_field) + 1 :]
from_parts.append(
f", {self._json_each_function}({self._json_function_prefix}_extract({parent_alias}.value, '{parse_json_path(nested_path)}')) as {je_alias}"
)
else:
from_parts.append(
f", {self._json_each_function}({self._json_function_prefix}_extract({quote_table_name(self.collection.name)}.data, '{parse_json_path(field_name)}')) as {je_alias}"
)
unwound_fields[field_name] = je_alias
return " ".join(from_parts), unwound_fields
[docs]
def _find_parent_unwind(
self, field_name: str, unwound_fields: dict[str, str]
) -> tuple[str | None, str | None]:
"""
Find the parent unwind field for a nested unwind.
This method searches through already processed unwind fields to find a
parent field that the current field is nested within. This is used to
properly construct SQL joins for nested array unwinding operations.
Args:
field_name (str): The field name to find the parent for.
unwound_fields (dict[str, str]): A dictionary mapping field paths to
their aliases.
Returns:
tuple[str | None, str | None]: A tuple containing the parent field
name and its alias, or (None, None)
if no parent is found.
"""
parent_field = None
parent_alias = None
longest_match_len = -1
for p_field, p_alias in unwound_fields.items():
prefix = p_field + "."
if field_name.startswith(prefix):
if len(p_field) > longest_match_len:
longest_match_len = len(p_field)
parent_field = p_field
parent_alias = p_alias
return parent_field, parent_alias
[docs]
def _build_sort_skip_limit_clauses(
self,
pipeline: list[dict[str, Any]],
start_index: int,
end_index: int,
unwound_fields: dict[str, str],
) -> tuple[str, str, str]:
"""
Build ORDER BY, LIMIT, and OFFSET clauses for aggregation queries.
This method constructs the SQL clauses for sorting, skipping, and limiting
results in an aggregation pipeline. It handles both regular fields and
fields that have been unwound from arrays, ensuring proper SQL generation
for nested array elements.
Args:
pipeline (list[dict[str, Any]]): The aggregation pipeline stages.
start_index (int): The starting index in the pipeline to process stages from.
end_index (int): The ending index in the pipeline to process stages to.
unwound_fields (dict[str, str]): A mapping of field names to their aliases
for unwound fields.
Returns:
tuple[str, str, str]: A tuple containing:
- The ORDER BY clause (empty string if no sorting)
- The LIMIT clause (empty string if no limit)
- The OFFSET clause (empty string if no offset)
"""
local_order_by = ""
local_limit = ""
local_offset = ""
sort_stages = []
skip_value = 0
limit_value = None
for stage_idx in range(start_index, end_index):
stage = pipeline[stage_idx]
if "$sort" in stage:
sort_stages.append(stage["$sort"])
elif "$skip" in stage:
skip_value = stage["$skip"]
elif "$limit" in stage:
limit_value = stage["$limit"]
if sort_stages:
sort_clauses = []
for sort_spec in sort_stages:
for key, direction in sort_spec.items():
parent_field, parent_alias = self._find_parent_unwind(
key, unwound_fields
)
if parent_field and parent_alias:
nested_path = key[len(parent_field) + 1 :]
sort_clauses.append(
f"{self._json_function_prefix}_extract({parent_alias}.value, '{parse_json_path(nested_path)}') "
f"{'DESC' if direction == DESCENDING else 'ASC'}"
)
elif key in unwound_fields:
unwound_alias = unwound_fields[key]
sort_clauses.append(
f"{unwound_alias}.value {'DESC' if direction == DESCENDING else 'ASC'}"
)
else:
sort_clauses.append(
f"{self._json_function_prefix}_extract({quote_table_name(self.collection.name)}.data, '{parse_json_path(key)}') "
f"{'DESC' if direction == DESCENDING else 'ASC'}"
)
if sort_clauses:
local_order_by = "ORDER BY " + ", ".join(sort_clauses)
if limit_value is not None:
local_limit = f"LIMIT {limit_value}"
if skip_value > 0:
local_offset = f"OFFSET {skip_value}"
elif skip_value > 0:
# SQLite requires LIMIT when using OFFSET
local_limit = "LIMIT -1"
local_offset = f"OFFSET {skip_value}"
return local_order_by, local_limit, local_offset
[docs]
def _build_group_query(
self, group_spec: dict[str, Any]
) -> tuple[str, str, list[str]] | None:
"""
Builds the SELECT and GROUP BY clauses for a $group stage.
This method constructs SQL SELECT and GROUP BY clauses for MongoDB-like
$group aggregation stages that can be handled directly with SQL. It supports
grouping by a single field and various accumulator operations like $sum,
$avg, $min, $max, $count, $push, and $addToSet.
Args:
group_spec (dict[str, Any]): A dictionary representing the $group stage
specification. It should contain an "_id"
field for grouping and accumulator operations
for other fields.
Returns:
tuple[str, str, list[str]] | None: A tuple containing:
- The SELECT clause string with all required expressions
- The GROUP BY clause string
- A list of output field names
Returns None if the group specification contains unsupported operations
that require Python-based processing.
"""
group_id_expr = group_spec.get("_id")
if group_id_expr is None:
group_by_clause = ""
select_expressions = ["NULL AS _id"]
output_fields = ["_id"]
elif isinstance(group_id_expr, str) and group_id_expr.startswith("$"):
group_by_field = group_id_expr[1:]
group_by_clause = f"GROUP BY {self._json_function_prefix}_extract(data, '{parse_json_path(group_by_field)}')"
select_expressions = [
f"{self._json_function_prefix}_extract(data, '{parse_json_path(group_by_field)}') AS _id"
]
output_fields = ["_id"]
else:
return None # Fallback for complex _id expressions
for field, accumulator in group_spec.items():
if field == "_id":
continue
if not isinstance(accumulator, dict) or len(accumulator) != 1:
return None
op, expr = next(iter(accumulator.items()))
if op == "$count":
select_expressions.append(f"COUNT(*) AS {field}")
output_fields.append(field)
continue
if op == "$push":
# Handle $push accumulator
if not isinstance(expr, str) or not expr.startswith("$"):
return None # Fallback for complex accumulator expressions
field_name = expr[1:]
select_expressions.append(
f"json_group_array({self._json_function_prefix}_extract(data, '{parse_json_path(field_name)}')) AS \"{field}\""
)
output_fields.append(field)
continue
if op == "$addToSet":
# Handle $addToSet accumulator
if not isinstance(expr, str) or not expr.startswith("$"):
return None # Fallback for complex accumulator expressions
field_name = expr[1:]
select_expressions.append(
f"json_group_array(DISTINCT {self._json_function_prefix}_extract(data, '{parse_json_path(field_name)}')) AS \"{field}\""
)
output_fields.append(field)
continue
# Handle special case for $sum with integer literal 1 (count operation)
if op == "$sum" and isinstance(expr, int) and expr == 1:
select_expressions.append(f"COUNT(*) AS {field}")
output_fields.append(field)
continue
# Handle field-based operations
if not isinstance(expr, str) or not expr.startswith("$"):
return None # Fallback for complex accumulator expressions
field_name = expr[1:]
sql_func = {
"$sum": "SUM",
"$avg": "AVG",
"$min": "MIN",
"$max": "MAX",
}.get(op)
if not sql_func:
return None # Unsupported accumulator
select_expressions.append(
f"{sql_func}({self._json_function_prefix}_extract(data, '{parse_json_path(field_name)}')) AS {field}"
)
output_fields.append(field)
select_clause = "SELECT " + ", ".join(select_expressions)
return select_clause, group_by_clause, output_fields
[docs]
def _process_group_stage(
self,
group_query: dict[str, Any],
docs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Process the $group stage of an aggregation pipeline.
This method groups documents by a specified field and performs specified
accumulator operations on other fields.
Args:
group_query (dict[str, Any]): A dictionary representing the $group
stage of the aggregation pipeline.
docs (list[dict[str, Any]]): A list of documents to be grouped.
Returns:
list[dict[str, Any]]: A list of grouped documents with applied
accumulator operations.
"""
grouped_docs: dict[Any, dict[str, Any]] = {}
group_id_key = group_query.get("_id")
# Create a copy of group_query without _id for processing accumulator operations
accumulators = {k: v for k, v in group_query.items() if k != "_id"}
# Create expression evaluator for evaluating expressions in accumulators
evaluator = ExprEvaluator(
data_column="data", db_connection=self.collection.db
)
for doc in docs:
if group_id_key is None:
group_id = None
elif _is_expression(group_id_key):
# Evaluate expression for group key
group_id = evaluator._evaluate_expr_python(group_id_key, doc)
else:
group_id = self.collection._get_val(doc, group_id_key)
group = grouped_docs.setdefault(group_id, {"_id": group_id})
for field, accumulator in accumulators.items():
# Check if accumulator is a valid dictionary format
if not isinstance(accumulator, dict) or len(accumulator) != 1:
# Invalid accumulator format, skip this field
continue
op, key = next(iter(accumulator.items()))
# Check for unsupported operators
if op == "$accumulator":
raise NotImplementedError(
"The '$accumulator' operator is not supported in NeoSQLite. "
"Please use built-in accumulators ($sum, $avg, $min, $max, $count, $push, $addToSet, $first, $last), "
"or post-process results in Python."
)
if op == "$count":
group[field] = group.get(field, 0) + 1
continue
# Handle expressions in accumulators
if _is_expression(key):
# Evaluate expression for each document
value = evaluator._evaluate_expr_python(key, doc)
# Handle literal values (e.g., $sum: 1 for counting)
elif isinstance(key, (int, float)):
value = key
elif isinstance(key, dict):
# Check if this is one of our new N-value operators
if op in {"$firstN", "$lastN", "$minN", "$maxN"}:
# These operators use dict format with "input" field
# Extract the input field and get its value
input_field = key.get("input", key.get("values", ""))
if input_field:
value = self.collection._get_val(doc, input_field)
else:
value = None
else:
# Complex expression like {"$multiply": [...]}, not supported in Python fallback
continue
else:
value = self.collection._get_val(doc, key)
match op:
case "$sum":
group[field] = (group.get(field, 0) or 0) + (value or 0)
case "$avg":
avg_info = group.get(field, {"sum": 0, "count": 0})
avg_info["sum"] += value or 0
avg_info["count"] += 1
group[field] = avg_info
case "$min":
current = group.get(field, value)
if current is not None and value is not None:
group[field] = min(current, value)
elif value is not None:
group[field] = value
elif current is not None:
group[field] = current
else:
group[field] = None
case "$max":
current = group.get(field, value)
if current is not None and value is not None:
group[field] = max(current, value)
elif value is not None:
group[field] = value
elif current is not None:
group[field] = current
else:
group[field] = None
case "$push":
group.setdefault(field, []).append(value)
case "$addToSet":
# Initialize the list if it doesn't exist
if field not in group:
group[field] = []
# Only add the value if it's not already in the list
if value not in group[field]:
group[field].append(value)
case "$first":
# Only set the value if it hasn't been set yet (first document in group)
if field not in group:
group[field] = value
case "$last":
# Always update with the latest value (last document in group)
group[field] = value
case "$mergeObjects":
# Merge objects from all documents in the group
# Last value wins for conflicting fields
if field not in group:
group[field] = {}
if isinstance(value, dict):
group[field] |= value
case "$stdDevPop":
# Track sum, sum of squares, and count for population standard deviation
if field not in group:
group[field] = {
"sum": 0,
"sum_squares": 0,
"count": 0,
"type": "stdDevPop",
}
if value is not None:
group[field]["sum"] += value
group[field]["sum_squares"] += value * value
group[field]["count"] += 1
case "$stdDevSamp":
# Track sum, sum of squares, and count for sample standard deviation
if field not in group:
group[field] = {
"sum": 0,
"sum_squares": 0,
"count": 0,
"type": "stdDevSamp",
}
if value is not None:
group[field]["sum"] += value
group[field]["sum_squares"] += value * value
group[field]["count"] += 1
case "$firstN" | "$lastN" | "$minN" | "$maxN":
# Handle N-value operators
if not isinstance(key, dict) or "n" not in key:
continue
n_value = key["n"]
if field not in group:
group[field] = {
"type": op,
"n": n_value,
"values": [],
}
# Add value to the list
if value is not None:
group[field]["values"].append(value)
# Keep only the top N values based on operator type
if len(group[field]["values"]) > n_value:
if op == "$firstN":
# Keep first N values (already in order)
group[field]["values"] = group[field][
"values"
][:n_value]
elif op == "$lastN":
# Keep last N values
group[field]["values"] = group[field][
"values"
][-n_value:]
elif op == "$minN":
# Keep N smallest values
group[field]["values"] = sorted(
group[field]["values"]
)[:n_value]
elif op == "$maxN":
# Keep N largest values
group[field]["values"] = sorted(
group[field]["values"], reverse=True
)[:n_value]
# Finalize $avg calculations
for group in grouped_docs.values():
for field, value in group.items():
if field == "_id":
continue
# Skip if this is a std dev calculation (has "type" key)
if isinstance(value, dict) and value.get("type") in {
"stdDevPop",
"stdDevSamp",
}:
continue
# Finalize $avg calculations
if (
isinstance(value, dict)
and "sum" in value
and "count" in value
):
if value["count"] > 0:
group[field] = value["sum"] / value["count"]
else:
group[field] = None
# Finalize standard deviation calculations
import math
for group in grouped_docs.values():
for field, value in group.items():
if field == "_id":
continue
if isinstance(value, dict) and value.get("type") in {
"stdDevPop",
"stdDevSamp",
}:
n = value["count"]
if n > 0:
mean = value["sum"] / n
variance = (value["sum_squares"] / n) - (mean * mean)
if value["type"] == "stdDevSamp" and n > 1:
# Sample standard deviation uses Bessel's correction
variance = (
value["sum_squares"] - (value["sum"] ** 2) / n
) / (n - 1)
if variance < 0:
# Handle floating point errors
variance = 0
group[field] = math.sqrt(variance)
else:
group[field] = None
# Finalize N-value operators
for group in grouped_docs.values():
for field, value in group.items():
if field == "_id":
continue
if isinstance(value, dict) and value.get("type") in {
"$firstN",
"$lastN",
"$minN",
"$maxN",
}:
if value["type"] == "$minN":
# Sort in ascending order and take first N values
sorted_values = sorted(value["values"])
group[field] = sorted_values[: value["n"]]
elif value["type"] == "$maxN":
# Sort in descending order and take first N values
sorted_values = sorted(value["values"], reverse=True)
group[field] = sorted_values[: value["n"]]
else:
# For firstN and lastN, values are already in correct order
group[field] = value["values"]
return list(grouped_docs.values())
[docs]
def _run_subpipeline(
self,
sub_pipeline: list[dict[str, Any]],
docs: list[dict[str, Any]],
batch_size: int = 101,
) -> str:
"""
Run a sub-pipeline (e.g., for $facet) on a list of documents.
Uses tier optimization (Tier-1/Tier-2/Tier-3) for each sub-pipeline.
Results are streamed to a temporary table in batches to avoid memory issues.
Args:
sub_pipeline: List of pipeline stages to execute
docs: Input documents
batch_size: Number of documents to process in each batch
Returns:
Name of the temporary table containing results
"""
# Create a temporary in-memory collection to run the sub-pipeline
# This allows each sub-pipeline to use Tier-1/Tier-2 optimization
import uuid
from .. import Collection
# Create temp collection for processing this batch
temp_collection_name = f"_facet_batch_{uuid.uuid4().hex[:12]}"
temp_collection = Collection(
db=self.collection.db,
name=temp_collection_name,
create=True,
database=self.collection._database,
)
# Create result temp table to store sub-pipeline results
result_table = f"_facet_result_{uuid.uuid4().hex[:12]}"
self.collection.db.execute(f"""
CREATE TEMP TABLE {result_table} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data TEXT
)
""")
try:
# Process input docs in batches
for i in range(0, len(docs), batch_size):
batch = docs[i : i + batch_size]
# Strip __doc__ wrapper if present
docs_to_insert = []
for doc in batch:
if isinstance(doc, dict) and "__doc__" in doc:
docs_to_insert.append(doc["__doc__"])
else:
docs_to_insert.append(doc)
if not docs_to_insert:
continue
# Insert batch into temp collection
temp_collection.insert_many(docs_to_insert)
# Run sub-pipeline through normal aggregation (uses Tier-1/Tier-2/Tier-3)
result = list(
temp_collection.aggregate(
sub_pipeline, batchSize=batch_size
)
)
# Insert results into result temp table
for doc in result:
from neosqlite.collection.json_helpers import (
neosqlite_json_dumps,
)
self.collection.db.execute(
f"INSERT INTO {result_table} (data) VALUES (?)",
(neosqlite_json_dumps(doc),),
)
# Clear temp collection for next batch
temp_collection.delete_many({})
return result_table
finally:
# Clean up temporary collection
try:
temp_collection.drop()
except Exception as e:
logger.debug(
f"Failed to drop temporary collection '{temp_collection.name}': {e}"
)
pass # Ignore cleanup errors
[docs]
def _apply_projection(
self,
projection: dict[str, Any],
document: dict[str, Any],
) -> dict[str, Any]:
"""
Applies the projection to the document, selecting or excluding fields
based on the projection criteria.
Args:
projection (dict[str, Any]): A dictionary specifying which fields to
include or exclude.
document (dict[str, Any]): The document to apply the projection to.
Returns:
dict[str, Any]: The document with fields applied based on the projection.
"""
from ..expr_evaluator import (
REMOVE_SENTINEL,
)
if not projection:
return document
doc = deepcopy(document)
projected_doc: dict[str, Any] = {}
include_id = projection.get("_id", 1) == 1
# Check if this is an inclusion projection with expressions or aggregation variables
has_expressions = any(
_is_expression(value)
or (isinstance(value, str) and value.startswith("$"))
for value in projection.values()
)
if has_expressions:
# Inclusion mode with expressions - evaluate each field
evaluator = ExprEvaluator(
data_column="data", db_connection=self.collection.db
)
ctx = AggregationContext()
ctx.bind_document(document)
for key, value in projection.items():
if key == "_id":
if include_id and "_id" in doc:
projected_doc["_id"] = doc["_id"]
continue
if _is_expression(value):
# Evaluate expression
projected_value = evaluator._evaluate_expr_python(
value, document
)
# Check for $$REMOVE sentinel
if projected_value is REMOVE_SENTINEL:
# Skip this field (remove it)
continue
projected_doc[key] = projected_value
elif isinstance(value, str) and value.startswith("$"):
# Field reference or aggregation variable
if value.startswith("$$"):
# Aggregation variable
if value == "$$ROOT":
projected_doc[key] = document.copy()
elif value == "$$CURRENT":
projected_doc[key] = document.copy()
elif value == "$$REMOVE":
# Skip this field (remove it)
continue
else:
projected_doc[key] = None
else:
# Regular field reference
field_name = value[1:]
projected_doc[key] = self.collection._get_val(
document, field_name
)
elif value == 1:
# Simple inclusion
if key in doc:
projected_doc[key] = doc[key]
# value == 0 is exclusion, skip it
if include_id and "_id" in doc:
projected_doc["_id"] = doc["_id"]
return projected_doc
# Inclusion mode (no expressions)
if any(v == 1 for v in projection.values()):
for key, value in projection.items():
if value == 1 and key in doc:
projected_doc[key] = doc[key]
if include_id and "_id" in doc:
projected_doc["_id"] = doc["_id"]
return projected_doc
# Exclusion mode
for key, value in projection.items():
if value == 0 and key in doc:
doc.pop(key, None)
if not include_id and "_id" in doc:
doc.pop("_id", None)
return doc