Source code for neosqlite.gridfs.grid_file

from __future__ import annotations

import datetime
import hashlib
import logging
from typing import Any

from .._sqlite import sqlite3
from ..collection.schema_utils import column_exists

# Import ObjectId for GridIn and GridOut
from ..objectid import ObjectId
from ..sql_utils import quote_table_name
from .errors import NoFile
from .utils import (
    deserialize_aliases,
    deserialize_metadata,
    force_sync_if_needed,
    serialize_aliases,
    serialize_metadata,
)

logger = logging.getLogger(__name__)


[docs] class GridIn: """ A file-like object for writing data to GridFS. This class provides a writable interface for storing files in GridFS. """
[docs] def __init__( self, db: sqlite3.Connection, bucket_name: str, chunk_size_bytes: int, filename: str, metadata: dict[str, Any] | None = None, file_id: ObjectId | int | None = None, disable_md5: bool = False, write_concern: dict[str, Any] | None = None, content_type: str | None = None, aliases: list[str] | None = None, ): """ Initialize a new GridIn instance. Args: db: SQLite database connection bucket_name: The bucket name for the GridFS files chunk_size_bytes: The chunk size in bytes filename: The name of the file metadata: Optional metadata for the file file_id: Optional custom file ID disable_md5: Disable MD5 checksum calculation for performance write_concern: Write concern settings (simulated for compatibility) content_type: Optional MIME type of the file aliases: Optional list of alternative names for the file """ self._db = db self._bucket_name = bucket_name self._chunk_size_bytes = chunk_size_bytes self._filename = filename self._metadata = metadata self._file_id = file_id self._files_collection = quote_table_name(f"{bucket_name}_files") self._chunks_collection = quote_table_name(f"{bucket_name}_chunks") self._disable_md5 = disable_md5 self._write_concern = write_concern or {} self._content_type = content_type self._aliases = aliases # Stream state self._buffer = bytearray() self._chunk_number = 0 self._position = 0 self._closed = False self._md5_hasher = None if disable_md5 else hashlib.md5()
[docs] def _serialize_aliases(self) -> str | None: """ Serialize aliases to JSON string. Returns: JSON string representation or None """ return serialize_aliases(self._aliases)
[docs] def _serialize_metadata( self, metadata: dict[str, Any] | None ) -> str | None: """ Serialize metadata to JSON string. Args: metadata: Metadata dictionary to serialize Returns: JSON string representation or None """ return serialize_metadata(metadata)
[docs] def _deserialize_metadata( self, metadata_str: str | None ) -> dict[str, Any] | None: """ Deserialize metadata from JSON string. Args: metadata_str: JSON string representation of metadata Returns: Metadata dictionary or None """ return deserialize_metadata(metadata_str)
[docs] def _deserialize_aliases(self, aliases_str: str | None) -> list[str] | None: """ Deserialize aliases from JSON string. Args: aliases_str: JSON string representation of aliases Returns: List of alias strings or None """ return deserialize_aliases(aliases_str)
[docs] def _force_sync_if_needed(self): """Force database synchronization if write concern requires it.""" force_sync_if_needed(self._db, self._write_concern)
[docs] def write(self, data: bytes | bytearray) -> int: """ Write data to the GridIn stream. Args: data: The data to write Returns: The number of bytes written """ if self._closed: raise ValueError("I/O operation on closed file") if not isinstance(data, (bytes, bytearray)): raise TypeError("data must be bytes or bytearray") # Add data to buffer self._buffer.extend(data) self._position += len(data) if self._md5_hasher: self._md5_hasher.update(data) # Flush chunks if buffer is full while len(self._buffer) >= self._chunk_size_bytes: self._flush_chunk() return len(data)
[docs] def _flush_chunk(self) -> None: """ Flush a chunk from the buffer to the database. This method extracts a chunk of data from the internal buffer and writes it to the chunks collection in the database. If this is the first chunk being written and no file ID has been set, it creates the corresponding file document first. The chunk is inserted with its sequence number and associated with the file ID. """ if len(self._buffer) >= self._chunk_size_bytes: # Extract a chunk from the buffer chunk_data = bytes(self._buffer[: self._chunk_size_bytes]) del self._buffer[: self._chunk_size_bytes] # If this is the first chunk, create the file document if self._chunk_number == 0 and self._file_id is None: self._create_file_document() # Insert the chunk self._db.execute( f""" INSERT INTO {self._chunks_collection} (files_id, n, data) VALUES (?, ?, ?) """, (self._get_file_id(), self._chunk_number, chunk_data), ) self._chunk_number += 1
[docs] def _create_file_document(self) -> None: """ Create the file document in the files collection. This method creates a new file document in the GridFS files collection with the necessary metadata. It handles both ObjectId and integer file IDs, storing them appropriately in the database. The method stores the filename, chunk size, upload date, and serialized metadata. If no file ID is provided, it generates a new ObjectId for the file. """ upload_date = datetime.datetime.now(datetime.timezone.utc).isoformat() if self._file_id is None: # Generate an ObjectId for the new file oid = ObjectId() self._db.execute( f""" INSERT INTO {self._files_collection} (id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases) VALUES (NULL, ?, ?, ?, ?, ?, ?, ?) """, ( str(oid), # Store ObjectId as hex string self._filename, self._chunk_size_bytes, upload_date, serialize_metadata(self._metadata), self._content_type, serialize_aliases(self._aliases), ), ) self._file_id = oid # Store the ObjectId else: # Check if file_id is an ObjectId or integer if isinstance(self._file_id, ObjectId): # Store ObjectId in _id column, let SQLite auto-generate id self._db.execute( f""" INSERT INTO {self._files_collection} (id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases) VALUES (NULL, ?, ?, ?, ?, ?, ?, ?) """, ( str(self._file_id), self._filename, self._chunk_size_bytes, upload_date, serialize_metadata(self._metadata), self._content_type, serialize_aliases(self._aliases), ), ) else: # Integer ID provided self._db.execute( f""" INSERT INTO {self._files_collection} (id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( self._file_id, str(self._file_id), # Store as string for consistency self._filename, self._chunk_size_bytes, upload_date, serialize_metadata(self._metadata), self._content_type, serialize_aliases(self._aliases), ), )
[docs] def _get_file_id(self) -> int: """ Get the file ID, creating the file document if necessary. This method returns the integer ID of the file, which is used internally for database operations. If the file document hasn't been created yet, it creates one first. The method handles both ObjectId and integer file IDs, looking up the corresponding integer ID in the database when needed. Returns: int: The integer ID of the file for database operations Raises: RuntimeError: If the file cannot be found in the database """ if self._file_id is None: self._create_file_document() # Return the integer ID, which can be obtained by looking up the stored _id if isinstance(self._file_id, ObjectId): # Look up the integer ID for this ObjectId cursor = self._db.execute( f"SELECT id FROM {self._files_collection} WHERE _id = ?", (str(self._file_id),), ) if (row := cursor.fetchone()) is None: raise RuntimeError( f"File with ObjectId {self._file_id} not found in database" ) return row[0] elif isinstance(self._file_id, int): # If file_id is already an integer, return it as-is return self._file_id else: # For other types, try to look it up by value cursor = self._db.execute( f"SELECT id FROM {self._files_collection} WHERE _id = ?", (str(self._file_id),), ) if (row := cursor.fetchone()) is None: raise RuntimeError( f"File with ID {self._file_id} not found in database" ) return row[0]
[docs] def close(self) -> None: """ Close the GridIn stream and finalize the file storage. This method flushes any remaining data in the buffer to the database, completes the file document with final metadata including length and MD5 hash, and ensures the file is properly stored in GridFS. If no chunks have been written yet, it still creates the file document. The method also handles database synchronization if required by the write concern settings. """ if self._closed: return # Flush any remaining data in the buffer if self._buffer or self._chunk_number == 0: # If no chunks have been written yet, we still need to create the file if self._chunk_number == 0: self._create_file_document() # Get the integer ID for the file document (from the created document) file_int_id = self._get_file_id() # Write the final chunk (which may be smaller than chunk_size_bytes) if self._buffer: self._db.execute( f""" INSERT INTO {self._chunks_collection} (files_id, n, data) VALUES (?, ?, ?) """, ( file_int_id, self._chunk_number, bytes(self._buffer), ), ) self._chunk_number += 1 # Update the file document with final metadata md5_hash = None if self._md5_hasher: md5_hash = self._md5_hasher.hexdigest() self._db.execute( f""" UPDATE {self._files_collection} SET length = ?, md5 = ? WHERE id = ? """, (self._position, md5_hash, file_int_id), ) # Force sync if write concern requires it self._force_sync_if_needed() self._closed = True
def __enter__(self) -> GridIn: """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close()
[docs] class GridOut: """ A file-like object for reading data from GridFS. This class provides a readable interface for retrieving files from GridFS. """
[docs] def __init__( self, db: sqlite3.Connection, bucket_name: str, file_id: ObjectId | int, ): """ Initialize a new GridOut instance. Args: db: SQLite database connection bucket_name: The bucket name for the GridFS files file_id: The ID of the file to read (ObjectId or integer) """ self._db = db self._bucket_name = bucket_name # Convert file_id to integer for internal use if isinstance(file_id, ObjectId): # Look up the integer ID for this ObjectId # Use quote_table_name to prevent SQL injection in bucket_name cursor = self._db.execute( f"SELECT id FROM {quote_table_name(f'{bucket_name}_files')} WHERE _id = ?", (str(file_id),), ) if (row := cursor.fetchone()) is None: raise RuntimeError( f"File with ObjectId {file_id} not found in database" ) self._int_file_id = row[0] # Store the integer ID internally else: self._int_file_id = file_id # Store the integer ID internally self._files_collection = quote_table_name(f"{bucket_name}_files") self._chunks_collection = quote_table_name(f"{bucket_name}_chunks") # Get file metadata using the integer ID # Check if new columns exist to maintain backward compatibility has_content_type = False has_aliases = False try: has_content_type = column_exists( self._db, self._files_collection, "content_type" ) has_aliases = column_exists( self._db, self._files_collection, "aliases" ) except (AttributeError, TypeError) as e: # Handle mocked databases in tests - assume old schema logger.debug(f"{e=}") pass # Build dynamic SELECT query based on available columns select_columns = [ "filename", "length", "chunkSize", "uploadDate", "md5", "json(metadata)", "_id", ] if has_content_type: select_columns.append("content_type") if has_aliases: select_columns.append("json(aliases)") query = f""" SELECT {", ".join(select_columns)} FROM {self._files_collection} WHERE id = ? """ row = self._db.execute(query, (self._int_file_id,)).fetchone() if row is None: raise NoFile(f"File with id {file_id} not found") # Unpack based on available columns row_idx = 0 self._filename = row[row_idx] row_idx += 1 self._length = row[row_idx] row_idx += 1 self._chunk_size = row[row_idx] row_idx += 1 self._upload_date = row[row_idx] row_idx += 1 self._md5 = row[row_idx] row_idx += 1 metadata_str = row[row_idx] row_idx += 1 self._stored_oid = row[ row_idx ] # Store the _id value (ObjectId hex string) row_idx += 1 # Handle optional columns with defaults self._content_type = row[row_idx] if has_content_type else None if has_content_type: row_idx += 1 aliases_str = row[row_idx] if has_aliases else None # Set the actual _id field to the appropriate type based on the stored value if self._stored_oid is not None: # Determine the appropriate type for the ID if isinstance(self._stored_oid, int): # If it's already an integer, use it directly self._actual_id: ObjectId | int | str = self._stored_oid elif isinstance(self._stored_oid, str): if len(self._stored_oid) == 24: # Check if it's a valid ObjectId hex format try: self._actual_id = ObjectId( self._stored_oid ) # Convert to ObjectId except ValueError as e: # If it's not a valid ObjectId hex, try to convert to int logger.debug( f"Failed to convert GridFS ID '{self._stored_oid}' to ObjectId: {e}" ) try: self._actual_id = int(self._stored_oid) except ValueError as e: # Keep as string if it's neither a valid ObjectId nor integer logger.debug( f"Failed to convert GridFS ID '{self._stored_oid}' to integer: {e}" ) self._actual_id = self._stored_oid else: # For other length strings, try to convert to integer first try: self._actual_id = int(self._stored_oid) except ValueError as e: # If it's not an integer string, keep as is logger.debug( f"Failed to convert non-24 char GridFS ID '{self._stored_oid}' to integer: {e}" ) self._actual_id = self._stored_oid else: # For any other type, keep as is self._actual_id = self._stored_oid # type: ignore else: # If no stored _id, fall back to the integer ID self._actual_id = file_id # type: ignore self._metadata = deserialize_metadata(metadata_str) self._aliases = deserialize_aliases(aliases_str) self._position = 0 self._current_chunk_data = b"" self._current_chunk_index = -1 self._closed = False
@property def _id(self): """Get the file's actual ID, which may be an ObjectId or integer.""" return self._actual_id @property def _file_id(self): """Get the file's actual ID that represents what the user expects as the ID. For compatibility with the original API, this returns the actual ID (ObjectId or int). """ return self._actual_id
[docs] def _deserialize_metadata( self, metadata_str: str | None ) -> dict[str, Any] | None: """ Deserialize metadata from JSON string. Args: metadata_str: JSON string representation of metadata Returns: Metadata dictionary or None """ return deserialize_metadata(metadata_str)
[docs] def _deserialize_aliases(self, aliases_str: str | None) -> list[str] | None: """ Deserialize aliases from JSON string. Args: aliases_str: JSON string representation of aliases Returns: List of alias strings or None """ return deserialize_aliases(aliases_str)
[docs] def read(self, size: int = -1) -> bytes: """ Read data from the GridOut stream. Args: size: The number of bytes to read (-1 for all remaining data) Returns: The data read from the stream """ if self._closed: raise ValueError("I/O operation on closed file") if size == -1: # Read all remaining data size = self._length - self._position if size <= 0: return b"" # Calculate which chunks we need result = bytearray() bytes_read = 0 while bytes_read < size and self._position < self._length: # Load the current chunk if needed self._load_chunk() # Calculate how much we can read from the current chunk chunk_offset = self._position % self._chunk_size bytes_available_in_chunk = ( len(self._current_chunk_data) - chunk_offset ) bytes_to_read = min(size - bytes_read, bytes_available_in_chunk) # Read from the current chunk result.extend( self._current_chunk_data[ chunk_offset : chunk_offset + bytes_to_read ] ) # Update position self._position += bytes_to_read bytes_read += bytes_to_read # Move to next chunk if we've exhausted the current one if self._position % self._chunk_size == 0: self._current_chunk_index += 1 self._current_chunk_data = b"" return bytes(result)
[docs] def _load_chunk(self) -> None: """Load the chunk containing the current position.""" chunk_index = self._position // self._chunk_size # If we already have the right chunk, we're done if chunk_index == self._current_chunk_index: return # Load the required chunk using the integer file ID row = self._db.execute( f""" SELECT data FROM {self._chunks_collection} WHERE files_id = ? AND n = ? """, (self._int_file_id, chunk_index), ).fetchone() if row is None: raise NoFile( f"Chunk {chunk_index} for file id {self._int_file_id} not found" ) self._current_chunk_data = row[0] self._current_chunk_index = chunk_index
@property def filename(self) -> str: """Get the filename.""" return self._filename @property def length(self) -> int: """Get the length of the file in bytes.""" return self._length @property def chunk_size(self) -> int: """Get the chunk size in bytes.""" return self._chunk_size @property def upload_date(self) -> str: """Get the upload date.""" return self._upload_date @property def md5(self) -> str: """Get the MD5 hash of the file.""" return self._md5 @property def metadata(self) -> dict[str, Any] | None: """Get the metadata of the file.""" return self._metadata @property def content_type(self) -> str | None: """Get the content type (MIME type) of the file.""" return self._content_type @property def aliases(self) -> list[str] | None: """Get the aliases (alternative names) for the file.""" return self._aliases
[docs] def close(self) -> None: """Close the GridOut stream.""" self._closed = True
def __enter__(self) -> GridOut: """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close()
[docs] class GridOutCursor: """ A cursor for iterating over GridFS files. This class provides an iterator interface for retrieving file documents from GridFS. """
[docs] def __init__( self, db: sqlite3.Connection, bucket_name: str, filter: dict[str, Any], ): """ Initialize a new GridOutCursor instance. Args: db: SQLite database connection bucket_name: The bucket name for the GridFS files filter: The filter to apply when searching for files """ self._db = db self._bucket_name = bucket_name self._filter = filter self._files_collection = quote_table_name(f"{bucket_name}_files") # Build query based on filter where_clause = "" params = [] if filter: where_conditions = [] for key, value in filter.items(): match key: case "_id": # Handle ObjectId hex strings and other ID formats if isinstance(value, ObjectId): where_conditions.append("_id = ?") params.append(str(value)) elif isinstance(value, str) and len(value) == 24: # Check if it's a valid ObjectId hex string try: ObjectId(value) where_conditions.append("_id = ?") params.append(value) except ValueError as e: # Not a valid ObjectId, treat as regular string logger.debug( f"ID '{value}' is not a valid ObjectId: {e}" ) where_conditions.append("_id = ?") params.append(value) else: # Handle other types where_conditions.append("_id = ?") params.append(value) case "id": # For 'id' queries, we look in the integer id column where_conditions.append("id = ?") params.append(value) case "filename": if isinstance(value, dict): # Handle operators like {"$regex": "pattern"}, {"$ne": "name"}, etc. for op, val in value.items(): match op: case "$regex": where_conditions.append( "filename LIKE ?" ) params.append(f"%{val}%") case "$ne": where_conditions.append("filename != ?") params.append(val) case "$eq": where_conditions.append("filename = ?") params.append(val) case _: # For unsupported operators, fall back to exact match where_conditions.append("filename = ?") params.append(str(value)) else: # Direct value comparison where_conditions.append("filename = ?") params.append(value) case "length": if isinstance(value, dict): # Handle operators like {"$gt": 1000}, {"$lt": 5000}, etc. for op, val in value.items(): match op: case "$gt": where_conditions.append("length > ?") params.append(val) case "$gte": where_conditions.append("length >= ?") params.append(val) case "$lt": where_conditions.append("length < ?") params.append(val) case "$lte": where_conditions.append("length <= ?") params.append(val) case "$eq": where_conditions.append("length = ?") params.append(val) case "$ne": where_conditions.append("length != ?") params.append(val) else: # Direct value comparison where_conditions.append("length = ?") params.append(value) case "chunkSize": if isinstance(value, dict): # Handle operators like {"$gt": 1000}, {"$lt": 5000}, etc. for op, val in value.items(): match op: case "$gt": where_conditions.append("chunkSize > ?") params.append(val) case "$gte": where_conditions.append( "chunkSize >= ?" ) params.append(val) case "$lt": where_conditions.append("chunkSize < ?") params.append(val) case "$lte": where_conditions.append( "chunkSize <= ?" ) params.append(val) case "$eq": where_conditions.append("chunkSize = ?") params.append(val) case "$ne": where_conditions.append( "chunkSize != ?" ) params.append(val) else: # Direct value comparison where_conditions.append("chunkSize = ?") params.append(value) case "uploadDate": if isinstance(value, dict): # Handle operators like {"$gt": date}, {"$lt": date}, etc. for op, val in value.items(): match op: case "$gt": where_conditions.append( "uploadDate > ?" ) params.append(val) case "$gte": where_conditions.append( "uploadDate >= ?" ) params.append(val) case "$lt": where_conditions.append( "uploadDate < ?" ) params.append(val) case "$lte": where_conditions.append( "uploadDate <= ?" ) params.append(val) case "$eq": where_conditions.append( "uploadDate = ?" ) params.append(val) case "$ne": where_conditions.append( "uploadDate != ?" ) params.append(val) else: # Direct value comparison where_conditions.append("uploadDate = ?") params.append(value) case "md5": if isinstance(value, dict) and "$ne" in value: where_conditions.append("md5 != ?") params.append(value["$ne"]) else: # Direct value comparison where_conditions.append("md5 = ?") params.append(value) # For metadata, we do a simple string match (basic implementation) # In a full implementation, we'd parse the JSON, but for now we'll do substring matching case "metadata": if isinstance(value, dict): # Handle metadata queries with operators for op, val in value.items(): match op: case "$regex": where_conditions.append( "metadata LIKE ?" ) params.append(f"%{val}%") case "$ne": where_conditions.append("metadata != ?") params.append( str(val) if not isinstance(val, str) else val ) case _: # For other operators, convert to string and match where_conditions.append( "metadata LIKE ?" ) params.append(f"%{op}%{val}%") else: # Direct metadata string matching where_conditions.append("metadata LIKE ?") params.append(f"%{value}%") case "aliases": # Check if the value is in the aliases JSON array where_conditions.append( "EXISTS (SELECT 1 FROM json_each(aliases) WHERE value = ?)" ) params.append(value) case "content_type": if isinstance(value, dict) and "$ne" in value: where_conditions.append("content_type != ?") params.append(value["$ne"]) else: # Direct value comparison where_conditions.append("content_type = ?") params.append(value) if where_conditions: where_clause = "WHERE " + " AND ".join(where_conditions) # Execute query to get integer file IDs (changed from _id to id for internal use) query = f"SELECT id FROM {self._files_collection} {where_clause}" cursor = self._db.execute(query, params) self._file_ids = [row[0] for row in cursor.fetchall()] self._index = 0
def __iter__(self) -> GridOutCursor: """Return the iterator object.""" return self def __next__(self) -> GridOut: """Get the next GridOut object.""" if self._index >= len(self._file_ids): raise StopIteration file_id = self._file_ids[self._index] self._index += 1 return GridOut(self._db, self._bucket_name, file_id)