Source code for neosqlite.collection.raw_batch_cursor

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Iterator

logger = logging.getLogger(__name__)

from .json_helpers import neosqlite_json_dumps
from .json_path_utils import parse_json_path
from .jsonb_support import _get_json_function_prefix, json_data_column

if TYPE_CHECKING:
    from ..client_session import ClientSession
    from . import Collection


[docs] class RawBatchCursor: """A cursor that returns raw batches of JSON data instead of individual documents."""
[docs] def __init__( self, collection: Collection, filter: dict[str, Any] | None = None, projection: dict[str, Any] | None = None, hint: str | None = None, batch_size: int = 100, pipeline: list[dict[str, Any]] | None = None, session: ClientSession | None = None, ): """ Initialize a RawBatchCursor object. Args: collection (Collection): The collection associated with this cursor. filter (dict[str, Any]): A dictionary representing the filter criteria for the documents. projection (dict[str, Any]): A dictionary representing the projection criteria for the documents. hint (str): A string hinting at the index to use for the query. batch_size (int): The number of documents to return in each batch. pipeline (list[dict[str, Any]]): An optional aggregation pipeline to execute. session (ClientSession, optional): A ClientSession for transactions. """ self._collection = collection self._query_helpers = collection.query_engine.helpers self._filter = filter or {} self._projection = projection or {} self._hint = hint self._batch_size = batch_size self._skip = 0 self._limit: int | None = None self._sort: dict[str, int] | None = None self._pipeline = pipeline self._session = session self._tables_to_cleanup: list[str] = []
[docs] def batch_size(self, batch_size: int) -> RawBatchCursor: """ Set the batch size for this cursor. Args: batch_size (int): The number of documents to return in each batch. Returns: RawBatchCursor: This cursor object, for method chaining. """ self._batch_size = batch_size return self
def __iter__(self) -> Iterator[bytes]: """ Return an iterator over raw batches of JSON data. Returns: Iterator[bytes]: An iterator that yields raw batches of JSON data. """ # If we have a pipeline, use aggregation if self._pipeline is not None: # Execute the aggregation pipeline results = list(self._collection.aggregate(self._pipeline)) # Split results into batches for i in range(0, len(results), self._batch_size): batch = results[i : i + self._batch_size] # Convert each document to JSON using the custom encoder and join with newlines batch_json = "\n".join( neosqlite_json_dumps(doc) for doc in batch ) yield batch_json.encode("utf-8") return # Build the query using the collection's SQL-building methods where_result = self._query_helpers._build_simple_where_clause( self._filter ) if where_result is not None: # Use SQL-based filtering where_clause, params, tables = where_result # Track tables for cleanup if tables: self._tables_to_cleanup.extend(tables) # Use the collection's JSONB support flag to determine how to select data jsonb = self._collection.query_engine._jsonb_supported json_func = _get_json_function_prefix(jsonb) # Build ORDER BY clause if sorting is specified order_by = "" if self._sort: sort_clauses = [] for key, direction in self._sort.items(): sort_clauses.append( f"{json_func}_extract(data, '{parse_json_path(key)}') " f"{'DESC' if direction == -1 else 'ASC'}" ) order_by = "ORDER BY " + ", ".join(sort_clauses) # Build the full query with proper WHERE clause handling if where_clause and where_clause.strip(): cmd = ( f"SELECT id, {json_data_column(jsonb)} as data " f"FROM {self._collection.name} {where_clause} {order_by}" ) else: cmd = ( f"SELECT id, {json_data_column(jsonb)} as data " f"FROM {self._collection.name} {order_by}" ) # Execute and process in batches offset = self._skip total_returned = 0 while True: # Calculate how many records to fetch in this batch batch_limit = self._batch_size if self._limit is not None: remaining_limit = self._limit - total_returned if remaining_limit <= 0: break batch_limit = min(batch_limit, remaining_limit) # Add LIMIT and OFFSET for this batch batch_cmd = f"{cmd} LIMIT {batch_limit} OFFSET {offset}" db_cursor = self._collection.db.execute(batch_cmd, params) rows = db_cursor.fetchall() if not rows: break # Convert rows to documents docs = [self._collection._load(row[0], row[1]) for row in rows] # Convert to JSON batch using custom encoder to handle ObjectIds batch_json = "\n".join( neosqlite_json_dumps(doc) for doc in docs ) yield batch_json.encode("utf-8") # Update counters returned_count = len(rows) total_returned += returned_count offset += returned_count # If we got fewer rows than requested, we're done if returned_count < batch_limit: break # If we've hit our limit, we're done if self._limit is not None and total_returned >= self._limit: break else: # Fallback to the original method for complex queries # Get all documents first by using the collection's find method cursor = self._collection.find( self._filter, self._projection, self._hint ) # Apply any cursor modifications if self._sort: cursor._sort = self._sort cursor._skip = self._skip cursor._limit = self._limit # Get all documents docs = list(cursor) # Split into batches for i in range(0, len(docs), self._batch_size): batch = docs[i : i + self._batch_size] # Convert each document to JSON using the custom encoder and join with newlines batch_json = "\n".join( neosqlite_json_dumps(doc) for doc in batch ) yield batch_json.encode("utf-8") def __del__(self) -> None: """Clean up resources on garbage collection.""" self._cleanup()
[docs] def _cleanup(self) -> None: """Clean up temporary tables.""" if not self._tables_to_cleanup: return for table in self._tables_to_cleanup: try: self._collection._database.execute( f"DROP TABLE IF EXISTS {table}" ) except Exception as e: logger.debug( f"Failed to drop temporary table {table} in RawBatchCursor cleanup: {e}" ) pass self._tables_to_cleanup = []