Source code for neosqlite.collection.query_engine.query_methods
"""Query methods for the QueryEngine."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ..client_session import ClientSession
import json
from neosqlite.collection.json_helpers import (
neosqlite_json_dumps,
neosqlite_json_loads,
)
from ...sql_utils import quote_table_name
from ..json_path_utils import parse_json_path
from ..type_utils import validate_session
from .base import QueryEngineProtocol
[docs]
class QueryMethodsMixin(QueryEngineProtocol):
"""Mixin class providing query methods for QueryEngine."""
[docs]
def count_documents(
self, filter: dict[str, Any], session: ClientSession | None = None
) -> int:
"""
Return the count of documents that match the given filter.
Args:
filter (dict[str, Any]): A dictionary specifying the query filter.
session (ClientSession, optional): A ClientSession for transactions.
Returns:
int: The number of documents matching the filter.
"""
validate_session(session, self.collection._database)
# Apply ID type normalization to handle cases where users query 'id' with ObjectId
filter = self.helpers._normalize_id_query(filter)
# Try to use SQLTranslator for the WHERE clause
where_clause, params = self.sql_translator.translate_match(filter)
if where_clause is not None:
cmd = f"SELECT COUNT(id) FROM {quote_table_name(self.collection.name)} {where_clause}"
row = self.collection.db.execute(cmd, params).fetchone()
return row[0] if row else 0
return len(list(self.find(filter, session=session)))
[docs]
def estimated_document_count(
self, session: ClientSession | None = None
) -> int:
"""
Return the estimated number of documents in the collection.
Args:
session (ClientSession, optional): A ClientSession for transactions.
Returns:
int: The estimated number of documents.
"""
validate_session(session, self.collection._database)
row = self.collection.db.execute(
f"SELECT COUNT(1) FROM {quote_table_name(self.collection.name)}"
).fetchone()
return row[0] if row else 0
[docs]
def distinct(
self,
key: str,
filter: dict[str, Any] | None = None,
session: ClientSession | None = None,
) -> list[Any]:
"""
Return a list of distinct values from the specified key in the documents
of this collection, optionally filtered by a query.
Args:
key (str): The field name to extract distinct values from.
filter (dict[str, Any] | None): An optional query filter to apply to the documents.
session (ClientSession, optional): A ClientSession for transactions.
Returns:
list[Any]: A list containing the distinct values from the specified key.
"""
validate_session(session, self.collection._database)
# Apply ID type normalization to handle cases where users query 'id' with ObjectId
if filter is not None:
filter = self.helpers._normalize_id_query(filter)
params: list[Any] = []
where_clause: str = ""
if filter:
# Try to use SQLTranslator for the WHERE clause
result = self.sql_translator.translate_match(filter)
if result[0] is not None:
where_clause = result[0]
params = result[1]
# For distinct operations, always use json_* functions to avoid binary data issues
# Even if JSONB is supported, we use json_* for distinct to ensure proper text output
func_prefix = "json"
cmd = (
f"SELECT DISTINCT {func_prefix}_extract(data, '{parse_json_path(key)}') "
f"FROM {quote_table_name(self.collection.name)} {where_clause}"
)
cursor = self.collection.db.execute(cmd, params)
results: set[Any] = set()
for row in cursor.fetchall():
if row[0] is None:
continue
try:
val = neosqlite_json_loads(row[0])
match val:
case list():
results.add(tuple(val))
case dict():
results.add(neosqlite_json_dumps(val, sort_keys=True))
case _:
results.add(val)
except (json.JSONDecodeError, TypeError):
results.add(row[0])
return list(results)