Source code for neosqlite.collection.query_helper.update_operations
"""
Update operations mixin for QueryHelper.
This module contains the UpdateOperationsMixin class which provides
SQL-based and Python-based update operations for NeoSQLite collections.
"""
import logging
from copy import deepcopy
from datetime import datetime
from typing import TYPE_CHECKING, Any
from ...binary import Binary
from ...exceptions import MalformedQueryException
from ...sql_utils import quote_table_name
from ..json_helpers import (
neosqlite_json_dumps,
neosqlite_json_loads,
)
from ..json_path_utils import (
parse_json_path,
)
from ..jsonb_support import json_data_column
logger = logging.getLogger(__name__)
# Import positional update functions
from .positional_update import (
_apply_positional_update,
_set_nested_field,
)
# Import utility functions
from .utils import (
_convert_bytes_to_binary,
_get_json_function,
_supports_relative_json_indexing,
_validate_inc_mul_field_value,
get_force_fallback,
)
# Import helper functions
if TYPE_CHECKING:
from .. import Collection
[docs]
class UpdateOperationsMixin:
"""
A mixin class providing update operations for QueryHelper.
This mixin assumes it will be used with a class that has:
- self.collection (with db and name attributes)
- self._jsonb_supported
- self._json_function_prefix
- self._build_simple_where_clause method
"""
collection: "Collection"
_jsonb_supported: bool
_json_function_prefix: str
_get_integer_id_for_oid: Any
[docs]
def _internal_update(
self,
doc_id: Any,
update_spec: dict[str, Any],
original_doc: dict[str, Any],
array_filters: list[dict[str, Any]] | None = None,
query_filter: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], bool]:
"""
Helper method for updating documents.
Attempts to use SQL-based updates for simple operations, falling back to
Python-based updates for complex operations.
Args:
doc_id (Any): The ID of the document to update (can be ObjectId, int, etc.).
update_spec (dict[str, Any]): The update specification.
original_doc (dict[str, Any]): The original document before the update.
array_filters (list[dict[str, Any]], optional): Filter documents for array positional operators.
query_filter (dict[str, Any], optional): The query filter for $ operator.
Returns:
tuple[dict[str, Any], bool]: The updated document and whether it was modified.
"""
# Validate $inc and $mul operations before choosing implementation
# This ensures consistent behavior between SQL and Python implementations
for op, value in update_spec.items():
if op in {"$inc", "$mul"}:
for field_name in value.keys():
# Get the current value of the field
if field_name in original_doc:
field_value = original_doc[field_name]
# Validate the field value
_validate_inc_mul_field_value(
field_name, field_value, op
)
# If field doesn't exist, it will be treated as 0, which is valid
# (no validation needed for missing fields)
# Respect the kill switch - force Python fallback if enabled
if get_force_fallback():
return self._perform_python_update(
doc_id, update_spec, original_doc, array_filters, query_filter
)
# Try to use SQL-based updates for simple operations
# Note: SQL updates don't support array_filters or positional operators, so fall back to Python if provided
if array_filters:
return self._perform_python_update(
doc_id, update_spec, original_doc, array_filters, query_filter
)
if self._can_use_sql_updates(update_spec, doc_id, original_doc):
# Use enhanced SQL update with json_insert/json_replace when possible
try:
updated_doc = self._perform_enhanced_sql_update(
doc_id, update_spec, original_doc
)
# For SQL updates, assume modified if we got a result
return updated_doc, updated_doc != original_doc
except Exception as e:
# If enhanced update fails, fall back to standard SQL update
logger.debug(
f"Enhanced update failed: {e}. Falling back to standard SQL update."
)
try:
updated_doc = self._perform_sql_update(doc_id, update_spec)
return updated_doc, updated_doc != original_doc
except Exception as e2:
logger.debug(
f"Standard SQL update failed: {e2}. Falling back to Python update."
)
return self._perform_python_update(
doc_id,
update_spec,
original_doc,
array_filters,
query_filter,
)
else:
# Fall back to Python-based updates for complex operations
return self._perform_python_update(
doc_id, update_spec, original_doc, array_filters, query_filter
)
[docs]
def _can_use_sql_updates(
self,
update_spec: dict[str, Any],
doc_id: int,
original_doc: dict[str, Any] | None = None,
) -> bool:
"""
Check if all operations in the update spec can be handled with SQL.
This method determines whether the update operations can be efficiently
executed using SQL directly, which allows for better performance compared
to iterating over each document and applying updates in Python.
Args:
update_spec (dict[str, Any]): The update operations to be checked.
doc_id (int): The document ID, which is used to determine if the update
is an upsert.
Returns:
bool: True if all operations can be handled with SQL, False otherwise.
"""
# Respect the kill switch - force fallback if enabled
if get_force_fallback():
return False
# Tier 1: Simple operations that can use json_set/json_remove
# Tier 2: More complex operations that can use SQL with some limitations
supported_ops = {
"$set",
"$unset",
"$inc",
"$mul",
"$min",
"$max",
"$pull",
"$pullAll",
"$currentDate",
"$rename",
"$setOnInsert",
}
# Also check that doc_id is not 0 (which indicates an upsert)
# Disable SQL updates for documents containing Binary objects
has_binary_values = any(
isinstance(val, bytes) and hasattr(val, "encode_for_storage")
for op in update_spec.values()
if isinstance(op, dict)
for val in op.values()
)
# Check for positional operators in field paths (not supported in SQL tier)
has_positional_operators = False
for op, value in update_spec.items():
if isinstance(value, dict):
for field_path in value.keys():
if "$" in field_path:
has_positional_operators = True
break
if has_positional_operators:
break
# Positional operators require Python fallback
if has_positional_operators:
return False
# Check for complex $push modifiers
# SQL optimization supports $each, $position, and $slice - all handled in SQL tier
# No need to check for Python fallback anymore
# Check for $pop operator
has_pop = "$pop" in update_spec
# SQL $pop requires relative indexing [#-1] support
if has_pop and not _supports_relative_json_indexing():
return False
# Check for $addToSet operator - now supports $each in SQL
# Note: We no longer return False for $each since we have SQL implementation
# Check for $pull and $pullAll - require field to exist and be a list in original_doc
if original_doc is not None:
for op in update_spec:
if op in {"$pull", "$pullAll"}:
op_spec = update_spec[op]
if isinstance(op_spec, dict):
for field in op_spec.keys():
if field not in original_doc or not isinstance(
original_doc.get(field), list
):
return False
return (
doc_id != 0
and not has_binary_values
and all(
op in supported_ops
or op in {"$push", "$bit", "$pop", "$addToSet"}
for op in update_spec.keys()
)
)
[docs]
def _perform_sql_update(
self,
doc_id: int,
update_spec: dict[str, Any],
) -> dict[str, Any]:
"""
Perform update operations using SQL JSON functions.
This method builds SQL clauses for updating document fields based on the
provided update specification. It supports both `$set` and `$unset` operations
using SQLite's `json_set` and `json_remove` functions, respectively. The
method then executes the SQL commands to apply the updates and fetches
the updated document from the database.
Args:
doc_id (int): The ID of the document to be updated.
update_spec (dict[str, Any]): A dictionary specifying the update
operations to be performed.
Returns:
dict[str, Any]: The updated document.
Raises:
RuntimeError: If no rows are updated or if an error occurs during the
update process.
"""
set_clauses = []
set_params = []
unset_clauses = []
unset_params = []
# Build SQL update clauses for each operation
for op, value in update_spec.items():
clauses, params = self._build_sql_update_clause(op, value)
if clauses:
if op == "$unset":
unset_clauses.extend(clauses)
unset_params.extend(params)
else:
set_clauses.extend(clauses)
set_params.extend(params)
# Get integer ID for the document
int_doc_id = self._get_integer_id_for_oid(doc_id)
# Execute the SQL updates using a single consolidated UPDATE statement if possible
# This significantly reduces disk I/O and transaction overhead
current_data = "data"
all_params = []
if unset_clauses:
# Handle $unset operations with json_remove
func_name = _get_json_function("remove", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(unset_clauses)})"
)
all_params.extend(unset_params)
if set_clauses:
# Handle other operations with json_set
func_name = _get_json_function("set", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(set_clauses)})"
)
all_params.extend(set_params)
if current_data != "data":
cmd = (
f"UPDATE {quote_table_name(self.collection.name)} "
f"SET data = {current_data} "
"WHERE id = ?"
)
cursor = self.collection.db.execute(cmd, all_params + [int_doc_id])
# Check if any rows were updated
if cursor.rowcount == 0:
raise RuntimeError(f"No rows updated for doc_id {doc_id}")
else:
# No operations to perform
raise RuntimeError("No valid operations to perform")
# Fetch and return the updated document
# Use the instance's JSONB support flag to determine how to select data
jsonb = self._jsonb_supported
cmd = (
f"SELECT id, {json_data_column(jsonb)} as data "
f"FROM {quote_table_name(self.collection.name)} WHERE id = ?"
)
if row := self.collection.db.execute(cmd, (int_doc_id,)).fetchone():
return self.collection._load(row[0], row[1])
# This shouldn't happen, but just in case
raise RuntimeError("Failed to fetch updated document")
[docs]
def _perform_enhanced_sql_update(
self,
doc_id: Any,
update_spec: dict[str, Any],
original_doc: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Perform update operations using SQL JSON functions with field-level granularity.
This method optimizes update operations by using specialized JSON functions
(json_insert, json_replace, etc.) based on whether fields already exist.
It provides field-level updates rather than whole-document rewrites.
Args:
doc_id (Any): The ID of the document to be updated.
update_spec (dict[str, Any]): A dictionary specifying the update
operations to be performed.
original_doc (dict[str, Any], optional): The original document before the update.
If provided, used to determine existing fields
instead of fetching again.
Returns:
dict[str, Any]: The updated document.
Raises:
RuntimeError: If no rows are updated or if an error occurs during the
update process.
"""
# Get integer ID for the document immediately to avoid UnboundLocalError
int_doc_id = self._get_integer_id_for_oid(doc_id)
# First, we need to determine which fields exist in the document
# and which are new to decide between json_insert and json_replace
# Use original_doc if provided to avoid extra fetch
if original_doc is not None:
existing_fields = (
set(original_doc.keys())
if isinstance(original_doc, dict)
else set()
)
else:
existing_fields = self._get_document_fields(doc_id)
insert_clauses = []
insert_params = []
replace_clauses = []
replace_params = []
set_clauses = [] # For backward compatibility with json_set
set_params = []
unset_clauses = []
unset_params: list[Any] = []
# Build SQL update clauses for each operation
for op, value in update_spec.items():
match op:
case "$set":
# For $set, we need to determine whether to use json_insert or json_replace
for field, field_val in value.items():
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
if isinstance(converted_val, Binary):
param_value = neosqlite_json_dumps(converted_val)
use_json_func = True
# For complex objects (dict, list), serialize them to JSON
elif isinstance(converted_val, (dict, list)):
param_value = neosqlite_json_dumps(converted_val)
use_json_func = True
else:
param_value = converted_val
use_json_func = False
# For dotted field names, we should use json_set
if "." in field and field not in existing_fields:
json_path = f"'{parse_json_path(field)}'"
if use_json_func:
set_clauses.append(f"{json_path}, json(?)")
else:
set_clauses.append(f"{json_path}, ?")
set_params.append(param_value)
else:
# Check if field exists in the document
json_path = f"'{parse_json_path(field)}'"
if field in existing_fields:
# Use json_replace for existing fields
if use_json_func:
replace_clauses.append(
f"{json_path}, json(?)"
)
else:
replace_clauses.append(f"{json_path}, ?")
replace_params.append(param_value)
else:
# Use json_insert for new fields
if use_json_func:
insert_clauses.append(
f"{json_path}, json(?)"
)
else:
insert_clauses.append(f"{json_path}, ?")
insert_params.append(param_value)
case "$unset":
# For $unset, we use json_remove
for field in value:
json_path = f"'{parse_json_path(field)}'"
unset_clauses.append(json_path)
case "$push":
# Tier 2: $push (with $each, optionally with $position and/or $slice) can use SQL
for field, push_value in value.items():
# Extract values to push (handle $each)
values_to_push = []
slice_value = None
position_value = None
if (
isinstance(push_value, dict)
and "$each" in push_value
):
# Check for $slice and $position
if "$slice" in push_value:
slice_value = push_value["$slice"]
if "$position" in push_value:
position_value = push_value["$position"]
values_to_push = push_value["$each"]
if not isinstance(values_to_push, list):
values_to_push = [values_to_push]
else:
values_to_push = [push_value]
# Check for $slice without $each
if (
isinstance(push_value, dict)
and "$slice" in push_value
and "$each" not in push_value
):
slice_value = push_value["$slice"]
# Check for $position without $each
if (
isinstance(push_value, dict)
and "$position" in push_value
and "$each" not in push_value
):
position_value = push_value["$position"]
json_path = f"'{parse_json_path(field)}'"
# Convert values to add
converted_values = [
_convert_bytes_to_binary(v) for v in values_to_push
]
# Build JSON array of new values
new_values_json = neosqlite_json_dumps(converted_values)
# Handle $position: need to reconstruct array with insertion at position
if position_value is not None:
# Position 0 = insert at beginning
# Position N = insert after N elements
# Clamp position to valid range
position = int(position_value)
if position < 0:
position = 0
if slice_value is not None and slice_value == 0:
# Slice 0 means empty array
set_clauses.append(f"{json_path}, json('[]')")
elif slice_value is not None:
# Both $position and $slice - apply slice after insertion
# Get elements before position, new values, then slice from position
slice_limit = (
slice_value if slice_value > 0 else 1000000
)
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {slice_limit} OFFSET {position}))"
)
else:
# Only $position, no $slice - insert at position
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT -1 OFFSET {position}))"
)
elif slice_value is not None:
# Handle $slice only (existing logic)
if slice_value == 0:
# Slice 0 means empty array
set_clauses.append(f"{json_path}, json('[]')")
else:
# Get existing array, concatenate with new values, then slice
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) UNION ALL SELECT value FROM json_each({new_values_json}) LIMIT {slice_value if slice_value > 0 else 1000000}))"
)
else:
# No $position or $slice - just append values using [#]
append_path = f"'{parse_json_path(field)}[#]'"
for val in values_to_push:
converted_val = _convert_bytes_to_binary(val)
if isinstance(
converted_val, (dict, list, Binary)
):
param_value = neosqlite_json_dumps(
converted_val
)
set_clauses.append(
f"{append_path}, json(?)"
)
else:
param_value = converted_val
set_clauses.append(f"{append_path}, ?")
set_params.append(param_value)
case "$pop":
# Tier 2: $pop uses json_remove with [0] or [#-1]
for field, pop_direction in value.items():
index_path = (
"[0]" if int(pop_direction) < 0 else "[#-1]"
)
json_path = f"'{parse_json_path(field)}{index_path}'"
unset_clauses.append(json_path)
case "$addToSet":
# Tier 2: $addToSet (with or without $each) can use conditional SQL
insert_func = _get_json_function(
"insert", self._jsonb_supported
)
for field, val in value.items():
json_path = f"'{parse_json_path(field)}'"
# Handle $each modifier
values_to_add = []
if isinstance(val, dict) and "$each" in val:
each_values = val["$each"]
if not isinstance(each_values, list):
each_values = [each_values]
values_to_add = each_values
else:
values_to_add = [val]
# Process each value
for each_val in values_to_add:
converted_val = _convert_bytes_to_binary(each_val)
if isinstance(converted_val, (Binary, dict, list)):
param_value = neosqlite_json_dumps(
converted_val
)
use_json = True
else:
param_value = converted_val
use_json = False
array_path = json_path
append_path = f"'{parse_json_path(field)}[#]'"
if not use_json:
cmd = (
f"UPDATE {quote_table_name(self.collection.name)} "
f"SET data = {insert_func}(data, {append_path}, ?) "
f"WHERE id = ? AND NOT EXISTS ("
f" SELECT 1 FROM json_each(data, {array_path}) "
f" WHERE value = ?"
f")"
)
self.collection.db.execute(
cmd, (param_value, int_doc_id, param_value)
)
else:
# For complex values (dict, list, Binary), use a more complex SQL
# We need to check if the JSON value already exists
cmd = (
f"UPDATE {quote_table_name(self.collection.name)} "
f"SET data = {insert_func}(data, {append_path}, json(?)) "
f"WHERE id = ? AND NOT EXISTS ("
f" SELECT 1 FROM json_each(data, {array_path}) "
f" WHERE json(value) = json(?)"
f")"
)
self.collection.db.execute(
cmd, (param_value, int_doc_id, param_value)
)
case "$bit":
# Tier 2: $bit using bitwise operators
for field, bit_spec in value.items():
json_path = f"'{parse_json_path(field)}'"
extract_func = _get_json_function(
"extract", self._jsonb_supported
)
current_expr = (
f"COALESCE({extract_func}(data, {json_path}), 0)"
)
bit_expr = current_expr
if "and" in bit_spec:
bit_expr = f"({bit_expr} & {int(bit_spec['and'])})"
if "or" in bit_spec:
bit_expr = f"({bit_expr} | {int(bit_spec['or'])})"
if "xor" in bit_spec:
xor_val = int(bit_spec["xor"])
bit_expr = f"(({bit_expr} | {xor_val}) & ~(({bit_expr}) & {xor_val}))"
set_clauses.append(f"{json_path}, {bit_expr}")
case "$currentDate":
# SQL implementation for $currentDate
for field, type_spec in value.items():
json_path = f"'{parse_json_path(field)}'"
# Determine type: true defaults to date, { $type: "timestamp" } or { $type: "date" }
if (
isinstance(type_spec, dict)
and type_spec.get("$type") == "timestamp"
):
type_value = "timestamp"
else:
type_value = "date"
# Set to current datetime ISO string
if type_value == "timestamp":
set_clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')"
)
else:
# For date type, match Python's datetime.now().isoformat() format
set_clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%f', 'now')"
)
case "$rename":
# SQL implementation for $rename using single UPDATE
# Combine json_set and json_remove in a single operation
for old_field, new_field in value.items():
old_json_path = f"'{parse_json_path(old_field)}'"
new_json_path = f"'{parse_json_path(new_field)}'"
extract_func = _get_json_function(
"extract", self._jsonb_supported
)
# First set the new field, then remove the old field
# We need to nest the operations: json_remove(json_set(data, new_path, value), old_path)
# This requires special handling - we'll use a combined approach
set_clauses.append(
f"{new_json_path}, {extract_func}(data, {old_json_path})"
)
unset_clauses.append(old_json_path)
case _:
# For other operations, use the standard approach
clauses, params = self._build_sql_update_clause(op, value)
if clauses:
set_clauses.extend(clauses)
set_params.extend(params)
# Execute updates using a single consolidated UPDATE statement if possible
# This significantly reduces disk I/O and transaction overhead
# Combine all operations into a single nested data update
# We start with 'data' and wrap it with functions in order
current_data = "data"
all_params = []
# Special handling for $rename: combine json_set and json_remove logic
has_rename = any(op == "$rename" for op in update_spec.keys())
if has_rename and set_clauses and unset_clauses:
set_func = _get_json_function("set", self._jsonb_supported)
remove_func = _get_json_function("remove", self._jsonb_supported)
current_data = f"{remove_func}({set_func}({current_data}, {', '.join(set_clauses)}), {', '.join(unset_clauses)})"
all_params.extend(set_params)
# Clear them so we don't process them again below
set_clauses = []
unset_clauses = []
# Process remaining clauses in order to build a single nested expression
if unset_clauses:
func_name = _get_json_function("remove", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(unset_clauses)})"
)
all_params.extend(unset_params)
if insert_clauses:
func_name = _get_json_function("insert", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(insert_clauses)})"
)
all_params.extend(insert_params)
if replace_clauses:
func_name = _get_json_function("replace", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(replace_clauses)})"
)
all_params.extend(replace_params)
if set_clauses:
func_name = _get_json_function("set", self._jsonb_supported)
current_data = (
f"{func_name}({current_data}, {', '.join(set_clauses)})"
)
all_params.extend(set_params)
if current_data != "data":
cmd = f"UPDATE {quote_table_name(self.collection.name)} SET data = {current_data} WHERE id = ?"
cursor = self.collection.db.execute(cmd, all_params + [int_doc_id])
if cursor.rowcount == 0:
raise RuntimeError(f"No rows updated for doc_id {doc_id}")
elif (
not has_rename
): # If has_rename was True, it was already handled or clauses were cleared
# This check might be redundant if we ensured at least one op, but safe to keep
pass
# Fetch updated document
jsonb = self._jsonb_supported
cmd = (
f"SELECT id, {json_data_column(jsonb)} as data "
f"FROM {quote_table_name(self.collection.name)} WHERE id = ?"
)
if row := self.collection.db.execute(cmd, (int_doc_id,)).fetchone():
return self.collection._load(row[0], row[1])
raise RuntimeError("Failed to fetch updated document")
[docs]
def _get_document_fields(self, doc_id: Any) -> set[str]:
"""
Get the set of field names in a document.
This method extracts the field names from a document to determine which fields
already exist and which are new. This is used to decide between json_insert
and json_replace operations.
Args:
doc_id (Any): The ID of the document to analyze.
Returns:
set: A set of field names in the document.
"""
# Get the integer ID for the document
int_doc_id = self._get_integer_id_for_oid(doc_id)
# Fetch the document data
jsonb = self._jsonb_supported
cmd = (
f"SELECT {json_data_column(jsonb)} as data "
f"FROM {quote_table_name(self.collection.name)} WHERE id = ?"
)
row = self.collection.db.execute(cmd, (int_doc_id,)).fetchone()
if not row:
return set()
# Parse the JSON to get field names
try:
doc_data = neosqlite_json_loads(
row[0] if self._jsonb_supported else row[0]
)
if isinstance(doc_data, dict):
return set(doc_data.keys())
else:
return set()
except Exception as e:
# If we can't parse the document, return empty set
logger.debug(
f"Failed to parse document for indexed fields extraction: {e}"
)
return set()
[docs]
def _build_update_clause(
self,
update: dict[str, Any],
) -> tuple[str, list[Any]] | None:
"""
Build the SQL update clause based on the provided update operations.
Args:
update (dict[str, Any]): A dictionary containing update operations.
Returns:
tuple[str, list[Any]] | None: A tuple containing the SQL update clause
and parameters, or None if no update
clauses are generated.
"""
set_clauses = []
params = []
for op, value in update.items():
match op:
case "$set":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(f"{json_path}, ?")
params.append(field_val)
case "$inc":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) + ?"
)
params.append(field_val)
case "$mul":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) * ?"
)
params.append(field_val)
case "$min":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(
f"{json_path}, min({self._json_function_prefix}_extract(data, {json_path}), ?)"
)
params.append(field_val)
case "$max":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(
f"{json_path}, max({self._json_function_prefix}_extract(data, {json_path}), ?)"
)
params.append(field_val)
case "$unset":
# For $unset, we use json_remove
for field in value:
json_path = f"'{parse_json_path(field)}'"
set_clauses.append(json_path)
# json_remove has a different syntax
if set_clauses:
func_name = _get_json_function(
"remove", self._jsonb_supported
)
return (
f"data = {func_name}(data, {', '.join(set_clauses)})",
params,
)
else:
# No fields to unset
return None
case "$currentDate":
for field, type_spec in value.items():
json_path = f"'{parse_json_path(field)}'"
# Determine type: true defaults to date, { $type: "timestamp" } or { $type: "date" }
if (
isinstance(type_spec, dict)
and type_spec.get("$type") == "timestamp"
):
type_value = "timestamp"
else:
type_value = "date"
# Set to current datetime ISO string to match Python implementation
# Use strftime for consistent ISO format (Python uses isoformat() like '2026-03-15T12:34:56.789012')
if type_value == "timestamp":
set_clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')"
)
else:
# For date type, match Python's datetime.now().isoformat() format
set_clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%f', 'now')"
)
case "$pop":
# For $pop, we use json_remove
if not _supports_relative_json_indexing():
return None
for field, pop_direction in value.items():
# 1: remove last, -1: remove first
index_path = (
"[0]" if int(pop_direction) < 0 else "[#-1]"
)
json_path = f"'{parse_json_path(field)}{index_path}'"
set_clauses.append(json_path)
# json_remove has a different syntax
if set_clauses:
func_name = _get_json_function(
"remove", self._jsonb_supported
)
return (
f"data = {func_name}(data, {', '.join(set_clauses)})",
params,
)
else:
return None
case "$push":
# Optimized $push (with $each, optionally with $slice) using [#]
for field, push_value in value.items():
# Handle $each, $position, and $slice
values_to_push = []
slice_value = None
position_value = None
if (
isinstance(push_value, dict)
and "$each" in push_value
):
# Check for $slice and $position
if "$slice" in push_value:
slice_value = push_value["$slice"]
if "$position" in push_value:
position_value = push_value["$position"]
# SQL optimization supports $each, $position, and $slice
values_to_push = push_value["$each"]
if not isinstance(values_to_push, list):
values_to_push = [values_to_push]
else:
values_to_push = [push_value]
# Check for $slice without $each
if (
isinstance(push_value, dict)
and "$slice" in push_value
and "$each" not in push_value
):
slice_value = push_value["$slice"]
# Check for $position without $each
if (
isinstance(push_value, dict)
and "$position" in push_value
and "$each" not in push_value
):
position_value = push_value["$position"]
json_path = f"'{parse_json_path(field)}'"
# Handle $position or $slice: need to reconstruct array
if (
position_value is not None
or slice_value is not None
):
# Convert values to add
converted_values = [
_convert_bytes_to_binary(v)
for v in values_to_push
]
# Check if any values are complex
has_complex = any(
isinstance(v, (dict, list, Binary))
for v in converted_values
)
# Build JSON array of new values
new_values_json = neosqlite_json_dumps(
converted_values
)
# Handle $position: insert at specific position
if position_value is not None:
position = int(position_value)
if position < 0:
position = 0
if slice_value is not None and slice_value == 0:
# Slice 0 means empty array
set_clauses.append(
f"{json_path}, json('[]')"
)
elif slice_value is not None:
# Both $position and $slice
slice_limit = (
slice_value
if slice_value > 0
else 1000000
)
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {slice_limit} OFFSET {position}))"
)
else:
# Only $position, no $slice
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT -1 OFFSET {position}))"
)
elif slice_value is not None:
# Handle $slice only
if slice_value == 0:
# Slice 0 means empty array
set_clauses.append(
f"{json_path}, json('[]')"
)
else:
# Get existing array, concatenate with new values, then slice
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) UNION ALL SELECT value FROM json_each({new_values_json}) LIMIT {slice_value if slice_value > 0 else 1000000}))"
)
else:
# No $position or $slice - just append values
append_path = f"'{parse_json_path(field)}[#]'"
for val in values_to_push:
converted_val = _convert_bytes_to_binary(val)
set_clauses.append(f"{append_path}, ?")
params.append(converted_val)
case "$setOnInsert":
# $setOnInsert only applies on upsert (doc_id == 0)
# For existing documents (doc_id != 0), this is a no-op
# We can safely skip it in the SQL path
pass
case "$rename":
# $rename is handled in _perform_enhanced_sql_update for proper
# set + unset ordering. Return None here to use enhanced path.
return None
case "$pull":
# SQL optimization for $pull: filter array elements using json_each
for field, pull_value in value.items():
json_path = f"'{parse_json_path(field)}'"
converted_val = _convert_bytes_to_binary(pull_value)
if isinstance(converted_val, (dict, list, Binary)):
pull_value_json = neosqlite_json_dumps(
converted_val
)
set_clauses.append(
f"{json_path}, (SELECT json_group_array(json(value)) FROM json_each(json_extract(data, {json_path})) WHERE json(value) != {pull_value_json} OR json(value) IS NULL)"
)
else:
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM json_each(json_extract(data, {json_path})) WHERE value != ?)"
)
params.append(converted_val)
case "$pullAll":
# SQL optimization for $pullAll: filter multiple values from array
for field, pull_values in value.items():
if not isinstance(pull_values, list):
pull_values = [pull_values]
json_path = f"'{parse_json_path(field)}'"
has_complex = any(
isinstance(
_convert_bytes_to_binary(v),
(dict, list, Binary),
)
for v in pull_values
)
if has_complex:
pull_values_json = [
neosqlite_json_dumps(
_convert_bytes_to_binary(v)
)
for v in pull_values
]
conditions = " OR ".join(
[
f"json(value) != {v}"
for v in pull_values_json
]
)
set_clauses.append(
f"{json_path}, (SELECT json_group_array(json(value)) FROM json_each(json_extract(data, {json_path})) WHERE {conditions} OR json(value) IS NULL)"
)
else:
placeholders = ", ".join(["?" for _ in pull_values])
converted_values = [
_convert_bytes_to_binary(v) for v in pull_values
]
set_clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM json_each(json_extract(data, {json_path})) WHERE value NOT IN ({placeholders}))"
)
params.extend(converted_values)
case _:
return None # Fallback for unsupported operators
if not set_clauses:
return None
# For $unset, we already returned above
if "$unset" not in update:
func_name = _get_json_function("set", self._jsonb_supported)
return f"data = {func_name}(data, {', '.join(set_clauses)})", params
else:
# This case should have been handled above
return None
[docs]
def _build_sql_update_clause(
self,
op: str,
value: Any,
) -> tuple[list[str], list[Any]]:
"""
Build SQL update clause for a single operation.
Args:
op (str): The update operation, such as "$set", "$inc", "$mul", etc.
value (Any): The value associated with the update operation.
Returns:
tuple[list[str], list[Any]]: A tuple containing the SQL update clauses
and parameters.
"""
clauses = []
params = []
match op:
case "$set":
for field, field_val in value.items():
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
json_path = f"'{parse_json_path(field)}'"
if isinstance(converted_val, Binary):
clauses.append(f"{json_path}, json(?)")
params.append(neosqlite_json_dumps(converted_val))
else:
clauses.append(f"{json_path}, ?")
params.append(converted_val)
case "$inc":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
if isinstance(converted_val, Binary):
clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) + json(?)"
)
params.append(neosqlite_json_dumps(converted_val))
else:
clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) + ?"
)
params.append(converted_val)
case "$mul":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
if isinstance(converted_val, Binary):
clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) * json(?)"
)
params.append(neosqlite_json_dumps(converted_val))
else:
clauses.append(
f"{json_path}, COALESCE({self._json_function_prefix}_extract(data, {json_path}), 0) * ?"
)
params.append(converted_val)
case "$min":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
clauses.append(
f"{json_path}, min({self._json_function_prefix}_extract(data, {json_path}), ?)"
)
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
if isinstance(converted_val, Binary):
clauses[-1] = (
f"{json_path}, min({self._json_function_prefix}_extract(data, {json_path}), json(?))"
)
params.append(neosqlite_json_dumps(converted_val))
else:
params.append(converted_val)
case "$max":
for field, field_val in value.items():
json_path = f"'{parse_json_path(field)}'"
clauses.append(
f"{json_path}, max({self._json_function_prefix}_extract(data, {json_path}), ?)"
)
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(field_val)
# If it's a Binary object, serialize it to JSON and use json() function
if isinstance(converted_val, Binary):
clauses[-1] = (
f"{json_path}, max({self._json_function_prefix}_extract(data, {json_path}), json(?))"
)
params.append(neosqlite_json_dumps(converted_val))
else:
params.append(converted_val)
case "$push":
# Optimized $push (with $each, optionally with $position and/or $slice) using [#]
for field, push_value in value.items():
# Handle $each, $position, and $slice
values_to_push = []
slice_value = None
position_value = None
if isinstance(push_value, dict) and "$each" in push_value:
# Check for $slice and $position
if "$slice" in push_value:
slice_value = push_value["$slice"]
if "$position" in push_value:
position_value = push_value["$position"]
# SQL optimization supports $each, $position, and $slice
values_to_push = push_value["$each"]
if not isinstance(values_to_push, list):
values_to_push = [values_to_push]
else:
values_to_push = [push_value]
# Check for $slice without $each
if (
isinstance(push_value, dict)
and "$slice" in push_value
and "$each" not in push_value
):
slice_value = push_value["$slice"]
# Check for $position without $each
if (
isinstance(push_value, dict)
and "$position" in push_value
and "$each" not in push_value
):
position_value = push_value["$position"]
json_path = f"'{parse_json_path(field)}'"
# Handle $position or $slice: need to reconstruct array
if position_value is not None or slice_value is not None:
# Convert values to add
converted_values = [
_convert_bytes_to_binary(v) for v in values_to_push
]
# Build JSON array of new values
new_values_json = neosqlite_json_dumps(converted_values)
# Handle $position: insert at specific position
if position_value is not None:
position = int(position_value)
if position < 0:
position = 0
if slice_value is not None and slice_value == 0:
# Slice 0 means empty array
clauses.append(f"{json_path}, json('[]')")
elif slice_value is not None:
# Both $position and $slice
slice_limit = (
slice_value if slice_value > 0 else 1000000
)
clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {slice_limit} OFFSET {position}))"
)
else:
# Only $position, no $slice
clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT {position} UNION ALL SELECT value FROM json_each({new_values_json}) UNION ALL SELECT value FROM json_each(json_extract(data, {json_path})) LIMIT -1 OFFSET {position}))"
)
elif slice_value is not None:
# Handle $slice only
if slice_value == 0:
# Slice 0 means empty array
clauses.append(f"{json_path}, json('[]')")
else:
# Get existing array, concatenate with new values, then slice
clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM (SELECT value FROM json_each(json_extract(data, {json_path})) UNION ALL SELECT value FROM json_each({new_values_json}) LIMIT {slice_value if slice_value > 0 else 1000000}))"
)
else:
# No $position or $slice - just append values
append_path = f"'{parse_json_path(field)}[#]'"
for val in values_to_push:
# Convert bytes to Binary for proper JSON serialization
converted_val = _convert_bytes_to_binary(val)
if isinstance(converted_val, Binary):
clauses.append(f"{append_path}, json(?)")
params.append(
neosqlite_json_dumps(converted_val)
)
elif isinstance(converted_val, (dict, list)):
clauses.append(f"{append_path}, json(?)")
params.append(
neosqlite_json_dumps(converted_val)
)
else:
clauses.append(f"{append_path}, ?")
params.append(converted_val)
case "$pop":
# Tier 2: $pop uses json_remove with [0] or [#-1]
if not _supports_relative_json_indexing():
return [], []
for field, pop_direction in value.items():
# 1: remove last, -1: remove first
index_path = "[0]" if int(pop_direction) < 0 else "[#-1]"
json_path = f"'{parse_json_path(field)}{index_path}'"
clauses.append(json_path)
case "$addToSet":
# Tier 2: $addToSet (without $each) can use conditional SQL
for field, val in value.items():
json_path = f"'{parse_json_path(field)}'"
# Convert value for proper parameter handling
converted_val = _convert_bytes_to_binary(val)
if isinstance(converted_val, Binary):
param_value = neosqlite_json_dumps(converted_val)
use_json = True
elif isinstance(converted_val, (dict, list)):
param_value = neosqlite_json_dumps(converted_val)
use_json = True
else:
param_value = converted_val
use_json = False
if not use_json:
# Build SQL that only inserts if value not in array
insert_func = _get_json_function(
"insert", self._jsonb_supported
)
array_path = json_path
append_path = f"'{parse_json_path(field)}[#]'"
# Use a CASE expression to conditionally call json_insert
exists_subquery = f"EXISTS (SELECT 1 FROM json_each(data, {array_path}) WHERE value = ?)"
clauses.append(
f"{array_path}, CASE WHEN {exists_subquery} THEN data ELSE {insert_func}(data, {append_path}, ?) END"
)
params.extend([param_value, param_value])
else:
# Complex values - fall back to Python
return [], []
case "$unset":
# For $unset, we use json_remove
for field in value:
json_path = f"'{parse_json_path(field)}'"
clauses.append(json_path)
case "$pull":
# SQL optimization for $pull: filter array elements using json_each
for field, pull_value in value.items():
json_path = f"'{parse_json_path(field)}'"
converted_val = _convert_bytes_to_binary(pull_value)
if isinstance(converted_val, (dict, list, Binary)):
# For complex values, serialize to JSON for comparison
pull_value_json = neosqlite_json_dumps(converted_val)
clauses.append(
f"{json_path}, (SELECT json_group_array(json(value)) FROM json_each(json_extract(data, {json_path})) WHERE json(value) != {pull_value_json} OR json(value) IS NULL)"
)
else:
# For simple values, compare directly
clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM json_each(json_extract(data, {json_path})) WHERE value != ?)"
)
params.append(converted_val)
case "$pullAll":
# SQL optimization for $pullAll: filter multiple values from array
for field, pull_values in value.items():
if not isinstance(pull_values, list):
pull_values = [pull_values]
json_path = f"'{parse_json_path(field)}'"
# Check if any values are complex (dict, list, Binary)
has_complex = any(
isinstance(
_convert_bytes_to_binary(v), (dict, list, Binary)
)
for v in pull_values
)
if has_complex:
# For complex values, serialize each and build JSON comparison
pull_values_json = [
neosqlite_json_dumps(_convert_bytes_to_binary(v))
for v in pull_values
]
conditions = " OR ".join(
[f"json(value) != {v}" for v in pull_values_json]
)
clauses.append(
f"{json_path}, (SELECT json_group_array(json(value)) FROM json_each(json_extract(data, {json_path})) WHERE {conditions} OR json(value) IS NULL)"
)
else:
# For simple values, use IN clause
placeholders = ", ".join(["?" for _ in pull_values])
converted_values = [
_convert_bytes_to_binary(v) for v in pull_values
]
clauses.append(
f"{json_path}, (SELECT json_group_array(value) FROM json_each(json_extract(data, {json_path})) WHERE value NOT IN ({placeholders}))"
)
params.extend(converted_values)
case "$currentDate":
# SQL implementation for $currentDate
for field, type_spec in value.items():
json_path = f"'{parse_json_path(field)}'"
# Determine type: true defaults to date, { $type: "timestamp" } or { $type: "date" }
if (
isinstance(type_spec, dict)
and type_spec.get("$type") == "timestamp"
):
type_value = "timestamp"
else:
type_value = "date"
# Set to current datetime ISO string to match Python implementation
if type_value == "timestamp":
clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')"
)
else:
# For date type, match Python's datetime.now().isoformat() format
clauses.append(
f"{json_path}, strftime('%Y-%m-%dT%H:%M:%f', 'now')"
)
case "$setOnInsert":
# $setOnInsert only applies on upsert (doc_id == 0)
# For existing documents (doc_id != 0), this is a no-op
# We can safely skip it - return empty clauses
pass
case "$rename":
# SQL implementation for $rename
# $rename requires: get value, set at new path, remove old path
# This is handled in _perform_enhanced_sql_update specially
# Return empty to trigger fallback for complex cases
return [], []
return clauses, params
[docs]
def _perform_python_update(
self,
doc_id: Any,
update_spec: dict[str, Any],
original_doc: dict[str, Any],
array_filters: list[dict[str, Any]] | None = None,
query_filter: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], bool]:
"""
Perform update operations using Python-based logic.
Args:
doc_id (Any): The document ID of the document to update (can be ObjectId, int, etc.).
update_spec (dict[str, Any]): A dictionary specifying the update
operations to perform.
original_doc (dict[str, Any]): The original document before applying
the updates.
array_filters (list[dict[str, Any]], optional): Filter documents for array positional operators.
query_filter (dict[str, Any], optional): The query filter for $ operator.
Returns:
tuple[dict[str, Any], bool]: The updated document and whether it was modified.
"""
doc_to_update = deepcopy(original_doc)
for op, value in update_spec.items():
match op:
case "$set":
# Handle positional operators in field paths
for k, v in value.items():
if "$" in k:
# Use positional update
_apply_positional_update(
doc_to_update, k, v, array_filters, query_filter
)
else:
_set_nested_field(doc_to_update, k, v)
case "$unset":
for k in value:
doc_to_update.pop(k, None)
case "$inc":
for k, v in value.items():
# Validate that the field value is numeric before performing operation
current_value = doc_to_update.get(k)
_validate_inc_mul_field_value(k, current_value, "$inc")
doc_to_update[k] = doc_to_update.get(k, 0) + v
case "$push":
for k, v in value.items():
# Check if v is a dict with modifiers ($each, $position, $slice)
if isinstance(v, dict) and "$each" in v:
# Get the array to push to
current_list = doc_to_update.setdefault(k, [])
# Get values to add
values_to_add = v["$each"]
if not isinstance(values_to_add, list):
values_to_add = [values_to_add]
# Handle $position modifier
position = v.get("$position")
if position is not None:
# Insert at specific position
for i, val in enumerate(values_to_add):
current_list.insert(position + i, val)
else:
# Append to end
current_list.extend(values_to_add)
# Handle $slice modifier (after adding values)
slice_val = v.get("$slice")
if slice_val is not None:
if slice_val == 0:
doc_to_update[k] = []
elif slice_val > 0:
# Keep first N elements
doc_to_update[k] = current_list[:slice_val]
else:
# Keep last N elements (negative slice)
doc_to_update[k] = current_list[slice_val:]
else:
# Simple push (no modifiers)
doc_to_update.setdefault(k, []).append(v)
case "$addToSet":
for k, v in value.items():
current_list = doc_to_update.setdefault(k, [])
# Handle $each modifier
values_to_add = []
if isinstance(v, dict) and "$each" in v:
each_values = v["$each"]
if not isinstance(each_values, list):
each_values = [each_values]
values_to_add = each_values
else:
values_to_add = [v]
# Add each value if not already present
for val in values_to_add:
if val not in current_list:
current_list.append(val)
case "$pull":
for k, v in value.items():
if k in doc_to_update:
doc_to_update[k] = [
item for item in doc_to_update[k] if item != v
]
case "$pullAll":
for k, v in value.items():
if k in doc_to_update and isinstance(v, (list, tuple)):
# Only process if the field is a list
if isinstance(doc_to_update[k], list):
# Remove all instances of values in the array
# Use list instead of set to handle unhashable types
values_to_remove = list(v)
new_list = [
item
for item in doc_to_update[k]
if item not in values_to_remove
]
# Only update if the list actually changed
if new_list != doc_to_update[k]:
doc_to_update[k] = new_list
case "$pop":
for k, v in value.items():
if v == 1:
doc_to_update.get(k, []).pop()
elif v == -1:
doc_to_update.get(k, []).pop(0)
case "$bit":
for k, bit_op in value.items():
if not isinstance(bit_op, dict):
raise MalformedQueryException(
"$bit operator requires a dict with 'and', 'or', or 'xor'"
)
# Get current value (default to 0)
current_val = doc_to_update.get(k, 0)
# Apply bitwise operations
if "and" in bit_op:
current_val &= bit_op["and"]
if "or" in bit_op:
current_val |= bit_op["or"]
if "xor" in bit_op:
current_val ^= bit_op["xor"]
doc_to_update[k] = current_val
case "$rename":
for k, v in value.items():
if k in doc_to_update:
doc_to_update[v] = doc_to_update.pop(k)
case "$mul":
for k, v in value.items():
# Validate that the field value is numeric before performing operation
if k in doc_to_update:
_validate_inc_mul_field_value(
k, doc_to_update[k], "$mul"
)
doc_to_update[k] *= v
case "$min":
for k, v in value.items():
if k not in doc_to_update or doc_to_update[k] > v:
doc_to_update[k] = v
case "$max":
for k, v in value.items():
if k not in doc_to_update or doc_to_update[k] < v:
doc_to_update[k] = v
case "$currentDate":
for k, type_spec in value.items():
doc_to_update[k] = datetime.now().isoformat()
case "$setOnInsert":
# Only apply on upsert (doc_id == 0)
if doc_id == 0:
for k, v in value.items():
doc_to_update[k] = v
case _:
raise MalformedQueryException(
f"Update operator '{op}' not supported"
)
# If this is an upsert (doc_id == 0), we don't update the database
# We just return the updated document for insertion by the caller
if doc_id != 0:
# Convert the doc_id to integer ID for internal operations
int_doc_id = self._get_integer_id_for_oid(doc_id)
self.collection.db.execute(
f"UPDATE {quote_table_name(self.collection.name)} SET data = ? WHERE id = ?",
(neosqlite_json_dumps(doc_to_update), int_doc_id),
)
# Check if document was actually modified
was_modified = doc_to_update != original_doc
return doc_to_update, was_modified
[docs]
@staticmethod
def _validate_inc_mul_types_sql(
db: Any,
collection_name: str,
where_clause: str | None,
where_params: list[Any],
update: dict[str, Any],
jsonb_supported: bool,
) -> bool:
"""
Validate that fields in $inc/$mul operations are numeric.
This checks the JSON type of each field being incremented/multiplied
to ensure they're numeric types.
Args:
db: Database connection
collection_name: Name of the collection
where_clause: The translated WHERE clause
where_params: Parameters for the WHERE clause
update: The update operations
jsonb_supported: Whether JSONB is supported
Returns:
True if all fields are numeric or don't exist, False if any field is non-numeric
"""
from .utils import _get_json_function
fields_to_check = []
for op in ("$inc", "$mul"):
if op in update:
fields_to_check.extend(update[op].keys())
if not fields_to_check:
return True
json_func = _get_json_function("type", jsonb_supported)
# Build a single query to check all fields
select_expressions = []
for field in fields_to_check:
# json_type(data, '$.field') or jsonb_type(data, '$.field')
json_path = f"$.{field}"
select_expressions.append(f"{json_func}(data, '{json_path}')")
if not where_clause:
where_clause = "WHERE 1=1"
cmd = f"SELECT {', '.join(select_expressions)} FROM {quote_table_name(collection_name)} {where_clause} LIMIT 1"
try:
cursor = db.execute(cmd, where_params)
row = cursor.fetchone()
if row:
for field_type in row:
if field_type is not None:
if jsonb_supported:
# JSONB type returns 'number' for both int and float
if field_type not in (
"number",
"null",
"integer",
"real",
):
return False
else:
# Standard JSON type returns 'integer' or 'real'
if field_type not in ("null", "integer", "real"):
return False
# If no row matches, the update will be a no-op anyway, so it's safe to use fast path
return True
except Exception as e:
# Fallback to slow path on any SQL error
logger.debug(f"Fast path check failed due to SQL error: {e}")
return False