Source code for populse_db.storage

import importlib
import pathlib
import types
import typing
from contextlib import contextmanager

from .database import type_to_str
from .storage_api import StorageAPI


[docs] class Storage: default_collection = "_" schema_collection = "_schema" default_field = "_" default_document_id = "_"
[docs] def __init__( self, database_file: str | pathlib.Path, timeout: float | None = 10000, create: bool = False, echo_sql: typing.TextIO | None = None, secret: str | None = None, ): if isinstance(database_file, pathlib.Path): database_file = str(database_file) self.storage_api = StorageAPI( database_file, timeout=timeout, create=create, echo_sql=echo_sql, secret=secret, ) self._read_access_token = None self._write_access_token = None self._current_data_session = None
[docs] def access_token(self, write): if write: if self._write_access_token is None: self._write_access_token = self.storage_api.access_token(write=True) return self._write_access_token if self._read_access_token is None: self._read_access_token = self.storage_api.access_token(write=False) return self._read_access_token
[docs] @contextmanager def data(self, exclusive=None, write=False, create=False): if self._current_data_session is not None: storage_session, is_exclusive, is_write = self._current_data_session if exclusive and is_exclusive is not True: raise RuntimeError( "Impossible to get an exclusive data session because another non exclusive data session exists" ) if write and not is_write: raise RuntimeError( "Impossible to get an write data session because another read data session exists" ) yield storage_session else: connection_id = self.storage_api.connect( self.access_token(write), exclusive=exclusive, write=write, create=create, ) if connection_id is not None: try: storage_session = StorageSession(self.storage_api, connection_id) self._current_data_session = (storage_session, exclusive, write) yield storage_session self._current_data_session = None self.storage_api.disconnect(connection_id, rollback=False) except Exception: self._current_data_session = None self.storage_api.disconnect(connection_id, rollback=True) raise else: raise RuntimeError("Failed to establish a data session.")
[docs] @contextmanager def schema(self): connection_id = self.storage_api.connect( self.access_token(True), exclusive=True, write=True, create=True ) try: yield SchemaSession(self.storage_api, connection_id) self.storage_api.disconnect(connection_id, rollback=False) except Exception: self.storage_api.disconnect(connection_id, rollback=True) raise
[docs] def start_session(self, exclusive=None, write=False, create=False): connection_id = self.storage_api.connect( self.access_token(write), exclusive=exclusive, write=write, create=create ) return StorageSession(self.storage_api, connection_id)
[docs] def end_session(self, storage_session, rollback=False): self.storage_api.disconnect(storage_session._connection_id, rollback=rollback)
class SchemaSession: @classmethod def find_schema(cls, name, version_selection=None): module = importlib.import_module(name) schemas_to_collections = getattr(module, "_schemas_to_collections", None) if schemas_to_collections is None: schemas_to_collections = {} index = 0 for schema_declaration in module.schemas: unknown_items = set(schema_declaration) - {"version", "schema"} if unknown_items: raise ValueError( f"invalid item(s) in schema {module.__name__} (list index={index}): {','.join(unknown_items)}" ) version = schema_declaration.get("version") if version is None: raise ValueError( f"version missing in schema {module.__name__} (list index={index})" ) sversion = version.split(".") if len(sversion) != 3: raise ValueError( f"invalid version in schema {module.__name__} (list index={index}): {version}" ) schema_content = schema_declaration.get("schema") if schema_content is None: raise ValueError( f"schema missing in schema {module.__name__} (list index={index})" ) try: collections = cls._parse_schema_content(schema_content) except Exception as e: raise ValueError( f"invalid schema definition {module.__name__} (list index={index})" ) from e schema_to_collections = { "name": module.__name__, "version": version, "collections": collections, } if version in schemas_to_collections: raise ValueError( f"two schemas with version {version} defined in {module.__name__}" ) schemas_to_collections[version] = schema_to_collections if None not in schemas_to_collections: schemas_to_collections[None] = schema_to_collections short_version = ".".join(sversion[:2]) if short_version not in schemas_to_collections: schemas_to_collections[short_version] = schema_to_collections index += 1 module._schemas_to_collections = schemas_to_collections return schemas_to_collections.get(version_selection) @classmethod def _parse_field(cls, name, definition): error = True if isinstance(definition, type | types.GenericAlias): type_str = type_to_str(definition) kwargs = {} error = False elif isinstance(definition, str): type_str = definition kwargs = {} error = False elif isinstance(definition, list) and len(definition) == 2: if isinstance(definition[1], dict): if isinstance(definition[0], type | types.GenericAlias): type_str = type_to_str(definition[0]) kwargs = definition[1] error = False elif isinstance(definition[0], str): type_str = definition[0] kwargs = definition[1] error = False if error: raise ValueError(f'invalid definition for field "{name}"') return [type_str, kwargs] @classmethod def _parse_schema_content(cls, schema_content): if not isinstance(schema_content, dict): raise ValueError("schema must be a dict") collections = {} for k, v in schema_content.items(): if isinstance(v, dict): fields = {kk: cls._parse_field(kk, vv) for kk, vv in v.items()} for name, d in fields.items(): _, kwargs = d if name == Storage.default_field: raise ValueError(f'invalid field name: "{name}') if "primary_key" in kwargs: raise ValueError( f'primary key not allowed for unique document "{k}" in field "{name}"' ) fields[Storage.default_field] = ["str", {"primary_key": True}] collections[k] = fields elif isinstance(v, list) and len(v) == 1 and isinstance(v[0], dict): fields = {kk: cls._parse_field(kk, vv) for kk, vv in v[0].items()} primary = False for name, d in fields.items(): type, kwargs = d if name == Storage.default_field: raise ValueError(f'invalid field name: "{name}') if "primary_key" in kwargs: primary = True if not primary: raise ValueError(f'no primary key defined for collection "{k}"') collections[k] = fields else: type_str, kwargs = cls._parse_field(k, v) fields = collections.setdefault(Storage.default_collection, {}) if Storage.default_field not in fields: fields[Storage.default_field] = ["str", {"primary_key": True}] fields[k] = [type_str, kwargs] return collections def __init__(self, server, connection_id): super().__setattr__("_storage_api", server) super().__setattr__("_connection_id", connection_id) def add_schema(self, name, version=None): schema_to_collections = self.find_schema(name, version) if not schema_to_collections: raise ValueError(f"cannot find schema {name} with version {version}") self._storage_api.add_schema_collections( self._connection_id, schema_to_collections ) def add_collection(self, name, primary_key): # Make primary_key json compatible if isinstance(primary_key, dict): primary_key = dict( (k, (v if isinstance(v, str) else type_to_str(v))) for k, v in primary_key.items() ) self._storage_api.add_collection(self._connection_id, name, primary_key) def add_field( self, collection_name, field_name, field_type, description=None, index=False ): if isinstance(field_type, type | types.GenericAlias): field_type = type_to_str(field_type) self._storage_api.add_field( self._connection_id, collection_name, field_name, field_type, description, index, ) def remove_field(self, collection_name, field_name): """ Removes a specified field from a collection in the storage system. Args: collection_name (str): The name of the collection containing the field. field_name (str): The name of the field to be removed. This method delegates the operation to the storage API, ensuring the field is removed from the specified collection within the active connection. """ self._storage_api.remove_field(self._connection_id, collection_name, field_name) def clear_database(self, keep_settings=False): return self._storage_api.clear_database(self._connection_id, keep_settings) @contextmanager def data(self): yield StorageSession(self._storage_api, self._connection_id) class StorageSession: def __init__(self, storage_api, connection_id, path=None): super().__setattr__("_storage_api", storage_api) super().__setattr__("_connection_id", connection_id) super().__setattr__("_path", path or []) def __getitem__(self, key): return self.__class__( self._storage_api, self._connection_id, self._path + [key] ) def __getattr__(self, key): return self[key] def __setitem__(self, key, value): self._storage_api.set(self._connection_id, self._path + [key], value) def __setattr__(self, key, value): self[key] = value def __delitem__(self, key): self._storage_api.delete(self._connection_id, self._path + [key]) def __delattr__(self, key): del self[key] def primary_key(self): return self._storage_api.primary_key(self._connection_id, self._path) def set(self, value): self._storage_api.set(self._connection_id, self._path, value) def update(self, value): self._storage_api.update(self._connection_id, self._path, value) def get(self, default=None, fields=None, as_list=False, distinct=False): return self._storage_api.get( self._connection_id, self._path, default=default, fields=fields, as_list=as_list, distinct=distinct, ) def count(self, query=None): return self._storage_api.count(self._connection_id, self._path, query=query) def append(self, value): return self._storage_api.append(self._connection_id, self._path, value) def distinct_values(self, field): return self._storage_api.distinct_values(self._connection_id, self._path, field) def search(self, query=None, fields=None, as_list=None, distinct=False, **kwargs): if kwargs and query: raise ValueError("Cannot combine query and equality research") if kwargs: query = " AND ".join(f'{{{k}}}=="{v}"' for k, v in kwargs.items()) if isinstance(fields, tuple): fields = list(fields) return self._storage_api.search( self._connection_id, self._path, query, fields=fields, as_list=as_list, distinct=distinct, ) def search_and_delete(self, query=None, **kwargs): if kwargs and query: raise ValueError("Cannot combine query and equality research") if kwargs: query = " AND ".join(f'{{{k}}}=="{v}"' for k, v in kwargs.items()) return self._storage_api.search_and_delete( self._connection_id, self._path, query ) def has_collection(self, collection): return self._storage_api.has_collection( self._connection_id, self._path, collection ) def collection_names(self): return self._storage_api.collection_names( self._connection_id, self._path, ) def keys(self): return self._storage_api.keys( self._connection_id, self._path, )