from __future__ import annotations
import datetime
import hashlib
import logging
from typing import Any
from .._sqlite import sqlite3
from ..collection.schema_utils import column_exists
# Import ObjectId for GridIn and GridOut
from ..objectid import ObjectId
from ..sql_utils import quote_table_name
from .errors import NoFile
from .utils import (
deserialize_aliases,
deserialize_metadata,
force_sync_if_needed,
serialize_aliases,
serialize_metadata,
)
logger = logging.getLogger(__name__)
[docs]
class GridIn:
"""
A file-like object for writing data to GridFS.
This class provides a writable interface for storing files in GridFS.
"""
[docs]
def __init__(
self,
db: sqlite3.Connection,
bucket_name: str,
chunk_size_bytes: int,
filename: str,
metadata: dict[str, Any] | None = None,
file_id: ObjectId | int | None = None,
disable_md5: bool = False,
write_concern: dict[str, Any] | None = None,
content_type: str | None = None,
aliases: list[str] | None = None,
):
"""
Initialize a new GridIn instance.
Args:
db: SQLite database connection
bucket_name: The bucket name for the GridFS files
chunk_size_bytes: The chunk size in bytes
filename: The name of the file
metadata: Optional metadata for the file
file_id: Optional custom file ID
disable_md5: Disable MD5 checksum calculation for performance
write_concern: Write concern settings (simulated for compatibility)
content_type: Optional MIME type of the file
aliases: Optional list of alternative names for the file
"""
self._db = db
self._bucket_name = bucket_name
self._chunk_size_bytes = chunk_size_bytes
self._filename = filename
self._metadata = metadata
self._file_id = file_id
self._files_collection = quote_table_name(f"{bucket_name}_files")
self._chunks_collection = quote_table_name(f"{bucket_name}_chunks")
self._disable_md5 = disable_md5
self._write_concern = write_concern or {}
self._content_type = content_type
self._aliases = aliases
# Stream state
self._buffer = bytearray()
self._chunk_number = 0
self._position = 0
self._closed = False
self._md5_hasher = None if disable_md5 else hashlib.md5()
[docs]
def _serialize_aliases(self) -> str | None:
"""
Serialize aliases to JSON string.
Returns:
JSON string representation or None
"""
return serialize_aliases(self._aliases)
[docs]
def _deserialize_aliases(self, aliases_str: str | None) -> list[str] | None:
"""
Deserialize aliases from JSON string.
Args:
aliases_str: JSON string representation of aliases
Returns:
List of alias strings or None
"""
return deserialize_aliases(aliases_str)
[docs]
def _force_sync_if_needed(self):
"""Force database synchronization if write concern requires it."""
force_sync_if_needed(self._db, self._write_concern)
[docs]
def write(self, data: bytes | bytearray) -> int:
"""
Write data to the GridIn stream.
Args:
data: The data to write
Returns:
The number of bytes written
"""
if self._closed:
raise ValueError("I/O operation on closed file")
if not isinstance(data, (bytes, bytearray)):
raise TypeError("data must be bytes or bytearray")
# Add data to buffer
self._buffer.extend(data)
self._position += len(data)
if self._md5_hasher:
self._md5_hasher.update(data)
# Flush chunks if buffer is full
while len(self._buffer) >= self._chunk_size_bytes:
self._flush_chunk()
return len(data)
[docs]
def _flush_chunk(self) -> None:
"""
Flush a chunk from the buffer to the database.
This method extracts a chunk of data from the internal buffer and writes it to the
chunks collection in the database. If this is the first chunk being written
and no file ID has been set, it creates the corresponding file document first.
The chunk is inserted with its sequence number and associated with the file ID.
"""
if len(self._buffer) >= self._chunk_size_bytes:
# Extract a chunk from the buffer
chunk_data = bytes(self._buffer[: self._chunk_size_bytes])
del self._buffer[: self._chunk_size_bytes]
# If this is the first chunk, create the file document
if self._chunk_number == 0 and self._file_id is None:
self._create_file_document()
# Insert the chunk
self._db.execute(
f"""
INSERT INTO {self._chunks_collection}
(files_id, n, data)
VALUES (?, ?, ?)
""",
(self._get_file_id(), self._chunk_number, chunk_data),
)
self._chunk_number += 1
[docs]
def _create_file_document(self) -> None:
"""
Create the file document in the files collection.
This method creates a new file document in the GridFS files collection with the
necessary metadata. It handles both ObjectId and integer file IDs, storing them
appropriately in the database. The method stores the filename, chunk size,
upload date, and serialized metadata. If no file ID is provided, it generates
a new ObjectId for the file.
"""
upload_date = datetime.datetime.now(datetime.timezone.utc).isoformat()
if self._file_id is None:
# Generate an ObjectId for the new file
oid = ObjectId()
self._db.execute(
f"""
INSERT INTO {self._files_collection}
(id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases)
VALUES (NULL, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(oid), # Store ObjectId as hex string
self._filename,
self._chunk_size_bytes,
upload_date,
serialize_metadata(self._metadata),
self._content_type,
serialize_aliases(self._aliases),
),
)
self._file_id = oid # Store the ObjectId
else:
# Check if file_id is an ObjectId or integer
if isinstance(self._file_id, ObjectId):
# Store ObjectId in _id column, let SQLite auto-generate id
self._db.execute(
f"""
INSERT INTO {self._files_collection}
(id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases)
VALUES (NULL, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(self._file_id),
self._filename,
self._chunk_size_bytes,
upload_date,
serialize_metadata(self._metadata),
self._content_type,
serialize_aliases(self._aliases),
),
)
else:
# Integer ID provided
self._db.execute(
f"""
INSERT INTO {self._files_collection}
(id, _id, filename, chunkSize, uploadDate, metadata, content_type, aliases)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
self._file_id,
str(self._file_id), # Store as string for consistency
self._filename,
self._chunk_size_bytes,
upload_date,
serialize_metadata(self._metadata),
self._content_type,
serialize_aliases(self._aliases),
),
)
[docs]
def _get_file_id(self) -> int:
"""
Get the file ID, creating the file document if necessary.
This method returns the integer ID of the file, which is used internally for
database operations. If the file document hasn't been created yet, it creates
one first. The method handles both ObjectId and integer file IDs, looking up
the corresponding integer ID in the database when needed.
Returns:
int: The integer ID of the file for database operations
Raises:
RuntimeError: If the file cannot be found in the database
"""
if self._file_id is None:
self._create_file_document()
# Return the integer ID, which can be obtained by looking up the stored _id
if isinstance(self._file_id, ObjectId):
# Look up the integer ID for this ObjectId
cursor = self._db.execute(
f"SELECT id FROM {self._files_collection} WHERE _id = ?",
(str(self._file_id),),
)
if (row := cursor.fetchone()) is None:
raise RuntimeError(
f"File with ObjectId {self._file_id} not found in database"
)
return row[0]
elif isinstance(self._file_id, int):
# If file_id is already an integer, return it as-is
return self._file_id
else:
# For other types, try to look it up by value
cursor = self._db.execute(
f"SELECT id FROM {self._files_collection} WHERE _id = ?",
(str(self._file_id),),
)
if (row := cursor.fetchone()) is None:
raise RuntimeError(
f"File with ID {self._file_id} not found in database"
)
return row[0]
[docs]
def close(self) -> None:
"""
Close the GridIn stream and finalize the file storage.
This method flushes any remaining data in the buffer to the database,
completes the file document with final metadata including length and MD5 hash,
and ensures the file is properly stored in GridFS. If no chunks have been
written yet, it still creates the file document. The method also handles
database synchronization if required by the write concern settings.
"""
if self._closed:
return
# Flush any remaining data in the buffer
if self._buffer or self._chunk_number == 0:
# If no chunks have been written yet, we still need to create the file
if self._chunk_number == 0:
self._create_file_document()
# Get the integer ID for the file document (from the created document)
file_int_id = self._get_file_id()
# Write the final chunk (which may be smaller than chunk_size_bytes)
if self._buffer:
self._db.execute(
f"""
INSERT INTO {self._chunks_collection}
(files_id, n, data)
VALUES (?, ?, ?)
""",
(
file_int_id,
self._chunk_number,
bytes(self._buffer),
),
)
self._chunk_number += 1
# Update the file document with final metadata
md5_hash = None
if self._md5_hasher:
md5_hash = self._md5_hasher.hexdigest()
self._db.execute(
f"""
UPDATE {self._files_collection}
SET length = ?, md5 = ?
WHERE id = ?
""",
(self._position, md5_hash, file_int_id),
)
# Force sync if write concern requires it
self._force_sync_if_needed()
self._closed = True
def __enter__(self) -> GridIn:
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
[docs]
class GridOut:
"""
A file-like object for reading data from GridFS.
This class provides a readable interface for retrieving files from GridFS.
"""
[docs]
def __init__(
self,
db: sqlite3.Connection,
bucket_name: str,
file_id: ObjectId | int,
):
"""
Initialize a new GridOut instance.
Args:
db: SQLite database connection
bucket_name: The bucket name for the GridFS files
file_id: The ID of the file to read (ObjectId or integer)
"""
self._db = db
self._bucket_name = bucket_name
# Convert file_id to integer for internal use
if isinstance(file_id, ObjectId):
# Look up the integer ID for this ObjectId
# Use quote_table_name to prevent SQL injection in bucket_name
cursor = self._db.execute(
f"SELECT id FROM {quote_table_name(f'{bucket_name}_files')} WHERE _id = ?",
(str(file_id),),
)
if (row := cursor.fetchone()) is None:
raise RuntimeError(
f"File with ObjectId {file_id} not found in database"
)
self._int_file_id = row[0] # Store the integer ID internally
else:
self._int_file_id = file_id # Store the integer ID internally
self._files_collection = quote_table_name(f"{bucket_name}_files")
self._chunks_collection = quote_table_name(f"{bucket_name}_chunks")
# Get file metadata using the integer ID
# Check if new columns exist to maintain backward compatibility
has_content_type = False
has_aliases = False
try:
has_content_type = column_exists(
self._db, self._files_collection, "content_type"
)
has_aliases = column_exists(
self._db, self._files_collection, "aliases"
)
except (AttributeError, TypeError) as e:
# Handle mocked databases in tests - assume old schema
logger.debug(f"{e=}")
pass
# Build dynamic SELECT query based on available columns
select_columns = [
"filename",
"length",
"chunkSize",
"uploadDate",
"md5",
"json(metadata)",
"_id",
]
if has_content_type:
select_columns.append("content_type")
if has_aliases:
select_columns.append("json(aliases)")
query = f"""
SELECT {", ".join(select_columns)}
FROM {self._files_collection}
WHERE id = ?
"""
row = self._db.execute(query, (self._int_file_id,)).fetchone()
if row is None:
raise NoFile(f"File with id {file_id} not found")
# Unpack based on available columns
row_idx = 0
self._filename = row[row_idx]
row_idx += 1
self._length = row[row_idx]
row_idx += 1
self._chunk_size = row[row_idx]
row_idx += 1
self._upload_date = row[row_idx]
row_idx += 1
self._md5 = row[row_idx]
row_idx += 1
metadata_str = row[row_idx]
row_idx += 1
self._stored_oid = row[
row_idx
] # Store the _id value (ObjectId hex string)
row_idx += 1
# Handle optional columns with defaults
self._content_type = row[row_idx] if has_content_type else None
if has_content_type:
row_idx += 1
aliases_str = row[row_idx] if has_aliases else None
# Set the actual _id field to the appropriate type based on the stored value
if self._stored_oid is not None:
# Determine the appropriate type for the ID
if isinstance(self._stored_oid, int):
# If it's already an integer, use it directly
self._actual_id: ObjectId | int | str = self._stored_oid
elif isinstance(self._stored_oid, str):
if len(self._stored_oid) == 24:
# Check if it's a valid ObjectId hex format
try:
self._actual_id = ObjectId(
self._stored_oid
) # Convert to ObjectId
except ValueError as e:
# If it's not a valid ObjectId hex, try to convert to int
logger.debug(
f"Failed to convert GridFS ID '{self._stored_oid}' to ObjectId: {e}"
)
try:
self._actual_id = int(self._stored_oid)
except ValueError as e:
# Keep as string if it's neither a valid ObjectId nor integer
logger.debug(
f"Failed to convert GridFS ID '{self._stored_oid}' to integer: {e}"
)
self._actual_id = self._stored_oid
else:
# For other length strings, try to convert to integer first
try:
self._actual_id = int(self._stored_oid)
except ValueError as e:
# If it's not an integer string, keep as is
logger.debug(
f"Failed to convert non-24 char GridFS ID '{self._stored_oid}' to integer: {e}"
)
self._actual_id = self._stored_oid
else:
# For any other type, keep as is
self._actual_id = self._stored_oid # type: ignore
else:
# If no stored _id, fall back to the integer ID
self._actual_id = file_id # type: ignore
self._metadata = deserialize_metadata(metadata_str)
self._aliases = deserialize_aliases(aliases_str)
self._position = 0
self._current_chunk_data = b""
self._current_chunk_index = -1
self._closed = False
@property
def _id(self):
"""Get the file's actual ID, which may be an ObjectId or integer."""
return self._actual_id
@property
def _file_id(self):
"""Get the file's actual ID that represents what the user expects as the ID.
For compatibility with the original API, this returns the actual ID (ObjectId or int).
"""
return self._actual_id
[docs]
def _deserialize_aliases(self, aliases_str: str | None) -> list[str] | None:
"""
Deserialize aliases from JSON string.
Args:
aliases_str: JSON string representation of aliases
Returns:
List of alias strings or None
"""
return deserialize_aliases(aliases_str)
[docs]
def read(self, size: int = -1) -> bytes:
"""
Read data from the GridOut stream.
Args:
size: The number of bytes to read (-1 for all remaining data)
Returns:
The data read from the stream
"""
if self._closed:
raise ValueError("I/O operation on closed file")
if size == -1:
# Read all remaining data
size = self._length - self._position
if size <= 0:
return b""
# Calculate which chunks we need
result = bytearray()
bytes_read = 0
while bytes_read < size and self._position < self._length:
# Load the current chunk if needed
self._load_chunk()
# Calculate how much we can read from the current chunk
chunk_offset = self._position % self._chunk_size
bytes_available_in_chunk = (
len(self._current_chunk_data) - chunk_offset
)
bytes_to_read = min(size - bytes_read, bytes_available_in_chunk)
# Read from the current chunk
result.extend(
self._current_chunk_data[
chunk_offset : chunk_offset + bytes_to_read
]
)
# Update position
self._position += bytes_to_read
bytes_read += bytes_to_read
# Move to next chunk if we've exhausted the current one
if self._position % self._chunk_size == 0:
self._current_chunk_index += 1
self._current_chunk_data = b""
return bytes(result)
[docs]
def _load_chunk(self) -> None:
"""Load the chunk containing the current position."""
chunk_index = self._position // self._chunk_size
# If we already have the right chunk, we're done
if chunk_index == self._current_chunk_index:
return
# Load the required chunk using the integer file ID
row = self._db.execute(
f"""
SELECT data FROM {self._chunks_collection}
WHERE files_id = ? AND n = ?
""",
(self._int_file_id, chunk_index),
).fetchone()
if row is None:
raise NoFile(
f"Chunk {chunk_index} for file id {self._int_file_id} not found"
)
self._current_chunk_data = row[0]
self._current_chunk_index = chunk_index
@property
def filename(self) -> str:
"""Get the filename."""
return self._filename
@property
def length(self) -> int:
"""Get the length of the file in bytes."""
return self._length
@property
def chunk_size(self) -> int:
"""Get the chunk size in bytes."""
return self._chunk_size
@property
def upload_date(self) -> str:
"""Get the upload date."""
return self._upload_date
@property
def md5(self) -> str:
"""Get the MD5 hash of the file."""
return self._md5
@property
def metadata(self) -> dict[str, Any] | None:
"""Get the metadata of the file."""
return self._metadata
@property
def content_type(self) -> str | None:
"""Get the content type (MIME type) of the file."""
return self._content_type
@property
def aliases(self) -> list[str] | None:
"""Get the aliases (alternative names) for the file."""
return self._aliases
[docs]
def close(self) -> None:
"""Close the GridOut stream."""
self._closed = True
def __enter__(self) -> GridOut:
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
[docs]
class GridOutCursor:
"""
A cursor for iterating over GridFS files.
This class provides an iterator interface for retrieving file documents from GridFS.
"""
[docs]
def __init__(
self,
db: sqlite3.Connection,
bucket_name: str,
filter: dict[str, Any],
):
"""
Initialize a new GridOutCursor instance.
Args:
db: SQLite database connection
bucket_name: The bucket name for the GridFS files
filter: The filter to apply when searching for files
"""
self._db = db
self._bucket_name = bucket_name
self._filter = filter
self._files_collection = quote_table_name(f"{bucket_name}_files")
# Build query based on filter
where_clause = ""
params = []
if filter:
where_conditions = []
for key, value in filter.items():
match key:
case "_id":
# Handle ObjectId hex strings and other ID formats
if isinstance(value, ObjectId):
where_conditions.append("_id = ?")
params.append(str(value))
elif isinstance(value, str) and len(value) == 24:
# Check if it's a valid ObjectId hex string
try:
ObjectId(value)
where_conditions.append("_id = ?")
params.append(value)
except ValueError as e:
# Not a valid ObjectId, treat as regular string
logger.debug(
f"ID '{value}' is not a valid ObjectId: {e}"
)
where_conditions.append("_id = ?")
params.append(value)
else:
# Handle other types
where_conditions.append("_id = ?")
params.append(value)
case "id":
# For 'id' queries, we look in the integer id column
where_conditions.append("id = ?")
params.append(value)
case "filename":
if isinstance(value, dict):
# Handle operators like {"$regex": "pattern"}, {"$ne": "name"}, etc.
for op, val in value.items():
match op:
case "$regex":
where_conditions.append(
"filename LIKE ?"
)
params.append(f"%{val}%")
case "$ne":
where_conditions.append("filename != ?")
params.append(val)
case "$eq":
where_conditions.append("filename = ?")
params.append(val)
case _:
# For unsupported operators, fall back to exact match
where_conditions.append("filename = ?")
params.append(str(value))
else:
# Direct value comparison
where_conditions.append("filename = ?")
params.append(value)
case "length":
if isinstance(value, dict):
# Handle operators like {"$gt": 1000}, {"$lt": 5000}, etc.
for op, val in value.items():
match op:
case "$gt":
where_conditions.append("length > ?")
params.append(val)
case "$gte":
where_conditions.append("length >= ?")
params.append(val)
case "$lt":
where_conditions.append("length < ?")
params.append(val)
case "$lte":
where_conditions.append("length <= ?")
params.append(val)
case "$eq":
where_conditions.append("length = ?")
params.append(val)
case "$ne":
where_conditions.append("length != ?")
params.append(val)
else:
# Direct value comparison
where_conditions.append("length = ?")
params.append(value)
case "chunkSize":
if isinstance(value, dict):
# Handle operators like {"$gt": 1000}, {"$lt": 5000}, etc.
for op, val in value.items():
match op:
case "$gt":
where_conditions.append("chunkSize > ?")
params.append(val)
case "$gte":
where_conditions.append(
"chunkSize >= ?"
)
params.append(val)
case "$lt":
where_conditions.append("chunkSize < ?")
params.append(val)
case "$lte":
where_conditions.append(
"chunkSize <= ?"
)
params.append(val)
case "$eq":
where_conditions.append("chunkSize = ?")
params.append(val)
case "$ne":
where_conditions.append(
"chunkSize != ?"
)
params.append(val)
else:
# Direct value comparison
where_conditions.append("chunkSize = ?")
params.append(value)
case "uploadDate":
if isinstance(value, dict):
# Handle operators like {"$gt": date}, {"$lt": date}, etc.
for op, val in value.items():
match op:
case "$gt":
where_conditions.append(
"uploadDate > ?"
)
params.append(val)
case "$gte":
where_conditions.append(
"uploadDate >= ?"
)
params.append(val)
case "$lt":
where_conditions.append(
"uploadDate < ?"
)
params.append(val)
case "$lte":
where_conditions.append(
"uploadDate <= ?"
)
params.append(val)
case "$eq":
where_conditions.append(
"uploadDate = ?"
)
params.append(val)
case "$ne":
where_conditions.append(
"uploadDate != ?"
)
params.append(val)
else:
# Direct value comparison
where_conditions.append("uploadDate = ?")
params.append(value)
case "md5":
if isinstance(value, dict) and "$ne" in value:
where_conditions.append("md5 != ?")
params.append(value["$ne"])
else:
# Direct value comparison
where_conditions.append("md5 = ?")
params.append(value)
# For metadata, we do a simple string match (basic implementation)
# In a full implementation, we'd parse the JSON, but for now we'll do substring matching
case "metadata":
if isinstance(value, dict):
# Handle metadata queries with operators
for op, val in value.items():
match op:
case "$regex":
where_conditions.append(
"metadata LIKE ?"
)
params.append(f"%{val}%")
case "$ne":
where_conditions.append("metadata != ?")
params.append(
str(val)
if not isinstance(val, str)
else val
)
case _:
# For other operators, convert to string and match
where_conditions.append(
"metadata LIKE ?"
)
params.append(f"%{op}%{val}%")
else:
# Direct metadata string matching
where_conditions.append("metadata LIKE ?")
params.append(f"%{value}%")
case "aliases":
# Check if the value is in the aliases JSON array
where_conditions.append(
"EXISTS (SELECT 1 FROM json_each(aliases) WHERE value = ?)"
)
params.append(value)
case "content_type":
if isinstance(value, dict) and "$ne" in value:
where_conditions.append("content_type != ?")
params.append(value["$ne"])
else:
# Direct value comparison
where_conditions.append("content_type = ?")
params.append(value)
if where_conditions:
where_clause = "WHERE " + " AND ".join(where_conditions)
# Execute query to get integer file IDs (changed from _id to id for internal use)
query = f"SELECT id FROM {self._files_collection} {where_clause}"
cursor = self._db.execute(query, params)
self._file_ids = [row[0] for row in cursor.fetchall()]
self._index = 0
def __iter__(self) -> GridOutCursor:
"""Return the iterator object."""
return self
def __next__(self) -> GridOut:
"""Get the next GridOut object."""
if self._index >= len(self._file_ids):
raise StopIteration
file_id = self._file_ids[self._index]
self._index += 1
return GridOut(self._db, self._bucket_name, file_id)