from __future__ import annotations
import json
import logging
import re
from contextlib import contextmanager
from typing import Any, Iterator, Literal
from ._sqlite import sqlite3
from .client_session import ClientSession
from .collection import Collection
from .collection.aggregation_cursor import AggregationCursor
from .exceptions import CollectionInvalid
from .migration import migrate_autovacuum, needs_migration, should_migrate
from .options import AutoVacuumMode, JournalMode, WriteConcern
from .sql_utils import quote_table_name
logger = logging.getLogger(__name__)
[docs]
class Connection:
"""
Represents a connection to an NeoSQLite database.
Provides methods for managing collections, executing SQL queries,
and handling database lifecycle events. Supports PyMongo-like API.
"""
DEFAULT_TRANSLATION_CACHE_SIZE = 100
[docs]
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Initialize a new database connection.
Args:
*args: Positional arguments passed to sqlite3.connect().
**kwargs: Keyword arguments passed to sqlite3.connect().
Special kwargs:
- tokenizers: List of tuples (name, path) for FTS5 tokenizers to load
- debug: Boolean flag to enable debug printing
- name: Optional name for the database (for PyMongo API compatibility)
- _is_clone: Internal flag for cloning (not for public use)
- journal_mode: Optional journal mode (default: "WAL")
- auto_vacuum: Auto vacuum mode (default: INCREMENTAL).
Can be 0/NONE, 1/FULL, 2/INCREMENTAL, or "NONE"/"FULL"/"INCREMENTAL".
If database has different auto_vacuum setting, migration may be triggered.
- translation_cache: SQL translation cache size (default: 100, 0 to disable)
"""
self._collections: dict[str, Collection] = {}
self._tokenizers: list[tuple[str, str]] = kwargs.pop("tokenizers", [])
self.debug: bool = kwargs.pop("debug", False)
self._is_clone = kwargs.pop("_is_clone", False)
self._codec_options = kwargs.pop("codec_options", None)
self._read_preference = kwargs.pop("read_preference", None)
self._write_concern = kwargs.pop("write_concern", None)
self._read_concern = kwargs.pop("read_concern", None)
self.journal_mode = JournalMode.validate(
kwargs.pop("journal_mode", "WAL")
)
self.auto_vacuum = AutoVacuumMode.validate(
kwargs.pop("auto_vacuum", AutoVacuumMode.INCREMENTAL)
)
self._translation_cache_size: int | None = kwargs.pop(
"translation_cache", self.DEFAULT_TRANSLATION_CACHE_SIZE
)
self.name: str = kwargs.pop("name", None)
self._db_path = args[0] if args else ":memory:"
if self.name is None:
self.name = (
self._db_path if self._db_path != ":memory:" else "memory"
)
self._closed = False
if not self._is_clone:
self.connect(*args, **kwargs)
@property
def db_path(self) -> str:
"""
Get the path to the database file.
Returns:
str: The database file path.
"""
return self._db_path
[docs]
def connect(self, *args: Any, **kwargs: Any) -> None:
"""
Establish a connection to the SQLite database.
Configures the database connection with the provided arguments, sets up
SQLite-specific settings like isolation level and journal mode, and loads
custom FTS5 tokenizers if specified. This method does not return a value.
Args:
*args: Positional arguments passed to sqlite3.connect().
**kwargs: Keyword arguments passed to sqlite3.connect().
"""
self.db = sqlite3.connect(*args, **kwargs)
# Safely access attributes to allow calling connect() on partially initialized objects
auto_vacuum = getattr(self, "auto_vacuum", AutoVacuumMode.INCREMENTAL)
journal_mode = getattr(self, "journal_mode", JournalMode.WAL)
self.db.execute(f"PRAGMA auto_vacuum={auto_vacuum}")
self.db.isolation_level = None
self.db.execute(f"PRAGMA journal_mode={journal_mode}")
self._register_custom_functions()
# Set synchronous mode to NORMAL for WAL mode by default if not specified by write concern
# This provides a good balance of performance and safety in WAL mode
if journal_mode == "WAL":
self.db.execute("PRAGMA synchronous = NORMAL")
write_concern = getattr(self, "_write_concern", None)
if write_concern:
self._apply_write_concern(write_concern)
if self._tokenizers:
self.db.enable_load_extension(True)
for name, path in self._tokenizers:
# Use parameterized query to prevent SQL injection in path
self.db.execute("SELECT load_extension(?)", (path,))
self._check_and_migrate_autovacuum(*args, **kwargs)
[docs]
def _register_custom_functions(self) -> None:
"""Register custom SQLite functions, including regex operators."""
def _regexp(pattern, text):
if text is None:
return 0
return 1 if re.search(str(pattern), str(text)) is not None else 0
def _regexp_find(pattern, text):
if text is None or pattern is None:
return None
match = re.search(str(pattern), str(text))
if match:
return json.dumps(
{
"match": match.group(0),
"idx": match.start(),
"captures": list(match.groups()),
}
)
return None
def _regexp_find_all(pattern, text):
if text is None or pattern is None:
return "[]"
matches = []
for match in re.finditer(str(pattern), str(text)):
matches.append(
{
"match": match.group(0),
"idx": match.start(),
"captures": list(match.groups()),
}
)
return json.dumps(matches)
def _regexp_replace(text, pattern, replacement, count=0):
if text is None:
return None
if pattern is None or replacement is None:
return text
# count=0 in re.sub means replace all
return re.sub(
str(pattern), str(replacement), str(text), count=int(count)
)
self.db.create_function("REGEXP", 2, _regexp)
self.db.create_function("REGEXP_FIND", 2, _regexp_find)
self.db.create_function("REGEXP_FIND_ALL", 2, _regexp_find_all)
self.db.create_function("REGEXP_REPLACE", 4, _regexp_replace)
[docs]
def _check_and_migrate_autovacuum(self, *args: Any, **kwargs: Any) -> None:
"""
Check auto_vacuum setting and migrate if needed.
Called after initial connection is established.
"""
if self._is_clone or self._db_path == ":memory:":
return
if not needs_migration(self.db, self.auto_vacuum):
return
if not should_migrate():
return
self._migrate_to_autovacuum(*args, **kwargs)
[docs]
def _migrate_to_autovacuum(self, *args: Any, **kwargs: Any) -> None:
"""
Migrate database to a new auto_vacuum mode.
Delegates to migration.migrate_autovacuum() and then reconnects.
"""
old_path = self._db_path
try:
self.db.execute("PRAGMA wal_checkpoint(FULL)")
except sqlite3.OperationalError as e:
logger.warning(
f"WAL checkpoint skipped (may be no WAL or already checkpointing): {e}"
)
except Exception as e:
logger.error(
f"Unexpected error during WAL checkpoint: {e}", exc_info=True
)
try:
self.db.execute("COMMIT")
except sqlite3.OperationalError as e:
logger.debug(f"Commit skipped (no active transaction): {e}")
except Exception as e:
logger.error(f"Unexpected error during commit: {e}", exc_info=True)
try:
self.db.close()
except Exception as e:
logger.error(
f"Unexpected error closing database: {e}", exc_info=True
)
migrate_autovacuum(
db_path=old_path,
target_autovacuum=self.auto_vacuum,
target_journal_mode=self.journal_mode,
extra_conn_kwargs=kwargs,
)
self.db = sqlite3.connect(*args, **kwargs)
# Safely access attributes to allow calling connect() on partially initialized objects
auto_vacuum = getattr(self, "auto_vacuum", AutoVacuumMode.INCREMENTAL)
journal_mode = getattr(self, "journal_mode", JournalMode.WAL)
self.db.execute(f"PRAGMA auto_vacuum={auto_vacuum}")
self.db.isolation_level = None
self.db.execute(f"PRAGMA journal_mode={journal_mode}")
self._register_custom_functions()
# Set synchronous mode to NORMAL for WAL mode by default if not specified by write concern
# This provides a good balance of performance and safety in WAL mode
if journal_mode == "WAL":
self.db.execute("PRAGMA synchronous = NORMAL")
write_concern = getattr(self, "_write_concern", None)
if write_concern:
self._apply_write_concern(write_concern)
if self._tokenizers:
self.db.enable_load_extension(True)
for name, path in self._tokenizers:
# Use parameterized query to prevent SQL injection in path
self.db.execute("SELECT load_extension(?)", (path,))
[docs]
def cleanup(self) -> None:
"""Clean up all collection resources associated with this connection."""
if hasattr(self, "_collections"):
for collection in self._collections.values():
collection.cleanup()
[docs]
def close(self) -> None:
"""
Close the database connection.
Commits any pending transaction and properly closes the underlying SQLite
connection. This method ensures resources are released and the connection
is no longer usable after being called.
"""
if getattr(self, "_is_clone", False) or getattr(self, "_closed", False):
return
# Clean up collections before closing the database
self.cleanup()
if self.db is not None:
try:
if self.db.in_transaction:
logger.warning(
"Closing connection with pending transaction; committing automatically"
)
self.db.commit()
except (sqlite3.ProgrammingError, sqlite3.OperationalError) as e:
logger.debug(f"{e=}")
pass
try:
self.db.close()
except (sqlite3.ProgrammingError, sqlite3.OperationalError) as e:
logger.debug(f"{e=}")
pass
self._closed = True
@property
def client(self) -> Connection:
"""
Get the MongoClient instance (returns self for NeoSQLite).
Returns:
Connection: The connection instance itself.
"""
return self
@property
def codec_options(self) -> Any:
"""
Get the codec options for this connection.
Returns:
Any: The codec options.
"""
return self._codec_options
@property
def read_preference(self) -> Any:
"""
Get the read preference for this connection.
Returns:
Any: The read preference.
"""
return self._read_preference
@property
def write_concern(self) -> Any:
"""
Get the write concern for this connection.
Returns:
Any: The write concern.
"""
return self._write_concern
@property
def read_concern(self) -> Any:
"""
Get the read concern for this connection.
Returns:
Any: The read concern.
"""
return self._read_concern
def __getitem__(self, name: str) -> Collection:
"""
Access a collection by name.
Allows retrieving or creating a collection associated with this connection
using dictionary-style access. If the collection does not exist, it will be
created automatically.
Args:
name (str): The name of the collection to access.
Returns:
Collection: The collection instance associated with the given name.
"""
if name not in self._collections:
self._collections[name] = Collection(self.db, name, database=self)
return self._collections[name]
def __getattr__(self, name: str) -> Any:
"""
Proxy attribute access to collection lookup.
When an attribute is not found in the instance's dictionary, this method
attempts to retrieve it using the dictionary-style collection access (via
`__getitem__`). This enables both attribute and dictionary access to collections.
Returns:
Any: The value retrieved from the collection, or the attribute if it exists.
"""
if name in self.__dict__:
return self.__dict__[name]
return self[name]
[docs]
def start_session(
self,
causal_consistency: bool | None = None,
default_transaction_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> ClientSession:
"""
Start a new client session for transactions.
This method provides PyMongo-compatible session management by wrapping
SQLite's native ACID transactions.
Args:
causal_consistency (bool, optional): Whether to enable causal consistency.
Ignored in NeoSQLite (stored for API compatibility).
default_transaction_options (dict, optional): Default transaction options.
**kwargs: Additional session arguments.
Returns:
ClientSession: A new ClientSession instance.
"""
options = kwargs.copy()
if causal_consistency is not None:
options["causal_consistency"] = causal_consistency
if default_transaction_options:
options["default_transaction_options"] = default_transaction_options
return ClientSession(self, options=options)
def __enter__(self) -> Connection:
"""
Allow the connection to be used in a context manager.
Returns:
Connection: The connection instance itself, enabling the 'with' statement
to manage the connection's lifecycle.
"""
return self
def __exit__(
self, exc_type: Any, exc_val: Any, exc_traceback: Any
) -> Literal[False]:
"""
Ensure the connection is properly closed when exiting a context manager.
Returns:
Literal[False]: Indicates that the method does not handle exceptions
and the connection should be closed.
"""
self.close()
return False
[docs]
def drop_collection(self, name: str) -> None:
"""
Drop a collection (table) from the database.
Args:
name (str): The name of the collection (table) to drop. If the table
does not exist, the operation is silently ignored due to
the use of `IF EXISTS` in the SQL command.
"""
self.db.execute(f"DROP TABLE IF EXISTS {quote_table_name(name)}")
[docs]
def create_collection(self, name: str, **kwargs) -> Collection:
"""
Create a new collection with specific options.
Args:
name (str): The name of the collection to create.
**kwargs: Additional options for collection creation.
Returns:
Collection: The newly created collection.
Raises:
CollectionInvalid: If a collection with the given name already exists.
"""
if name in self._collections:
raise CollectionInvalid(f"Collection {name} already exists")
collection = Collection(
self.db, name, create=True, database=self, **kwargs
)
self._collections[name] = collection
return collection
[docs]
def get_collection(self, name: str, **kwargs) -> Collection:
"""
Get a collection by name.
Args:
name (str): The name of the collection to get.
**kwargs: Additional options for collection access.
Returns:
Collection: The collection instance.
"""
if name not in self._collections:
self._collections[name] = Collection(
self.db, name, create=False, database=self, **kwargs
)
return self._collections[name]
[docs]
def rename_collection(self, old_name: str, new_name: str) -> None:
"""
Rename a collection.
Args:
old_name (str): The current name of the collection.
new_name (str): The new name for the collection.
Raises:
CollectionInvalid: If the old collection doesn't exist or new name already exists.
"""
if old_name not in self._collections:
raise CollectionInvalid(f"Collection {old_name} does not exist")
# Check if new name already exists
cursor = self.db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(new_name,),
)
if cursor.fetchone():
raise CollectionInvalid(f"Collection {new_name} already exists")
# Rename the collection
self._collections[old_name].rename(new_name)
self._collections[new_name] = self._collections.pop(old_name)
[docs]
def list_collection_names(self) -> list[str]:
"""
List all collection names in the database.
Returns:
list[str]: A list of all collection names in the database,
excluding internal SQLite tables (sqlite_sequence, sqlite_stat*, etc.)
"""
cursor = self.db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
return [row[0] for row in cursor.fetchall()]
[docs]
def list_collections(self) -> list[dict[str, Any]]:
"""
Get detailed information about collections in the database.
Returns:
list[dict[str, Any]]: A list of dictionaries containing collection information.
Each dictionary has 'name' and 'options' keys.
"""
cursor = self.db.execute(
"SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
return [
{"name": row[0], "options": row[1]} for row in cursor.fetchall()
]
[docs]
def command(
self, command: str | dict[str, Any], value: Any = None, **kwargs: Any
) -> dict[str, Any]:
"""
Issue a database command and return the response.
This method provides PyMongo-compatible command execution for SQLite.
It supports various commands including PRAGMA commands, introspection,
and utility commands.
Args:
command: The command to execute. Can be:
- A string (e.g., "table_info", "integrity_check")
- A dict with command name as key (e.g., {"ping": 1})
value: Optional command value (for string commands)
**kwargs: Additional command arguments
Returns:
dict[str, Any]: Command response
Supported Commands:
- "ping" or {"ping": 1} - Returns {"ok": 1}
- "serverStatus" - Returns SQLite version info
- "listCollections" - Returns collection list
- "table_info" - Returns table schema (PRAGMA table_info)
- "integrity_check" - Returns integrity check results
- "validate" - Returns integrity check for a collection
- "reIndex" - Rebuilds indexes for a collection
- "foreign_key_check" - Returns foreign key check results
- "index_list" - Returns index list for a table
- "vacuum" - Runs full VACUUM command
- "compact" - MongoDB-compatible compact (see below)
- "wal_checkpoint" - WAL checkpoint control (NeoSQLite extension)
- "cache_size" - Get/set cache size (NeoSQLite extension)
- "busy_timeout" - Get/set busy timeout (NeoSQLite extension)
- "query_only" - Get/set read-only mode (NeoSQLite extension)
- "analyze" - Runs ANALYZE command
compact Command:
MongoDB-compatible compact implementation with NeoSQLite extensions.
Args:
collection: Collection name (ignored, operates on entire database)
dryRun: If true, returns estimate without actually compacting
freeSpaceTargetMB: Target in MB (default: 20). Only compacts if
free space >= threshold. Uses incremental vacuum in batches.
Set to 0 for full VACUUM instead of incremental.
comment: Optional comment (ignored, for MongoDB compatibility)
Returns:
dryRun: {"estimatedBytesFreed": <bytes>, "ok": 1}
compact: {"bytesFreed": <bytes>, "ok": 1}
Examples:
>>> db.command("compact", "collection") # 20MB threshold + incremental
>>> db.command("compact", "collection", dryRun=True) # Estimate only
>>> db.command("compact", "collection", freeSpaceTargetMB=10) # 10MB threshold + incremental
>>> db.command("compact", "collection", freeSpaceTargetMB=0) # Full vacuum
Example:
>>> db = Connection("test.db")
>>> result = db.command("ping")
>>> print(result)
{'ok': 1.0}
"""
# If command is a string and value is provided, it's equivalent to {command: value}
if isinstance(command, str) and value is not None:
command = {command: value}
# Handle string commands
if isinstance(command, str):
cmd_name = command.lower()
elif isinstance(command, dict):
# Handle dict commands (PyMongo style)
cmd_name = next(iter(command.keys())).lower()
else:
raise TypeError("command must be a string or dict")
try:
match cmd_name:
case "ping":
return {"ok": 1}
case "serverstatus":
import sqlite3
return {
"ok": 1,
"version": sqlite3.sqlite_version,
"python_sqlite_version": getattr(
sqlite3, "version", "unknown"
),
"process": "neosqlite",
"pid": 1,
}
case "listcollections":
collections = self.list_collection_names()
return {
"ok": 1,
"collections": [{"name": name} for name in collections],
}
case "table_info":
table_name = kwargs.get("table")
if not table_name and isinstance(command, dict):
table_name = command.get("table_info")
if not table_name:
raise ValueError(
"table_info requires 'table' parameter"
)
cursor = self.db.execute(
f"PRAGMA table_info({quote_table_name(table_name)})"
)
columns = [
{
"cid": row[0],
"name": row[1],
"type": row[2],
"notnull": bool(row[3]),
"default": row[4],
"pk": bool(row[5]),
}
for row in cursor.fetchall()
]
return {"ok": 1, "columns": columns}
case "integrity_check":
cursor = self.db.execute("PRAGMA integrity_check")
result = cursor.fetchall()
return {"ok": 1, "result": [row[0] for row in result]}
case "validate":
collection_name = kwargs.get("validate")
if not collection_name and isinstance(command, dict):
collection_name = command.get("validate")
if collection_name:
cursor = self.db.execute(
f"PRAGMA integrity_check({quote_table_name(collection_name)})"
)
else:
cursor = self.db.execute("PRAGMA integrity_check")
result = cursor.fetchall()
errors = [row[0] for row in result if row[0] != "ok"]
return {
"ok": 1 if not errors else 0,
"result": [row[0] for row in result],
"errors": errors,
"valid": len(errors) == 0,
}
case "foreign_key_check":
table_name = kwargs.get("table")
if table_name:
cursor = self.db.execute(
f"PRAGMA foreign_key_check({quote_table_name(table_name)})"
)
else:
cursor = self.db.execute("PRAGMA foreign_key_check")
result = cursor.fetchall()
return {
"ok": 1,
"violations": [
{
"table": row[0],
"rowid": row[1],
"parent": row[2],
"fkid": row[3],
}
for row in result
],
}
case "index_list":
table_name = kwargs.get("table")
if not table_name and isinstance(command, dict):
table_name = command.get("index_list")
if not table_name:
raise ValueError(
"index_list requires 'table' parameter"
)
cursor = self.db.execute(
f"PRAGMA index_list({quote_table_name(table_name)})"
)
indexes = [
{
"seq": row[0],
"name": row[1],
"unique": bool(row[2]),
"origin": row[3] if len(row) > 3 else "c",
"partial": bool(row[4]) if len(row) > 4 else False,
}
for row in cursor.fetchall()
]
return {"ok": 1, "indexes": indexes}
case "vacuum":
self.db.execute("VACUUM")
return {"ok": 1, "message": "VACUUM completed"}
case "compact":
_ = kwargs.get("compact")
_ = kwargs.get("comment")
dry_run = kwargs.get("dryRun", False)
free_space_target_mb = kwargs.get("freeSpaceTargetMB")
free_pages = self.db.execute(
"PRAGMA freelist_count"
).fetchone()[0]
page_size = self.db.execute("PRAGMA page_size").fetchone()[
0
]
free_bytes = free_pages * page_size
if dry_run:
return {
"estimatedBytesFreed": free_bytes,
"ok": 1,
}
if free_space_target_mb == 0:
page_count_before = self.db.execute(
"PRAGMA page_count"
).fetchone()[0]
self.db.execute("VACUUM")
page_count_after = self.db.execute(
"PRAGMA page_count"
).fetchone()[0]
bytes_freed = (
page_count_before - page_count_after
) * page_size
return {"bytesFreed": bytes_freed, "ok": 1}
if free_space_target_mb is None:
free_space_target_mb = 20
target_bytes = free_space_target_mb * 1024 * 1024
if free_bytes < target_bytes:
return {"bytesFreed": 0, "ok": 1}
page_count_before = self.db.execute(
"PRAGMA page_count"
).fetchone()[0]
target_pages = max(1, target_bytes // page_size)
while free_pages > 0:
reclaim = min(target_pages, free_pages)
self.db.execute(f"PRAGMA incremental_vacuum({reclaim})")
free_pages = self.db.execute(
"PRAGMA freelist_count"
).fetchone()[0]
page_count_after = self.db.execute(
"PRAGMA page_count"
).fetchone()[0]
bytes_freed = (
page_count_before - page_count_after
) * page_size
return {"bytesFreed": bytes_freed, "ok": 1}
case "wal_checkpoint":
mode = str(kwargs.get("mode", "PASSIVE")).upper()
if mode not in ("PASSIVE", "FULL", "RESTART", "TRUNCATE"):
return {
"ok": 0,
"errmsg": f"Invalid wal_checkpoint mode: {mode}",
}
result = self.db.execute(
f"PRAGMA wal_checkpoint({mode})"
).fetchone()
return {
"ok": 1,
"mode": mode,
"busy": result[0],
"log": result[1],
"checkpointed": result[2],
}
case "cache_size":
pages = kwargs.get("pages")
if pages is not None:
self.db.execute(f"PRAGMA cache_size = {pages}")
return {
"ok": 1,
"message": f"cache_size set to {pages}",
}
else:
current = self.db.execute(
"PRAGMA cache_size"
).fetchone()[0]
return {"ok": 1, "cache_size": current}
case "busy_timeout":
ms = kwargs.get("milliseconds")
if ms is not None:
self.db.execute(f"PRAGMA busy_timeout = {ms}")
return {
"ok": 1,
"message": f"busy_timeout set to {ms}ms",
}
else:
current = self.db.execute(
"PRAGMA busy_timeout"
).fetchone()[0]
return {"ok": 1, "busy_timeout": current}
case "analyze":
self.db.execute("ANALYZE")
return {"ok": 1, "message": "ANALYZE completed"}
case "reindex":
collection_name = kwargs.get("reIndex")
if not collection_name and isinstance(command, dict):
collection_name = command.get("reindex")
if collection_name:
self.db.execute(
f"REINDEX {quote_table_name(collection_name)}"
)
else:
self.db.execute("REINDEX")
return {"ok": 1, "message": "REINDEX completed"}
case "collstats":
collection_name = kwargs.get("collection")
if not collection_name and isinstance(command, dict):
collection_name = command.get("collstats")
if not collection_name:
raise ValueError(
"collstats requires 'collection' parameter"
)
quoted_table = quote_table_name(collection_name)
cursor = self.db.execute(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
(collection_name,),
)
if not cursor.fetchone():
return {
"ok": 1,
"ns": collection_name,
"count": 0,
"size": 0,
"avgObjSize": 0,
"storageSize": 0,
"totalIndexSize": 0,
"indexSizes": {},
}
cursor = self.db.execute(
f"SELECT COUNT(*) FROM {quoted_table}"
)
count = cursor.fetchone()[0] or 0
size = 0
try:
size_cursor = self.db.execute(
f"SELECT SUM(LENGTH(data)) FROM {quoted_table}"
)
size = size_cursor.fetchone()[0] or 0
except Exception as e:
logger.debug(
f"Failed to calculate collection size: {e}"
)
pass
avg_obj_size = size / count if count > 0 else 0
storage_size = 0
total_index_size = 0
index_sizes: dict = {}
try:
self.db.execute(
"CREATE VIRTUAL TABLE IF NOT EXISTS temp.dbstat USING dbstat(main)"
)
storage_cursor = self.db.execute(
"SELECT SUM(pgsize) FROM dbstat WHERE name = ?",
(collection_name,),
)
storage_size = storage_cursor.fetchone()[0] or 0
index_cursor = self.db.execute(
"SELECT name, SUM(pgsize) as size FROM dbstat "
"WHERE tbl_name = ? AND type = 'index' GROUP BY name",
(collection_name,),
)
for row in index_cursor.fetchall():
idx_name, idx_size = row
if idx_name and idx_size:
index_sizes[idx_name] = idx_size
total_index_size += idx_size
except Exception as e:
logger.debug(f"{e=}")
pass
return {
"ok": 1,
"ns": collection_name,
"count": count,
"size": size,
"avgObjSize": avg_obj_size,
"storageSize": storage_size,
"totalIndexSize": total_index_size,
"indexSizes": index_sizes,
}
case "dbstats":
page_count = self.db.execute(
"PRAGMA page_count"
).fetchone()[0]
page_size = self.db.execute("PRAGMA page_size").fetchone()[
0
]
collections = self.list_collection_names()
total_objects = 0
total_indexes = 0
for coll_name in collections:
cursor = self.db.execute(
"SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
(coll_name,),
)
if not (row := cursor.fetchone()):
continue
sql = row[0] or ""
if (
"VIRTUAL TABLE" in sql.upper()
and "fts" in sql.lower()
):
continue
try:
count = self.db.execute(
f"SELECT COUNT(*) FROM {quote_table_name(coll_name)}"
).fetchone()[0]
total_objects += count
except Exception as e:
logger.debug(
f"Failed to count documents in '{coll_name}': {e}"
)
pass
try:
indexes = self.db.execute(
f"PRAGMA index_list({quote_table_name(coll_name)})"
).fetchall()
total_indexes += len(indexes)
except Exception as e:
logger.debug(
f"Failed to get indexes for '{coll_name}': {e}"
)
pass
storage_size = page_count * page_size
try:
self.db.execute(
"CREATE VIRTUAL TABLE IF NOT EXISTS temp.dbstat USING dbstat(main)"
)
cursor = self.db.execute(
"SELECT SUM(pgsize) FROM dbstat WHERE name LIKE 'idx_%' OR name LIKE 'sqlite_autoindex%'"
)
index_size = cursor.fetchone()[0] or 0
except Exception as e:
logger.debug(
f"Failed to calculate index size, using estimate: {e}"
)
index_size = int(storage_size * 0.2)
data_size = storage_size - index_size
views_count = self.db.execute(
"SELECT COUNT(*) FROM sqlite_master WHERE type='view'"
).fetchone()[0]
import os
import shutil
fs_total = 0
fs_used = 0
db_file_size = 0
if self._db_path and self._db_path != ":memory:":
try:
db_dir = os.path.dirname(self._db_path) or "."
fs_usage = shutil.disk_usage(db_dir)
fs_total = fs_usage.total
fs_used = fs_usage.used
db_file_size = os.path.getsize(self._db_path)
wal_path = self._db_path + "-wal"
if os.path.exists(wal_path):
db_file_size += os.path.getsize(wal_path)
shm_path = self._db_path + "-shm"
if os.path.exists(shm_path):
db_file_size += os.path.getsize(shm_path)
except OSError as e:
logger.debug(f"{e=}")
pass
return {
"ok": 1,
"db": self.name,
"collections": len(collections),
"views": views_count,
"objects": total_objects,
"avgObjSize": (
int(data_size / total_objects)
if total_objects
else 0
),
"dataSize": data_size,
"storageSize": storage_size,
"indexes": total_indexes,
"indexSize": index_size,
"totalSize": data_size + index_size,
"fsTotalSize": fs_total,
"fsUsedSize": fs_used,
"scaleFactor": 1,
}
case "aggregate":
collection_name = kwargs.get("aggregate")
if not collection_name and isinstance(command, dict):
collection_name = command.get("aggregate")
if not collection_name:
raise ValueError(
"aggregate requires 'aggregate' parameter"
)
pipeline = kwargs.get("pipeline", [])
explain = kwargs.get("explain", False)
kwargs.get("allowDiskUse", False)
if explain:
collection = self[collection_name]
return collection.query_engine.explain_aggregation(
pipeline, session=None
)
else:
collection = self[collection_name]
cursor_result = collection.aggregate(pipeline)
return {"ok": 1, "result": list(cursor_result)}
case "query_only":
pragma_value = (
command.get("query_only")
if isinstance(command, dict)
else None
)
if pragma_value is not None:
if isinstance(pragma_value, bool):
val = 1 if pragma_value else 0
elif isinstance(pragma_value, str):
val = (
1
if pragma_value.upper() in ("ON", "TRUE", "1")
else 0
)
elif isinstance(pragma_value, (int, float)):
val = 1 if pragma_value else 0
else:
return {
"ok": 0,
"errmsg": f"Invalid query_only value: {pragma_value!r}",
}
self.db.execute(f"PRAGMA query_only = {val}")
return {"ok": 1, "query_only": bool(val)}
else:
cursor = self.db.execute("PRAGMA query_only")
return {
"ok": 1,
"query_only": bool(cursor.fetchone()[0]),
}
case _:
try:
# Validate cmd_name to prevent SQL injection
if not re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", cmd_name):
return {
"ok": 0,
"errmsg": f"Invalid command name: {cmd_name}",
}
cursor = self.db.execute(f"PRAGMA {cmd_name}")
result = cursor.fetchall()
return {
"ok": 1,
"result": [
dict(
zip([d[0] for d in cursor.description], row)
)
for row in result
],
}
except Exception as e:
logger.debug(f"{e=}")
return {
"ok": 0,
"errmsg": f"Unknown command: {cmd_name}",
"error": str(e),
}
except Exception as e:
logger.debug(f"{e=}")
return {"ok": 0, "errmsg": str(e), "code": 1}
[docs]
def cursor_command(
self, command: str | dict[str, Any], value: Any = None, **kwargs: Any
) -> AggregationCursor:
"""
Execute a database command and return a cursor for its results.
This method provides PyMongo-compatible cursor-based command execution.
It wraps the results of `command()` in an `AggregationCursor`, allowing
them to be iterated.
Args:
command: The command to execute.
value: Optional command value.
**kwargs: Additional command arguments.
Returns:
AggregationCursor: A cursor over the command results.
"""
result = self.command(command, value=value, **kwargs)
match result:
case {"collections": _}:
items = result["collections"]
case {"columns": _}:
items = result["columns"]
case {"result": _} as r:
items = r["result"]
if not isinstance(items, list):
items = [items]
case {"cursor": {"firstBatch": _}}:
items = result["cursor"]["firstBatch"]
case _:
items = [result]
# Use a safe dummy collection name for AggregationCursor
# We don't want to use existing collection names because some might be reserved (e.g. sqlite_sequence)
collection = self["__command_results__"]
# Create an AggregationCursor with the pre-computed results
cursor = AggregationCursor(collection, [])
cursor._results = items
cursor._executed = True
return cursor
[docs]
def dereference(self, dbref: dict[str, Any]) -> dict[str, Any] | None:
"""
Resolve a DBRef object.
This method provides PyMongo-compatible DBRef resolution by performing
a `find_one` on the target collection using the provided `$id`.
Args:
dbref: A dictionary representing a DBRef with '$ref' and '$id' keys.
Returns:
dict[str, Any] | None: The referenced document, or None if not found.
"""
if not isinstance(dbref, dict):
return None
collection_name = dbref.get("$ref")
document_id = dbref.get("$id")
if not collection_name or document_id is None:
return None
return self[collection_name].find_one({"_id": document_id})
[docs]
def with_options(
self,
codec_options: Any | None = None,
read_preference: Any | None = None,
write_concern: Any | None = None,
read_concern: Any | None = None,
) -> Connection:
"""
Get a clone of this database with different options.
This method returns a new Connection instance with the specified options.
For PyMongo API compatibility, the options are stored but not actively
used since SQLite has different semantics.
Args:
codec_options (Any, optional): Codec options for encoding/decoding.
Ignored in NeoSQLite (stored for API compatibility).
read_preference (Any, optional): Read preference for replica sets.
Ignored in NeoSQLite (stored for API compatibility).
write_concern (Any, optional): Write concern for durability settings.
Stored for API compatibility.
read_concern (Any, optional): Read concern for consistency settings.
Ignored in NeoSQLite (stored for API compatibility).
Returns:
Connection: A new Connection instance with the same underlying database
but with the specified options stored.
Note:
NeoSQLite stores these options for PyMongo API compatibility, but
they don't affect SQLite behavior since SQLite doesn't have replica
sets, codec options, or the same consistency/durability model.
Example:
>>> db = Connection("test.db")
>>> db_with_options = db.with_options(
... write_concern={"w": "majority"},
... read_preference={"mode": "primaryPreferred"}
... )
"""
# Return a new Connection instance that shares the same database connection.
# This is by design: SQLite uses a single writable connection per database file,
# so clones naturally share the same sqlite3.Connection and _collections dict.
# The clone only differs in its stored PyMongo-compatible options
# (codec_options, read_preference, write_concern, read_concern).
clone = Connection(name=self.name, _is_clone=True)
clone.db = self.db
clone._tokenizers = self._tokenizers
clone.debug = self.debug
clone._collections = self._collections
clone.journal_mode = self.journal_mode
# Store the new options
clone._codec_options = codec_options
clone._read_preference = read_preference
clone._write_concern = write_concern
clone._read_concern = read_concern
# NOTE: We intentionally do NOT call _apply_write_concern() here.
# The clone shares the same sqlite3.Connection with the original, so
# changing PRAGMAs would affect the original too. The write_concern
# is stored for API compatibility but does not mutate the shared connection.
return clone
[docs]
def _apply_write_concern(self, write_concern: Any) -> None:
"""
Apply write concern to the underlying SQLite connection via PRAGMAs.
- w: 0 -> PRAGMA synchronous = OFF
- w: 1 -> PRAGMA synchronous = NORMAL
- j: True -> PRAGMA synchronous = FULL
Args:
write_concern (Any): The write concern to apply (dict or WriteConcern object).
"""
if not write_concern:
return
match write_concern:
case WriteConcern():
wc_doc = write_concern.document
case dict():
wc_doc = write_concern
case _:
return
w = wc_doc.get("w")
j = wc_doc.get("j")
try:
if w == 0:
self.db.execute("PRAGMA synchronous = OFF")
elif w == 1:
self.db.execute("PRAGMA synchronous = NORMAL")
if j:
self.db.execute("PRAGMA synchronous = FULL")
except Exception as e:
logger.warning(f"Error applying write concern: {e}")
[docs]
@contextmanager
def transaction(self) -> Iterator[None]:
"""
Context manager for handling database transactions.
Ensures atomicity by beginning a transaction on entry, committing on
successful exit, and rolling back in case of exceptions. This allows
using the connection in a 'with' statement to manage transaction
boundaries safely.
Yields control to the block, and automatically commits or rolls back
based on execution outcome.
"""
try:
self.db.execute("BEGIN")
yield
self.db.commit()
except Exception as e:
logger.debug(f"Transaction failed, rolling back: {e}")
self.db.rollback()
raise
def __del__(self):
"""
Ensure the database connection is closed when the object is garbage collected.
Wrapped in try/except to avoid crashes during interpreter shutdown when
module-level imports (like sqlite3) may already be None.
"""
try:
self.close()
except Exception as e:
if logger is not None:
logger.debug(f"{e=}")
pass