Source code for neosqlite.client_session
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from .exceptions import InvalidOperation
if TYPE_CHECKING:
from .connection import Connection
logger = logging.getLogger(__name__)
[docs]
class ClientSession:
"""
Represents a client session for transactions in NeoSQLite.
This class provides PyMongo-compatible session and transaction management
by wrapping SQLite's native ACID transactions.
"""
[docs]
def __init__(
self, client: Connection, options: dict[str, Any] | None = None
):
"""
Initialize a new ClientSession.
Args:
client (Connection): The connection instance that created this session.
options (dict, optional): Session options.
"""
self.client = client
self.options = options or {}
self._in_transaction = False
@property
def in_transaction(self) -> bool:
"""
Check if the session is currently in a transaction.
Returns:
bool: True if in a transaction, False otherwise.
"""
return self._in_transaction
[docs]
def start_transaction(self, write_concern: dict[str, Any] | None = None):
"""
Start a new transaction.
Args:
write_concern (dict, optional): Write concern for the transaction.
"""
if self._in_transaction:
raise InvalidOperation("Transaction already in progress")
# SQLite transaction logic
if self.client.db.in_transaction:
# Use SAVEPOINT for nested transactions
self._savepoint_name = f"session_tx_{id(self)}"
self.client.db.execute(f"SAVEPOINT {self._savepoint_name}")
self._is_savepoint = True
else:
# Start a normal transaction
self.client.db.execute("BEGIN IMMEDIATE")
self._is_savepoint = False
self._in_transaction = True
[docs]
def commit_transaction(self):
"""
Commit the current transaction.
"""
if not self._in_transaction:
raise InvalidOperation("No transaction in progress")
if self._is_savepoint:
self.client.db.execute(f"RELEASE SAVEPOINT {self._savepoint_name}")
else:
self.client.db.commit()
self._in_transaction = False
[docs]
def abort_transaction(self):
"""
Abort (rollback) the current transaction.
"""
if not self._in_transaction:
raise InvalidOperation("No transaction in progress")
if self._is_savepoint:
self.client.db.execute(
f"ROLLBACK TO SAVEPOINT {self._savepoint_name}"
)
self.client.db.execute(f"RELEASE SAVEPOINT {self._savepoint_name}")
else:
self.client.db.rollback()
self._in_transaction = False
[docs]
def end_session(self):
"""
End the session. If in a transaction, it will be aborted.
"""
if self._in_transaction:
try:
self.abort_transaction()
except Exception as e:
logger.warning(
f"Failed to abort transaction during session close: {e}"
)
pass
[docs]
def with_transaction(
self,
callback: Callable[[ClientSession], Any],
read_concern: Any | None = None,
write_concern: Any | None = None,
read_preference: Any | None = None,
max_commit_time_ms: int | None = None,
) -> Any:
"""
Execute a callback in a transaction.
This method automatically starts a transaction, executes the callback,
and commits the transaction if the callback succeeds. If the callback
raises an exception, the transaction is aborted.
Args:
callback: A function that takes a ClientSession as its only argument.
read_concern (optional): Unused in NeoSQLite.
write_concern (optional): Unused in NeoSQLite.
read_preference (optional): Unused in NeoSQLite.
max_commit_time_ms (optional): Unused in NeoSQLite.
Returns:
The return value of the callback.
"""
self.start_transaction()
try:
result = callback(self)
self.commit_transaction()
return result
except Exception as e:
logger.debug(f"Transaction context manager failed: {e}")
if self._in_transaction:
self.abort_transaction()
raise
def __enter__(self) -> ClientSession:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and self._in_transaction:
self.abort_transaction()
self.end_session()