Copyright (c) 2018-2022, CodeLV.
Distributed under the terms of the MIT License.
The full license is in the file LICENSE.txt, distributed with this software.
Created on Aug 2, 2018
import asyncio
import datetime
import functools
import logging
import weakref
from decimal import Decimal
from typing import Any
from typing import Callable as CallableType
from typing import ClassVar
from typing import Dict as DictType
from typing import Generic, Iterator
from typing import List as ListType
from typing import Optional, Sequence
from typing import Set as SetType
from typing import Tuple as TupleType
from typing import Type, TypeVar, Union, cast
import sqlalchemy as sa
from atom import api
from atom.api import (
from sqlalchemy.engine import ddl
from sqlalchemy.sql import schema
from sqlalchemy.sql.type_api import TypeEngine
from .base import (
# kwargs reserved for sqlalchemy table columns
FK_TYPES = (Instance, Typed, ForwardInstance, ForwardTyped)
# ops that can be used with django-style queries
"eq": "__eq__",
"gt": "__gt__",
"gte": "__ge__",
"ge": "__ge__",
"lt": "__lt__",
"le": "__le__",
"lte": "__le__",
"all": "all_",
"any": "any_",
"ne": "__ne__",
"not": "__ne__",
"contains": "contains",
"endswith": "endswith",
"ilike": "ilike",
"in": "in_",
"is": "is_",
"is_distinct_from": "is_distinct_from",
"isnot": "isnot",
"isnot_distinct_from": "isnot_distinct_from",
"like": "like",
"match": "match",
"notilike": "notilike",
"notlike": "notlike",
"notin": "notin_",
"startswith": "startswith",
# Fields supported on the django style Meta class of a model
# Constraint naming conventions
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
# Using "ck_%(table_name)s_%(constraint_name)s" is preferred but it causes
# issues using Bool on mysql
"ck": "ck_%(table_name)s_%(column_0_N_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
log = logging.getLogger("atomdb.sql")
QueryType = Union[str, sa.sql.expression.Executable]
T = TypeVar("T", bound="SQLModel")
[docs]def find_sql_models() -> Iterator[Type["SQLModel"]]:
"""Finds all non-abstract imported SQLModels by looking up subclasses
of the SQLModel.
cls: SQLModel
for model in find_subclasses(SQLModel):
# Get model Meta class
meta = getattr(model, "Meta", None)
if meta:
# If this is marked as abstract ignore it
if getattr(meta, "abstract", False):
yield model
[docs]class Relation(ContainerList):
"""A member which serves as a fk relation backref"""
__slots__ = ("_to",)
def __init__(self, item: CallableType[[], Type[Model]], default=None):
super().__init__(ForwardInstance(item), default=default) # type: ignore
self._to: Optional[Type[Model]] = None
[docs] def resolve(self) -> Type[Model]:
return self.to
def to(self) -> Type[Model]:
to = self._to
if to is None:
types = resolve_member_types(self.validate_mode[-1])
assert types is not None
to = self._to = types[0]
return to
[docs]def py_type_to_sql_column(
model: Type[Model],
member: Member,
types: Union[Type, TupleType[Type, ...]],
) -> TypeEngine:
"""Convert the python type to an alchemy table column type"""
if isinstance(types, tuple):
cls, *subtypes = types
cls = types
if issubclass(cls, JSONModel):
return sa.JSON(**kwargs)
elif issubclass(cls, SQLModel):
name = f"{cls.__model__}.{cls.__pk__}"
cls.__backrefs__.add((model, member))
# Determine the type of the foreign key
column = create_table_column(cls, cls._id)
return (column.type, sa.ForeignKey(name, **kwargs))
elif issubclass(cls, str):
return sa.String(**kwargs)
elif issubclass(cls, int):
return sa.Integer(**kwargs)
elif issubclass(cls, float):
return sa.Float(**kwargs)
elif issubclass(cls, dict):
return sa.JSON(**kwargs)
elif issubclass(cls, (tuple, list)):
return sa.ARRAY(**kwargs)
elif issubclass(cls, datetime.datetime):
return sa.DateTime(**kwargs)
elif issubclass(cls, datetime.date):
return sa.Date(**kwargs)
elif issubclass(cls, datetime.time):
return sa.Time(**kwargs)
elif issubclass(cls, datetime.timedelta):
return sa.Interval(**kwargs)
elif issubclass(cls, (bytes, bytearray)):
return sa.LargeBinary(**kwargs)
elif issubclass(cls, Decimal):
return sa.Numeric(**kwargs)
raise NotImplementedError(
f"A type for {member.name} of {model} ({cls}) could not be "
f"determined automatically, please specify it manually by tagging it "
f"with .tag(column=<sqlalchemy column>) or set `store=False`"
[docs]def resolve_member_column(
model: Type["SQLModel"], field: str, related_clauses: Optional[ListType[str]] = None
) -> sa.Column:
"""Get the sqlalchemy column for the given model and field.
model: atomdb.sql.Model
The model to lookup
field: String
The field name
result: sa.Column
A tuple containing the through table (or None) and the
sqlalchemy column.
if model is None or not field:
raise ValueError("Invalid field %s on %s" % (field, model))
# Walk the relations
if "__" in field:
path = field
*related_parts, field = field.split("__")
clause = "__".join(related_parts)
if related_clauses is not None and clause not in related_clauses:
# Follow the FK lookups
# Rename so the original lookup path is retained if an error occurs
rel_model = model
for part in related_parts:
m = rel_model.members().get(part)
if m is None:
raise ValueError("Invalid field %s on %s" % (path, model))
rel_model_types = resolve_member_types(m)
if rel_model_types is None:
raise ValueError("Invalid field %s on %s" % (path, model))
rel_model = rel_model_types[0]
model = rel_model
# Lookup the member
m = model.members().get(field)
if m is not None:
if m.metadata:
# If the field has a different name assigned use that
field = m.metadata.get("name", field)
if isinstance(m, Relation):
# Support looking up columns through a relation by the pk
model = m.to # type: ignore
# Add the through table to the related clauses if needed
if related_clauses is not None and field not in related_clauses:
field = model.__pk__
# Finally get the column from the table
col = model.objects.table.columns.get(field)
if col is None:
raise ValueError("Invalid field %s on %s" % (field, model))
return col
def resolve_relation(
model: Type["SQLModel"], field: str
) -> TupleType[Member, Type[Model], Member, sa.Column]:
"""Lookup a Relation.
model: SQLModel
The model to lookup.
field: str
Path to a Relation, Typed, or Instance marked with store=False
result: tuple[Member, SQLModel, Member, sa.Column]
A tuple of the related field on the given model, the other model
it points to, and the field on that model that points back to this
model, and the Column.
relation = model.members().get(field)
RelModel: Optional[Type[Model]] = None
if isinstance(relation, Relation):
# RelModel has a many to one relation back to model
RelModel = cast(Relation, relation).to
elif isinstance(relation, FK_TYPES) and not is_db_field(relation):
# Note: If is_db_field passes the user should use select_related
# instead of prefetch related.
types = resolve_member_types(relation)
if types and len(types) == 1 and issubclass(types[0], Model):
# RelModel has a one to one relation back to model
RelModel = types[0]
if RelModel is not None:
m = cast(Member, relation)
# Find the referring member
# TODO: This does not support multiple backrefs
for other_model, referring_member in model.__backrefs__:
if RelModel is other_model:
meta = referring_member.metadata or {}
name = meta.get("name", referring_member.name)
rel_col = RelModel.objects.table.c[name]
return (m, other_model, referring_member, rel_col)
raise ValueError("Invalid prefetch relation '%s' from %s" % (field, model))
[docs]def atom_member_to_sql_column(
model: Type["SQLModel"], member: Member, **kwargs
) -> TypeEngine:
"""Convert the atom member type to an sqlalchemy table column type
See https://docs.sqlalchemy.org/en/latest/core/type_basics.html
if hasattr(member, "get_column_type"):
# Allow custom members to define the column type programatically
return member.get_column_type(model) # type: ignore
elif isinstance(member, api.Str):
return sa.String(**kwargs)
elif hasattr(api, "Unicode") and isinstance(member, api.Unicode): # type: ignore
return sa.Unicode(**kwargs) # type: ignore
elif isinstance(member, api.Bool):
return sa.Boolean()
elif isinstance(member, api.Int):
return sa.Integer()
elif hasattr(api, "Long") and isinstance(member, api.Long): # type: ignore
return sa.BigInteger()
elif isinstance(member, api.Float):
return sa.Float()
elif isinstance(member, api.Range):
# TODO: Add min / max
return sa.Integer()
elif isinstance(member, api.FloatRange):
# TODO: Add min / max
return sa.Float()
elif isinstance(member, api.Enum):
return sa.Enum(*member.items, name=member.name)
elif hasattr(api, "IntEnum") and isinstance(member, api.IntEnum): # type: ignore
return sa.SmallInteger()
elif isinstance(member, FK_TYPES):
value_type = resolve_member_types(member)
if value_type is None:
raise TypeError("Instance and Typed members must specify types")
return py_type_to_sql_column(model, member, value_type, **kwargs)
elif isinstance(member, Relation):
# Relations are for backrefs
item_type = member.validate_mode[-1]
if item_type is None:
raise TypeError("Relation members must specify types")
# Resolve the item type
value_type = resolve_member_types(item_type)
if value_type is None:
raise TypeError("Relation members must specify types")
return None # Relations are just syntactic sugar
elif isinstance(member, (api.List, api.ContainerList, api.Tuple)):
item_type = member.validate_mode[-1]
if item_type is None:
raise TypeError("List and Tuple members must specify types")
# Resolve the item type
value_type = resolve_member_types(item_type)
if value_type is None:
raise TypeError("List and Tuple members must specify types")
if issubclass(value_type[0], JSONModel):
return sa.JSON(**kwargs)
t = py_type_to_sql_column(model, member, value_type, **kwargs)
if isinstance(t, tuple):
t = t[0] # Use only the value type
return sa.ARRAY(t)
elif isinstance(member, api.Bytes):
return sa.LargeBinary(**kwargs)
elif isinstance(member, api.Dict):
return sa.JSON(**kwargs)
raise NotImplementedError(
f"A column for {member.name} of {model} could not be determined "
f"automatically, please specify it manually by tagging it "
f"with .tag(column=<sqlalchemy column>)"
[docs]def create_table_column(model: Type["SQLModel"], member: Member) -> sa.Column:
"""Converts an Atom member into a sqlalchemy data type.
model: Model
The model which owns this member
member: Member
The atom member
column: Column
An sqlalchemy column
1. https://docs.sqlalchemy.org/en/latest/core/types.html
get_column = getattr(member, "get_column", None)
if get_column is not None:
# Allow custom members to define the column programatically
return get_column(model)
# Copy the metadata as we modify it
metadata = member.metadata.copy() if member.metadata else {}
# If a column is specified use that
if "column" in metadata:
return metadata["column"]
metadata.pop("store", None)
column_name = metadata.pop("name", member.name)
column_type = metadata.pop("type", None)
# Extract column kwargs from member metadata
kwargs = {}
if k in metadata:
kwargs[k] = metadata.pop(k)
if column_type is None:
args = atom_member_to_sql_column(model, member, **metadata)
if args is None:
return None
if not isinstance(args, (tuple, list)):
args = (args,)
elif isinstance(column_type, (tuple, list)):
args = column_type
args = (column_type,)
return sa.Column(column_name, *args, **kwargs)
[docs]def create_table(model: Type["SQLModel"], metadata: sa.MetaData) -> sa.Table:
"""Create an sqlalchemy table by inspecting the Model and generating
a column for each member.
model: SQLModel
The atom model
1. https://docs.sqlalchemy.org/en/latest/core/metadata.html
name = model.__model__
members = model.members()
args = []
# Add columns
for f in model.__fields__:
column = create_table_column(model, members[f])
if column is not None:
# Add table metadata
meta = getattr(model, "Meta", None)
if meta:
# Abstract field
abstract = getattr(meta, "abstract", False)
if abstract:
raise NotImplementedError(
f"Tables cannot be created for abstract models: {model}"
# Unique constraints
unique_together = getattr(meta, "unique_together", None)
if unique_together is not None:
if not isinstance(unique_together, (tuple, list)):
raise TypeError("Meta unique_together must be a tuple or list")
if isinstance(unique_together[0], str):
unique_together = [unique_together]
for constraint in unique_together:
if isinstance(constraint, (tuple, list)):
constraint = sa.UniqueConstraint(*constraint)
# Check constraints
constraints = getattr(meta, "constraints", None)
if constraints is not None:
if not isinstance(constraints, (tuple, list)):
raise TypeError("Meta constraints must be a tuple or list")
# Composite indexes
composite_indexes = getattr(meta, "composite_indexes", None)
if composite_indexes is not None:
if not isinstance(composite_indexes, (tuple, list)):
raise TypeError("Meta composite_indexes must be a tuple or list")
for index in composite_indexes:
if not isinstance(index, (tuple, list)):
raise TypeError("Index must be a tuple or list")
# Create table
table = sa.Table(name, metadata, *args)
# Hook up any database triggers defined
triggers = getattr(meta, "triggers", None)
if triggers is not None:
if isinstance(triggers, dict):
triggers = list(triggers.items())
elif not isinstance(triggers, (tuple, list)):
raise TypeError("Meta triggers must be a dict, tuple, or list")
for event, trigger in triggers:
# Allow triggers to be a lambda that generates one
if not isinstance(trigger, sa.schema.DDL) and callable(trigger):
trigger = trigger()
sa.event.listen(table, event, trigger)
return table
[docs]class SQLModelSerializer(ModelSerializer):
"""Uses sqlalchemy to lookup the model."""
[docs] def flatten_object(self, obj: Model, scope: ScopeType) -> Any:
"""Serialize a model for entering into the database
obj: Model
The object to unflatten
scope: Dict
The scope of references available for circular lookups
result: Object
The flattened object
if isinstance(obj, SQLModel):
return obj._id
return type(obj).serializer.flatten_object(obj, scope)
[docs] async def get_object_state(self, obj, state, scope):
"""Load the object state if needed. Since the __model__ is not saved
to the db tables with SQL we know that if it's "probably" there
because a query was used.
ModelType = obj.__class__
if "__model__" in state:
return state # Joined already
q = ModelType.objects.query(None, _id=state["_id"])
return await ModelType.objects.fetchone(q)
def _default_registry(self):
"""Add all sql and json models to the registry"""
registry = JSONSerializer.instance().registry.copy()
registry.update({m.__model__: m for m in find_sql_models()})
return registry
[docs]class SQLModelManager(ModelManager):
"""Manages models via aiopg, aiomysql, or similar libraries supporting
SQLAlchemy tables. It stores a table for each class and when accessed
on a Model subclass it returns a table proxy binding.
#: Constraint naming convenctions
conventions = Dict(default=CONSTRAINT_NAMING_CONVENTIONS)
#: Metadata
metadata = Instance(sa.MetaData)
#: Table proxy cache
proxies = Dict()
#: Cache results.
cache = Bool(True)
def _default_metadata(self) -> sa.MetaData:
binding = SQLBinding(manager=self)
return sa.MetaData(binding, naming_convention=self.conventions)
[docs] def create_tables(self) -> DictType[Type["SQLModel"], sa.Table]:
"""Create sqlalchemy tables for all registered SQLModels"""
tables = {}
for cls in find_sql_models():
table = cls.__table__
if table is None:
table = self.create_table_and_restore_fn(cls)
if not table.metadata.bind:
table.metadata.bind = SQLBinding(manager=self, table=table)
tables[cls] = table
return tables
[docs] def create_table_and_restore_fn(self, cls: Type["SQLModel"]) -> sa.Table:
"""Generate the sqlalchemy table and optimized restore function.
This is done here to make sure that foreign and forwarded members
are now resolved.
assert cls.__table__ is None
table = cls.__table__ = create_table(cls, self.metadata)
cls.__generated_restorestate__ = generate_sql_restorestate(cls)
return table
def __get__(
self, obj: T, cls: Optional[Type[T]] = None
) -> Union["SQLTableProxy[T]", "SQLModelManager"]:
"""Retrieve the table for the requested object or class."""
cls = cls or obj.__class__
if not issubclass(cls, Model):
return self # Only return the client when used from a Model
proxy = self.proxies.get(cls)
if proxy is None:
table = cls.__table__
if table is None:
table = self.create_table_and_restore_fn(cls)
proxy = self.proxies[cls] = SQLTableProxy(table=table, model=cls)
return proxy
def _default_database(self):
raise EnvironmentError(
"No database engine has been set. Use "
"SQLModelManager.instance().database = <db>"
[docs]class ConnectionProxy(Atom):
"""An wapper for a connection to be used with async with syntax that
does nothing but passes the existing connection when entered.
connection = Value()
async def __aenter__(self):
return self.connection
async def __aexit__(self, exc_type, exc, tb):
[docs]class SQLTableProxy(Atom, Generic[T]):
#: Table this is a proxy to
table = Instance(sa.Table, optional=False)
#: Model which owns the table
model = ForwardSubclass(lambda: SQLModel)
#: Cache of pk: obj using weakrefs
cache = Typed(weakref.WeakValueDictionary, ())
#: Key used to pull the connection out of filter kwargs
connection_kwarg = Str("connection")
#: Key used to pass the force restore option
restore_kwarg = Str("force_restore")
#: Reference to the aiomysql or aiopg Engine
#: This is used to get a connection from the connection pool.
def engine(self):
"""Retrieve the database engine."""
db = self.table.bind.manager.database
if isinstance(db, dict):
return db[self.model.__database__]
return db
[docs] def connection(self, connection=None):
"""Create a new connection or the return given connection as an async
contextual object.
connection: Database connection or None
The connection to return
connection: Database connection
The database connection or one that may be used with async with
if connection is None:
return self.engine.acquire()
return ConnectionProxy(connection=connection)
[docs] def create_table(self):
"""A wrapper for create which catches the create queries then executes
table = self.table
return table.bind.wait()
[docs] def drop_table(self):
table = self.table
return table.bind.wait()
[docs] async def execute(self, *args, **kwargs):
connection = kwargs.pop(self.connection_kwarg, None)
async with self.connection(connection) as conn:
return await conn.execute(*args, **kwargs)
[docs] async def fetchall(self, query: QueryType, connection=None):
"""Fetch all results for the query.
query: String or Query
The query to execute
connection: Database connection
The connection to use or a new one will be created
rows; List
List of rows returned, NOT objects
async with self.connection(connection) as conn:
r = await conn.execute(query)
return await r.fetchall()
[docs] async def fetchmany(self, query, size=None, connection=None):
"""Fetch size results for the query.
query: String or Query
The query to execute
size: Int or None
The number of results to fetch
connection: Database connection
The connection to use or a new one will be created
rows: List
List of rows returned, NOT objects
async with self.connection(connection) as conn:
r = await conn.execute(query)
return await r.fetchmany(size)
[docs] async def fetchone(self, query: QueryType, connection=None):
"""Fetch a single result for the query.
query: String or Query
The query to execute
connection: Database connection
The connection to use or a new one will be created
rows: Object or None
The row returned or None
async with self.connection(connection) as conn:
r = await conn.execute(query)
return await r.fetchone()
[docs] async def scalar(self, query: QueryType, connection=None):
"""Fetch the scalar result for the query.
query: String or Query
The query to execute
connection: Database connection
The connection to use or a new one will be created
result: Object or None
The the first column of the first row or None
async with self.connection(connection) as conn:
r = await conn.execute(query)
return await r.scalar()
[docs] async def get_or_create(self, **filters) -> TupleType[T, bool]:
"""Get or create a model matching the given criteria
filters: Dict
The filters to use to retrieve the object
result: Tuple[Model, Bool]
A tuple of the object and a bool indicating if it was just created
obj = await self.get(**filters)
if obj is not None:
return (obj, False)
connection_kwarg = self.connection_kwarg
connection = filters.get(connection_kwarg)
state = {
k: v for k, v in filters.items() if "__" not in k and k != connection_kwarg
obj = self.model(**state)
await obj.save(force_insert=True, connection=connection)
return (obj, True)
[docs] async def create(self, **state) -> T:
"""Create a and save model with the given state.
The connection parameter is popped from this state.
state: Dict
The state to use to initialize the object.
result: Tuple[Model, Bool]
A tuple of the object and a bool indicating if it was just created
connection = state.pop(self.connection_kwarg, None)
obj = cast(T, self.model(**state))
await obj.save(force_insert=True, connection=connection)
return obj
def __getattr__(self, name: str):
"""All other fields are delegated to the query set"""
qs: SQLQuerySet[T] = SQLQuerySet(proxy=self)
return getattr(qs, name)
[docs]class SQLQuerySet(Atom, Generic[T]):
#: Proxy
proxy = Instance(SQLTableProxy, optional=False)
connection = Value()
filter_clauses = List()
related_clauses = List()
prefetch_clauses = List()
outer_join = Bool()
order_clauses = List()
distinct_clauses = List()
limit_count = Int()
query_offset = Int()
force_restore = Bool()
[docs] def clone(self, **kwargs) -> "SQLQuerySet[T]":
state = self.__getstate__()
return self.__class__(**state)
[docs] def query(self, query_type: str = "select", *columns, **kwargs):
if kwargs:
return self.filter(**kwargs).query(query_type)
p = self.proxy
from_table = p.table
tables = [from_table]
model = p.model
use_labels = bool(self.related_clauses)
outer_join = self.outer_join
for clause in self.related_clauses:
from_table = p.table
rel_model = model
# Walk the fk relations
for part in clause.split("__"):
m = rel_model.members().get(part)
assert m is not None, f"{rel_model} has no field {part}"
rel_model_types = resolve_member_types(m)
assert rel_model_types is not None
rel_model = rel_model_types[0]
assert issubclass(rel_model, Model)
table = rel_model.objects.table
from_table = sa.join(from_table, table, isouter=outer_join)
if query_type == "select":
q = sa.select(columns or tables, use_labels=use_labels)
q = q.select_from(from_table)
elif query_type == "delete":
q = sa.delete(from_table)
elif query_type == "update":
q = sa.update(from_table)
raise ValueError("Unsupported query type")
if self.distinct_clauses:
q = q.distinct(*self.distinct_clauses)
if self.filter_clauses:
if len(self.filter_clauses) == 1:
q = q.where(self.filter_clauses[0])
q = q.where(sa.and_(*self.filter_clauses))
if self.order_clauses:
q = q.order_by(*self.order_clauses)
if self.limit_count:
q = q.limit(self.limit_count)
if self.query_offset:
q = q.offset(self.query_offset)
return q
[docs] def order_by(self, *args):
"""Order the query by the given fields.
args: List[str or column]
Fields to order by. A "-" prefix denotes decending.
query: SQLQuerySet
A clone of this queryset with the ordering terms added.
order_clauses = self.order_clauses[:]
related_clauses = self.related_clauses[:]
model = self.proxy.model
for arg in args:
if isinstance(arg, str):
# Convert django-style to sqlalchemy ordering column
if arg[0] == "-":
field = arg[1:]
ascending = False
field = arg
ascending = True
col = resolve_member_column(model, field, related_clauses)
if ascending:
clause = col.asc()
clause = col.desc()
clause = arg
if clause not in order_clauses:
return self.clone(order_clauses=order_clauses, related_clauses=related_clauses)
[docs] def distinct(self, *args):
"""Apply distinct on the given column.
args: List[str or column]
Fields that must be distinct.
query: SQLQuerySet
A clone of this queryset with the distinct terms added.
distinct_clauses = self.distinct_clauses[:]
related_clauses = self.related_clauses[:]
model = self.proxy.model
for arg in args:
if isinstance(arg, str):
# Convert name to sqlalchemy column
clause = resolve_member_column(model, arg, related_clauses)
clause = arg
if clause not in distinct_clauses:
return self.clone(
distinct_clauses=distinct_clauses, related_clauses=related_clauses
[docs] def filter(self, *args, **kwargs: DictType[str, Any]):
"""Filter the query by the given parameters. This accepts sqlalchemy
filters by arguments and django-style parameters as kwargs.
args: List
List of sqlalchemy filters
kwargs: Dict[str, object]
Django style filters to use
query: SQLQuerySet
A clone of this queryset with the filter terms added.
p = self.proxy
filter_clauses = self.filter_clauses + list(args)
related_clauses = self.related_clauses[:]
connection_kwarg, restore_kwarg = p.connection_kwarg, p.restore_kwarg
# Build the filter operations
for k, v in kwargs.items():
if k == connection_kwarg or k == restore_kwarg:
model = p.model
op = "eq"
if "__" in k:
parts = k.split("__")
if parts[-1] in QUERY_OPS:
op = parts[-1]
k = "__".join(parts[:-1])
col = resolve_member_column(model, k, related_clauses)
# Support lookups by model
if isinstance(v, Model):
v = v.serializer.flatten_object(v, scope={})
elif op in ("in", "notin"):
# Flatten lists when using in or notin ops
v = model.serializer.flatten(v, scope={})
clause = getattr(col, QUERY_OPS[op])(v)
return self.clone(
connection=kwargs.get(connection_kwarg, self.connection),
force_restore=kwargs.get(restore_kwarg, self.force_restore),
def __getitem__(self, key):
if isinstance(key, slice):
offset = key.start or 0
limit = key.stop - key.start if key.stop else 0
elif isinstance(key, int):
limit = 1
offset = key
raise TypeError("Invalid key")
if offset < 0:
raise ValueError("Cannot use a negative offset")
if limit < 0:
raise ValueError("Cannot use a negative limit")
return self.clone(limit_count=limit, query_offset=offset)
[docs] def limit(self, limit: int):
return self.clone(limit_count=limit)
[docs] def offset(self, offset: int):
return self.clone(query_offset=offset)
# -------------------------------------------------------------------------
# Query execution API
# -------------------------------------------------------------------------
[docs] async def values(
distinct: bool = False,
flat: bool = False,
group_by: Optional[Sequence[Union[str, sa.Column]]] = None,
) -> Sequence[Any]:
"""Returns the results as a list of dict instead of models.
args: List[str or column]
List of columns to select
distinct: Bool
Return only distinct rows
flat: Bool
Requires exactly one arg and will flatten the result into a single
list of values.
group_by: List[str or column]
Optional Columns to group by
results: List
List of results depending on the parameters described above
if flat and len(args) != 1:
raise ValueError("Values with flat=True can only have one param")
if args:
model = self.proxy.model
columns = []
for col in args:
if isinstance(col, str):
col = resolve_member_column(model, col)
q = self.query("select", *columns)
q = self.query("select")
if group_by is not None:
q = q.group_by(group_by)
if distinct:
q = q.distinct()
cursor = await self.proxy.fetchall(q, connection=self.connection)
if flat:
return [row[0] for row in cursor]
return cursor
[docs] async def count(self, *args, **kwargs) -> int:
if args or kwargs:
return await self.filter(*args, **kwargs).count()
subq = self.query("select").alias("subquery")
q = sa.func.count().select().select_from(subq)
return await self.proxy.scalar(q, connection=self.connection)
[docs] def max(self, *columns):
return self.aggregate(*columns, func=sa.func.max)
[docs] def min(self, *columns):
return self.aggregate(*columns, func=sa.func.min)
[docs] def mode(self, *columns):
return self.aggregate(*columns, func=sa.func.mode)
[docs] def sum(self, *columns):
return self.aggregate(*columns, func=sa.func.sum)
[docs] def aggregate(self, *args, func=None):
model = self.proxy.model
columns = []
for col in args:
if isinstance(col, str):
col = resolve_member_column(model, col)
columns.append(func(col) if func is not None else col)
subq = self.query("select").alias("subquery")
q = sa.select(columns).select_from(subq)
return self.proxy.fetchone(q, connection=self.connection)
[docs] async def exists(self, *args, **kwargs) -> bool:
if args or kwargs:
return await self.filter(*args, **kwargs).exists()
q = sa.exists(self.query("select")).select()
return await self.proxy.scalar(q, connection=self.connection)
[docs] async def delete(self, *args, **kwargs):
if args or kwargs:
return await self.filter(*args, **kwargs).delete()
q = self.query("delete")
return await self.proxy.execute(q, connection=self.connection)
[docs] async def update(self, **values):
"""Perform an update of the given values."""
# Translate any renamed fields back to the database value
for py_name, db_name in self.proxy.model.__renamed_fields__.items():
if py_name in values:
values[db_name] = values.pop(py_name)
q = self.query("update").values(**values)
return await self.proxy.execute(q, connection=self.connection)
def __await__(self):
# So await Model.objects.filter() works
f = asyncio.ensure_future(self.all())
yield from f
return f.result()
[docs] async def all(self, *args, **kwargs) -> Sequence[T]:
"""Get the all results matching the query. This will force restore the
items even if it was in the cache.
results: list[Model]
The models entry matching the query
if args or kwargs:
return await self.filter(*args, **kwargs).all()
cache = await self.prefetch()
q = self.query("select")
restore = self.proxy.model.restore
cursor = await self.proxy.fetchall(q, connection=self.connection)
force = self.force_restore
return [
cast(T, await restore(row, force=force, prefetched=cache)) for row in cursor
[docs] async def get(self, *args, **kwargs) -> Optional[T]:
"""Get the first result matching the query. Unlike django this will
NOT raise an error if multiple objects would be returned or an entry
does not exist. This will force restore the item even if it was in the
model: Optional[Model]
The first entry matching the query
if args or kwargs:
return await self.filter(*args, **kwargs).get()
q = self.query("select")
row = await self.proxy.fetchone(q, connection=self.connection)
if row is None:
return None
cache = await self.prefetch()
model = self.proxy.model
force = self.force_restore
return cast(T, await model.restore(row, force=force, prefetched=cache))
[docs] async def prefetch(self) -> Optional[DictType[Any, StateType]]:
"""Perform a prefetch lookup and populate the cache."""
if not self.prefetch_clauses:
return None
# Cache is a mapping of this model's pk to related member field values
cache: DictType[Any, StateType] = {}
model = self.proxy.model
sub_query = self.query("select", model.objects.table.c[model.__pk__])
# Perform a query for each related field
for field in self.prefetch_clauses:
#: TDOO: This only works with a single relation
m, RelModel, ref_member, rel_col = resolve_relation(model, field)
results = await RelModel.objects.filter(
rel_col.in_(sub_query), connection=self.connection
# Group the results by the this models pk
# Eg if Email.attachments is a relation to Attachments
# This will group by the Email value
if isinstance(m, Relation):
# Get list of items
for r in results:
pk = ref_member.get_slot(r)._id
prefetched_state = cache.get(pk)
if prefetched_state is None:
prefetched_state = cache[pk] = {field: []}
relation_values = prefetched_state.get(field)
if relation_values is None:
relation_values = prefetched_state[field] = []
for r in results:
pk = ref_member.get_slot(r)._id
prefetched_state = cache.get(pk)
if prefetched_state is None:
prefetched_state = cache[pk] = {}
prefetched_state[field] = r
return cache
[docs]class SQLBinding(Atom):
#: Model Manager
manager = Instance(SQLModelManager)
#: The queue
queue = ContainerList()
engine = property(lambda s: s)
def name(self):
return self.dialect.name
def dialect(self):
"""Get the dialect of the database."""
db = self.manager.database
if isinstance(db, dict):
db = db["default"]
return db.dialect
[docs] def schema_for_object(self, obj):
return obj.schema
[docs] def contextual_connect(self, **kwargs):
return self
[docs] def connect(self, **kwargs):
return self
[docs] def execution_options(self, **kw):
return self
[docs] def compiler(self, statement, parameters, **kwargs):
return self.dialect.compiler(statement, parameters, engine=self, **kwargs)
[docs] def create(self, entity, **kwargs):
kwargs["checkfirst"] = False
node = ddl.SchemaGenerator(self.dialect, self, **kwargs)
[docs] def drop(self, entity, **kwargs):
kwargs["checkfirst"] = False
node = ddl.SchemaDropper(self.dialect, self, **kwargs)
def _run_ddl_visitor(self, visitorcallable, element, connection=None, **kwargs):
kwargs["checkfirst"] = False
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
kwargs["checkfirst"] = False
node = visitorcallable(self.dialect, self, **kwargs)
[docs] def execute(self, object_, *multiparams, **params):
self.queue.append((object_, multiparams, params))
[docs] async def wait(self):
db = self.manager.database
if isinstance(db, dict):
engine = db["default"]
engine = db
result = None
async with engine.acquire() as conn:
while self.queue:
op, args, kwargs = self.queue.pop(0)
result = await conn.execute(op, args)
self.queue = [] # Wipe queue on error
return result
[docs]async def get_cached_model(cls: Type[T], pk: Any, state: StateType) -> Optional[T]:
"""Retrieve a model from the cache using the given pk. If the cached
object does not exist attempt to restore it from the state otherwise create
a model that has not been loaded and only contains the id.
cls: Type[SQLModel]
The class to lookup.
pk: Any
The primary key to look for.
state: StateType
The state from a join query.
obj: Optional[SQLModel]
If the pk is not None an instance of cls.
if cls.__joined_pk__ in state and state[cls.__joined_pk__]:
return await cls.restore(state) # Restore from joined row result
if not pk:
return None
cache = cls.objects.cache
obj = cache.get(pk)
if obj is not None:
return obj # item is already in the cache
# Create an unloaded model
obj = cls.__new__(cls)
cache[pk] = obj
obj._id = pk
return obj
[docs]def generate_sql_restorestate(cls: Type["SQLModel"]) -> RestoreStateFn:
"""Generate an optimized restore function for the SQL model. The generated
function creates "inline" dict key lookups for the table columns that
may have been joined or renamed. This avoids having to do this at runtime.
template = [
"async def __restorestate__(self, state, scope=None):",
"if '__model__' in state and state['__model__'] != self.__model__:",
" name = state['__model__']",
" raise ValueError(",
" f'Trying to use {name} state for {self.__model__} object'",
" )",
"scope = scope or {}",
"if '__ref__' in state and state['__ref__'] is not None:",
" scope[state['__ref__']] = self",
on_error = cls.__on_error__
default_unflatten = cls.serializer.unflatten
setters = []
excluded = {"__model__", "__ref__", "__restored__"}
for f, m in cls.members().items():
if f in excluded:
meta = m.metadata or {}
order = meta.get("setstate_order", 1000)
# Allow tagging a custom unflatten fn
unflatten = meta.get("unflatten", default_unflatten)
setters.append((order, f, m, unflatten))
setters.sort(key=lambda it: it[0])
namespace: DictType[str, Any] = {
"default_unflatten": default_unflatten,
"get_cached_model": get_cached_model,
# The state dict may have data from multiple tables that have been joined
# together. This handles that case.
table_name = cls.__model__
for order, f, m, unflatten in setters:
if m.metadata is not None:
col = m.metadata.get("name", f)
col = f
k = f"{table_name}_{col}"
# Since f, col, and k are potentially an untrusted input, make sure they are
# valid python identifiers to prevent unintended code being generated.
if not f.isidentifier():
raise ValueError(f"Field '{f}' cannot be used for code generation")
# TODO: Do proper column name validation
if not k.replace(".", "_").isidentifier():
raise ValueError(f"Key '{k}' cannot be used for code generation")
if not col.isidentifier():
raise ValueError(f"Renamed '{col}' cannot be used for code generation")
# TODO: Is there a better way to check for multiple keys?
if f in cls.__renamed_fields__:
# Make sure renamed fields are checked for
template.append(f"if '{col}' in state or '{k}' in state or '{f}' in state:")
f" v = state['{col}' if '{col}' in state else ("
f"'{k}' if '{k}' in state else '{f}')]"
elif col in cls.__fields__ and not isinstance(m, Relation):
# Expression to retrieve the value
# Always check the joined type first
template.append(f"if '{f}' in state or '{k}' in state:")
template.append(f" v = state['{k}' if '{k}' in state else '{f}']")
template.append(f"if '{f}' in state:")
template.append(f" v = state['{f}']")
# If a custom unflatten is not provided use the member type information
# to pick the most efficient way to restore the value
if unflatten is default_unflatten:
RelModel = None
if isinstance(m, FK_TYPES):
types = resolve_member_types(m)
if types and len(types) == 1 and issubclass(types[0], Model):
RelModel = types[0]
if RelModel is not None:
# TODO: This is fine for Typed members but not Instance..
# as it may need to be a subclass
namespace[f"rel_model_{f}"] = RelModel
if issubclass(RelModel, JSONModel):
obj = f"await rel_model_{f}.restore(v)"
obj = f"await get_cached_model(rel_model_{f}, v, state)"
# Only convert if the object has not already been restored
expr = "\n ".join(
f"if isinstance(v, rel_model_{f}):",
f" self.{f} = v",
f" self.{f} = {obj}",
elif is_primitive_member(m):
expr = f"self.{f} = v"
expr = f"self.{f} = await default_unflatten(v, scope)"
# Use provided unflatten function
namespace[f"unflatten_{f}"] = unflatten
if asyncio.iscoroutinefunction(unflatten):
expr = f"self.{f} = await unflatten_{f}(v, scope)"
expr = f"self.{f} = unflatten_{f}(v, scope)"
if on_error == "raise":
template.append(f" {expr}")
if on_error == "log":
handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
handler = "pass"
" try:",
f" {expr}",
" except Exception as e:",
f" {handler}",
# Update restored state
template.append("self.__restored__ = True")
source = "\n ".join(template)
# print("\n----------------------------------------\n")
# print(cls)
# print(source)
# print("\n----------------------------------------\n")
return generate_function(source, namespace, "__restorestate__")
[docs]class SQLModel(Model, metaclass=SQLMeta):
"""A model that can be saved and restored to and from a database supported
by sqlalchemy.
#: Primary key field name
__pk__: ClassVar[str]
#: Table name and primary key
__joined_pk__: ClassVar[str]
#: Models which link back to this
__backrefs__: ClassVar[SetType[TupleType[Type[Model], Member]]]
#: List of fields which have been tagged with a different column name
#: Mapping is class attr -> database column name.
__renamed_fields__: ClassVar[DictType[str, str]]
#: Set of fields to exclude from the database
__excluded_fields__: ClassVar[SetType[str]]
#: Reference to the sqlalchemy table backing this model
__table__: ClassVar[Optional[sa.Table]]
#: Database name. If the `database` field of the manager is a dict
#: This field will be used to determine which engine to use.
__database__: ClassVar[str] = "default"
#: Use SQL serializer
serializer = SQLModelSerializer.instance()
#: Use SQL object manager
objects = SQLModelManager.instance()
#: ID of this object in the database. Subclasses can redefine this as needed
_id = Typed(int).tag(primary_key=True)
[docs] @classmethod
async def restore(
cls: Type[T],
state: StateType,
force: Optional[bool] = None,
prefetched: Optional[DictType[Any, StateType]] = None,
**kwargs: Any,
) -> T:
"""Restore an object from the database using the primary key. Save
a ref in the table's object cache. If force is True, update
the cache if it exists.
state: Mapping[str, Any]
A mapping of field name to value. May contain result of a join (eg
state of multiple models prefexed with the table name).
force: Optional[bool]
Whether to force calling restorestate. This is used to to avoid
restoring cached objects.
prefetched: Optional[dict]
A mapping of prefetched related values. If present the objects
primary key is looked up and added to the state.
model: SQLModel
The restored or cached model.
if cls.__joined_pk__ in state:
# When sqlalchemy does a join the key will have a prefix
# of the database name
pk = state[cls.__joined_pk__]
pk = state[cls.__pk__]
# Note make sure this always occurs to force table creation
cache = cls.objects.cache
if pk is not None:
# Check if this is in the cache
obj = cache.get(pk)
obj = None
if obj is None:
# Create and cache it
obj = cls.__new__(cls)
# Do not place empty pk in cache
if pk is not None:
cache[pk] = obj
restore = True
# Check the default for force reloading
if force is None:
force = not cls.objects.table.bind.manager.cache
# Note that if force is false and the object was restored
# (ie from another query) the object in the cache is reused
# and any (potentially new) data in the state is discarded.
restore = force or not obj.__restored__
if restore:
# Merge any prefetched relation members into the restore state
# so the base class's restore method can find them.
if prefetched is not None:
prefetched_state = prefetched.get(pk)
if prefetched_state is not None:
state = dict(state) # Convert row proxy
# This ideally should only be done if created
await obj.__restorestate__(state)
return obj
[docs] async def load(
self: T,
reload: bool = False,
fields: Optional[Sequence[str]] = None,
"""Alias to load this object from the database
connection: Connection
The connection instance to use in a transaction
reload: Bool
If True force reloading the state even if the state has
already been loaded.
fields: Sequence[str]
Optional list of field names to load. Use this to refresh
specific fields from the database.
skip = self.__restored__ and not reload and not fields
if skip or not self._id:
return # Already loaded or won't do anything
db = self.objects
t = db.table
if fields is not None:
renamed = self.__renamed_fields__
columns = (t.c[renamed.get(f, f)] for f in fields)
q = sa.select(columns).select_from(t)
q = t.select()
q = q.where(t.c[self.__pk__] == self._id)
state = await db.fetchone(q, connection=connection)
await self.__restorestate__(state)
[docs] async def save(
self: T,
force_insert: bool = False,
force_update: bool = False,
update_fields: Optional[Sequence[str]] = None,
"""Alias to save this object to the database
force_insert: Bool
Ensure that save performs an insert
force_update: Bool
Ensure that save performs an update
update_fields: Iterable[str]
If given, only update the given fields
connection: Connection
The connection instance to use in a transaction
result: Value
Update or save result
if force_insert and force_update:
raise ValueError("Cannot use force_insert and force_update together")
db = self.objects
state = self.__getstate__()
# Remove any fields are in the state but should not go into the db
for f in self.__excluded_fields__:
state.pop(f, None)
# Replace any renamed fields
for py_name, db_name in self.__renamed_fields__.items():
state[db_name] = state.pop(py_name)
table = db.table
async with db.connection(connection) as conn:
if force_update or (self._id and not force_insert):
# If update fields was given, only pass those
if update_fields is not None:
# Replace any update fields with the appropriate name
renamed = self.__renamed_fields__
update_fields = [renamed.get(f, f) for f in update_fields]
# Replace update fields with only those given
state = {f: state[f] for f in update_fields}
q = (
.where(table.c[self.__pk__] == self._id)
r = await conn.execute(q)
if not r.rowcount:
f'Did not update "{self}", either no rows with '
f"pk={self._id} exist or it has not changed."
if not self._id:
# Postgres errors if using None for the pk
state.pop(self.__pk__, None)
q = table.insert().values(**state)
r = await conn.execute(q)
# Don't overwrite if force inserting
if not self._id:
if hasattr(r, "lastrowid"):
self._id = r.lastrowid # MySQL
self._id = await r.scalar() # Postgres
# Save a ref to the object in the model cache
db.cache[self._id] = self
self.__restored__ = True
return r
[docs] async def delete(self: T, connection=None):
"""Alias to delete this object in the database"""
pk = self._id
if not pk:
db = self.objects
table = db.table # type: sa.Table
q = table.delete().where(table.c[self.__pk__] == pk)
async with db.connection(connection) as conn:
r = await conn.execute(q)
if not r.rowcount:
f'Did not delete "{self}", no rows with ' f"pk={self._id} exist."
del db.cache[pk]
del self._id
return r