Source code for neosqlite.collection.type_utils

"""
Shared type checking and conversion utilities for the collection package.

This module consolidates type-related utility functions that are used across
multiple submodules (expr_evaluator, query_helper, etc.) to avoid code duplication
and provide a single source of truth for type operations.
"""

from __future__ import annotations

import re
from typing import Any

# =============================================================================
# Type Conversion Functions
# =============================================================================


[docs] def _convert_to_int(value: Any) -> Any: """Convert value to int.""" return int(value)
[docs] def _convert_to_long(value: Any) -> Any: """Convert value to long (64-bit int).""" return int(value)
[docs] def _convert_to_double(value: Any) -> Any: """Convert value to double (float).""" return float(value)
[docs] def _convert_to_decimal(value: Any) -> Any: """Convert value to decimal (float, as SQLite lacks Decimal128).""" return float(value)
[docs] def _convert_to_string(value: Any) -> Any: """Convert value to string.""" return str(value)
[docs] def _convert_to_bool(value: Any) -> Any: """Convert value to bool.""" return bool(value)
[docs] def _convert_to_objectid(value: Any) -> Any: """Convert value to ObjectId.""" from neosqlite.objectid import ObjectId return ObjectId(str(value)) if value else None
[docs] def _convert_to_bindata(value: Any) -> Any: """Convert value to Binary (binData).""" from neosqlite.binary import Binary if value is None: return None if isinstance(value, str): return Binary(value.encode("utf-8")) return Binary(value)
[docs] def _convert_to_bsonbindata(value: Any) -> Any: """Convert value to Binary (bsonBinData).""" from neosqlite.binary import Binary if value is None: return None if isinstance(value, str): return Binary(value.encode("utf-8")) return Binary(value)
[docs] def _convert_to_regex(value: Any) -> Any: """Convert value to regex pattern.""" return re.compile(str(value)) if value else None
[docs] def _convert_to_bsonregex(value: Any) -> Any: """Convert value to regex pattern (bsonRegex).""" return re.compile(str(value)) if value else None
[docs] def _convert_to_date(value: Any) -> Any: """Convert value to date (returns as-is; proper conversion requires parsing).""" return value
[docs] def _convert_to_null(value: Any) -> None: """Convert any value to None.""" return None
[docs] def get_bson_type(value: Any) -> str: """ Get BSON type name for a value. Args: value: The value to check Returns: BSON type name (e.g., 'null', 'bool', 'int', 'double', 'string', 'array', 'object') """ match value: case None: return "null" case bool(): return "bool" case int(): return "int" case float(): return "double" case str(): return "string" case list(): return "array" case dict(): return "object" case _: return "unknown"
# ============================================================================= # Type Checking Helpers # =============================================================================
[docs] def _is_expression(value: Any) -> bool: """ Check if value is an aggregation expression. An expression is a dict with exactly one key starting with '$' that is not a reserved field name. Args: value: Value to check Returns: True if value is an expression, False otherwise Examples: >>> _is_expression({"$sin": "$angle"}) True >>> _is_expression({"$field": "value"}) # Reserved False >>> _is_expression("$field") False >>> _is_expression(42) False """ # Reserved field names that are NOT operators (copied from expr_evaluator.constants # to avoid circular import) RESERVED_FIELDS = {"$field", "$index"} if not isinstance(value, dict): return False if len(value) != 1: return False # Could be a literal dict key = next(iter(value.keys())) return key.startswith("$") and key not in RESERVED_FIELDS
[docs] def _is_field_reference(value: Any) -> bool: """ Check if value is a field reference. Field references start with '$' but are not expressions (i.e., they're simple strings like "$field" or "$nested.field"). Args: value: Value to check Returns: True if value is a field reference, False otherwise Examples: >>> _is_field_reference("$field") True >>> _is_field_reference("$nested.field") True >>> _is_field_reference("$$ROOT") False >>> _is_field_reference({"$sin": "$angle"}) False """ return ( isinstance(value, str) and value.startswith("$") and not value.startswith("$$") )
[docs] def _is_literal(value: Any) -> bool: """ Check if value is a literal (not an expression or field reference). Literals include: numbers, strings, booleans, None, arrays, and plain dicts. Args: value: Value to check Returns: True if value is a literal, False otherwise Examples: >>> _is_literal(42) True >>> _is_literal("string") True >>> _is_literal(True) True >>> _is_literal(None) True >>> _is_literal([1, 2, 3]) True >>> _is_literal("$field") False """ if isinstance(value, str): # Strings starting with $ are field refs or variables, not literals return not value.startswith("$") # All other types are literals return True
[docs] def _is_numeric_value(value: Any) -> bool: """ Check if a value is numeric (int or float) or can be converted to a numeric value. This function determines if a value can be safely used in arithmetic operations like $inc and $mul. It considers: - int and float values as numeric (excluding bool, NaN, and infinity) - None as non-numeric (would cause issues in arithmetic) - String representations of numbers as non-numeric (to match MongoDB behavior) Args: value: The value to check Returns: bool: True if the value is numeric, False otherwise """ # Explicitly exclude boolean values (even though bool is subclass of int in Python) if isinstance(value, bool): return False # Check for actual numeric types if isinstance(value, (int, float)): # Special case: check for NaN and infinity if isinstance(value, float): import math if math.isnan(value) or math.isinf(value): return False return True # Everything else is considered non-numeric for MongoDB compatibility return False
[docs] def validate_session(session: Any | None, connection: Any) -> None: """ Validate that the session belongs to this connection. Args: session: ClientSession instance or None connection: The parent Connection or sqlite3.Connection object to validate against Raises: ValueError: If the session belongs to a different connection """ if session is not None: # connection could be neosqlite.Connection or sqlite3.Connection # session.client is neosqlite.Connection # neosqlite.Connection.db is sqlite3.Connection if session.client != connection and session.client.db != connection: raise ValueError("Session belongs to a different Connection")
__all__ = [ # Type conversion functions "_convert_to_int", "_convert_to_long", "_convert_to_double", "_convert_to_decimal", "_convert_to_string", "_convert_to_bool", "_convert_to_objectid", "_convert_to_bindata", "_convert_to_bsonbindata", "_convert_to_regex", "_convert_to_bsonregex", "_convert_to_date", "_convert_to_null", "get_bson_type", # Type checking helpers "_is_expression", "_is_field_reference", "_is_literal", "_is_numeric_value", "validate_session", ]