"""
Copyright (c) 2018-2022, CodeLV.
Distributed under the terms of the MIT License.
The full license is in the file LICENSE.text, distributed with this software.
Created on Jun 12, 2018
"""
import asyncio
import logging
from base64 import b64decode, b64encode
from collections.abc import MutableMapping
from datetime import date, datetime, time
from decimal import Decimal
from pprint import pformat
from random import getrandbits
from typing import Any, Callable, ClassVar
from typing import Dict as DictType
from typing import List as ListType
from typing import Optional
from typing import Tuple as TupleType
from typing import Type, TypeVar
from uuid import UUID
from atom.api import (
    Atom,
    AtomMeta,
    Bool,
    Coerced,
    Dict,
    Float,
    Instance,
    Int,
    List,
    Member,
    Property,
    Set,
    Str,
    Tuple,
    Typed,
    Value,
    set_default,
)
from bytecode import Bytecode, Instr, Label
T = TypeVar("T")
M = TypeVar("M", bound="Model")
ScopeType = DictType[int, Any]
StateType = DictType[str, Any]
GetStateFn = Callable[[M, Optional[ScopeType]], StateType]
RestoreStateFn = Callable[[M, StateType, Optional[ScopeType]], None]
log = logging.getLogger("atomdb")
[docs]def find_subclasses(cls: Type[T]) -> ListType[Type[T]]:
    """Finds subclasses of the given class"""
    classes = []
    for subclass in cls.__subclasses__():
        classes.append(subclass)
        classes.extend(find_subclasses(subclass))
    return classes 
[docs]def is_db_field(m: Member) -> bool:
    """Check if the member should be saved into the database.  Any member that
    does not start with an underscore, is not a Property, and is not tagged
    with `store=False` is considered to be field to save into the database.
    Parameters
    ----------
    m: Member
        The atom member to check.
    Returns
    -------
    result: bool
        Whether the member should be saved into the database.
    """
    metadata = m.metadata
    default = not m.name.startswith("_")
    if metadata is not None:
        return metadata.get("store", default)
    if isinstance(m, Property):
        return False  # Users can override this by tagging it with store=True
    return default 
[docs]def is_primitive_member(m: Member) -> Optional[bool]:
    """Check if the member can be serialized without calling flatten. If the
    member references a field that is not yet resolved it returns None
    indicating that it cannot determine whether it is primitive yet.
    Parameters
    ----------
    m: Member
        The atom member to check.
    Returns
    -------
    result: Optional[bool]
        Whether the member is a primitive type that can be intrinsicly
        converted.
    """
    if isinstance(m, (Bool, Str, Int, Float)):
        return True
    if hasattr(m, "resolve"):
        # These cannot be resolved until their dependencies are available
        return None
    if isinstance(m, (Tuple, Set, List, Typed, Instance, Dict, Coerced)):
        try:
            types = resolve_member_types(m, resolve=False)
        except UnresolvableError:
            return None
        if types is None:
            return False  # Value can be any type
        if types and all(t in (int, float, bool, str) for t in types):
            return True
    return False 
[docs]def resolve_member_types(
    member: Member, resolve: bool = True
) -> Optional[TupleType[type, ...]]:
    """Determine the validation types specified on a member.
    Parameters
    ----------
    member: Member
        The member to retrieve the type from
    resolve: bool
        Whether to resolve "Forward" members.
    Returns
    -------
    types: Optional[Tuple[Model|Member|type, ..]]
        The member types. If types is `None` then the member does not do any
        type validation.
    Raises
    ------
    UnresolveableError
        If `resolve=False` and the member has a nested forwarded member this
        will raise an UnresolvableError with the unresolved member.
    """
    # TODO: This should really use the validate mode...
    if hasattr(member, "resolve"):
        if not resolve:
            raise UnresolvableError(member)  # Do not resolve now
        types = member.resolve()  # type: ignore
    elif isinstance(member, Coerced):
        types = member.validate_mode[-1][0]
    else:
        types = member.validate_mode[-1]
    if types is None:
        return None
    if isinstance(types, tuple):
        # Dict may have an member in the types list, so walk the types
        # and resolve all of those.
        resolved: ListType[type] = []
        for t in types:
            if isinstance(t, Member):
                r = resolve_member_types(t, resolve)
                if r is None:
                    # TODO: Think about whether this is correct to bail out here
                    return None
                resolved.extend(r)
            else:
                resolved.append(t)
        return tuple(resolved)
    if isinstance(types, Member):
        # Follow the chain. For example if the member is defined
        # as `List(Tuple(float)))` lookup the types of the nested Tuple().
        return resolve_member_types(types, resolve)
    if isinstance(types, str):
        return None  # Custom validation method
    return (types,) 
[docs]class UnresolvableError(Exception):
    """Error raised when a Forwarded Member cannot be resolved at the time
    when the resolve_member_types is called.
    """
    def __init__(self, member):
        self.member = member
        super().__init__(f"Cannot resolve {member}") 
[docs]class ModelSerializer(Atom):
    """Handles serializing and deserializing of Model subclasses. It
    will automatically save and restore references where present.
    """
    #: Hold one instance per subclass for easy reuse
    _instances: ClassVar[DictType[Type["ModelSerializer"], "ModelSerializer"]] = {}
    #: Store all registered models
    registry = Dict()
    #: Mapping of type name to coercer function
    coercers = Dict(
        default={
            "datetime.date": lambda s: date(**s),
            "datetime.datetime": lambda s: datetime(**s),
            "datetime.time": lambda s: time(**s),
            "bytes": lambda s: b64decode(s["bytes"]),
            "decimal": lambda s: Decimal(s["value"]),
            "uuid": lambda s: UUID(s["id"]),
        }
    )
[docs]    @classmethod
    def instance(cls: Type["ModelSerializer"]) -> "ModelSerializer":
        if cls not in ModelSerializer._instances:
            ModelSerializer._instances[cls] = cls()
        return ModelSerializer._instances[cls] 
[docs]    def flatten(self, v: Any, scope: Optional[ScopeType] = None) -> Any:
        """Convert Model objects to a dict
        Parameters
        ----------
        v: Object
            The object to flatten
        scope: Dict
            The scope of references available for circular lookups
        Returns
        -------
        result: Object
            The flattened object
        """
        flatten = self.flatten
        scope = scope or {}
        # Handle circular reference
        if isinstance(v, Model):
            return v.serializer.flatten_object(v, scope)
        elif isinstance(v, (list, tuple, set)):
            return [flatten(item, scope) for item in v]
        elif isinstance(v, (dict, MutableMapping)):
            return {k: flatten(item, scope) for k, item in v.items()}
        # TODO: Handle other object types
        return v 
[docs]    def flatten_object(self, obj: "Model", scope: ScopeType) -> Any:
        """Serialize a model for entering into the database
        Parameters
        ----------
        obj: Model
            The object to unflatten
        scope: Dict
            The scope of references available for circular lookups
        Returns
        -------
        result: Object
            The flattened object
        """
        raise NotImplementedError 
[docs]    async def unflatten(self, v: Any, scope: Optional[ScopeType] = None) -> Any:
        """Convert dict or list to Models
        Parameters
        ----------
        v: Dict or List
            The object(s) to unflatten
        scope: Dict
            The scope of references available for circular lookups
        Returns
        -------
        result: Object
            The unflattened object
        """
        if isinstance(v, dict):
            # Circular reference
            if scope and "__ref__" in v:
                ref = v["__ref__"]
                if ref in scope:
                    return scope[ref]
            # Create the object
            if "__model__" in v:
                cls = self.registry[v["__model__"]]
                return await cls.serializer.unflatten_object(cls, v, scope)
            # Convert py types
            if "__py__" in v:
                coercer = self.coercers.get(v.pop("__py__"))
                if coercer:
                    return coercer(v)
            unflatten = self.unflatten
            return {k: await unflatten(i, scope) for k, i in v.items()}
        elif isinstance(v, (list, tuple)):
            unflatten = self.unflatten
            return [await unflatten(item, scope) for item in v]
        return v 
[docs]    async def unflatten_object(
        self, cls: Type["Model"], state: StateType, scope: ScopeType
    ) -> Optional["Model"]:
        """Restore the object for the given class, state, and scope.
        If a reference is given the scope should be updated with the newly
        created object using the given ref.
        Parameters
        ----------
        cls: Class
            The type of object expected
        state: Dict
            The state of the object to restore
        Returns
        -------
        result: object or None
            A the newly created object (or an existing object if using a cache)
            or None if this object does not exist in the database.
        """
        _id = state.get("_id")
        # Get the object for this id, retrieve from cache if needed
        obj, created = await self.get_or_create(cls, state, scope)
        # Lookup the object if needed
        if created and _id is not None:
            # If a new object was created lookup the state for that object
            state = await self.get_object_state(obj, state, scope)
            if state is None:
                return None
        # Child objects may have circular references to this object
        # so we must update the scope with this reference to handle this
        # before restoring any children
        if scope and "__ref__" in state:
            scope[state["__ref__"]] = obj
        # If not restoring from cache update the state
        if created:
            await obj.__restorestate__(state, scope)
        return obj 
[docs]    async def get_or_create(
        self, cls: Type["Model"], state: Any, scope: ScopeType
    ) -> TupleType["Model", bool]:
        """Get a cached object for this _id or create a new one. Subclasses
        should override this as needed to provide object caching if desired.
        Parameters
        ----------
        cls: Class
            The type of object expected
        state: Dict
            Unflattened state of object to restore
        scope: Dict
            Scope of objects available when flattened
        Returns
        -------
        result: Tuple[object, bool]
            A tuple of the object and a flag stating if it was created or not.
        """
        return (cls.__new__(cls), True) 
[docs]    async def get_object_state(self, obj: "Model", state: Any, scope: ScopeType) -> Any:
        """Lookup the state needed to restore the given object id and class.
        Parameters
        ----------
        obj: Model
            The object created by `get_or_create`
        state: Dict
            Unflattened state of object to restore
        scope: Dict
            Scope of objects available when flattened
        Returns
        -------
        result: Any
            The model state needed to restore this object
        """
        raise NotImplementedError  
[docs]class ModelManager(Atom):
    """A descriptor so you can use this somewhat like Django's models.
    Assuming your using motor.
    Examples
    --------
    MyModel.objects.find_one({'_id':'someid})
    """
    #: Stores instances of each class so we can easily reuse them if desired
    _instances: ClassVar[DictType[Type["ModelManager"], "ModelManager"]] = {}
[docs]    @classmethod
    def instance(cls) -> "ModelManager":
        if cls not in ModelManager._instances:
            ModelManager._instances[cls] = cls()
        return ModelManager._instances[cls] 
    #: Used to access the database
    database = Value()
    def _default_database(self) -> Any:
        raise NotImplementedError
    def __get__(self, obj: T, cls: Optional[Type[T]] = None):
        """Handle objects from the class that oType[wns the manager. Subclasses
        should override this as needed.
        """
        raise NotImplementedError 
[docs]def generate_getstate(cls: Type["Model"]) -> GetStateFn:
    """Generate an optimized __getstate__ function for the given model.
    Parameters
    ----------
    cls: Type[Model]
        The clase to generate a getstate function for.
    Returns
    -------
    result: GetStateFn
        A function optimized to generate the state for the given model class.
    """
    template = [
        "def __getstate__(self, scope=None):",
        "scope = scope or {}",
        "scope[self.__ref__] = self",
        "state = {",
    ]
    default_flatten = cls.serializer.flatten
    members = cls.members()
    namespace = {
        "default_flatten": default_flatten,
    }
    for f in cls.__fields__:
        # Since f is potentially an untrusted input, make sure it is a valid
        # python identifier to prevent unintended code being generated.
        if not f.isidentifier():
            raise ValueError(f"Field '{f}' cannot be used for code generation")
        m = members[f]
        meta = m.metadata or {}
        flatten = meta.get("flatten", default_flatten)
        if flatten is default_flatten:
            if is_primitive_member(m):
                expr = f"self.{f}"
            else:
                expr = f"default_flatten(self.{f}, scope)"
        else:
            namespace[f"flatten_{f}"] = flatten
            expr = f"flatten_{f}(self.{f}, scope)"
        template.append(f'    "{f}": {expr},')
    template.append('    "__model__": self.__model__,')
    template.append('    "__ref__": self.__ref__,')
    template.append("}")
    if "_id" in members:
        template.append("if self._id:")
        template.append('    state["_id"] = self._id')
    template.append("return state")
    source = "\n    ".join(template)
    return generate_function(source, namespace, "__getstate__") 
[docs]def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn:
    """Generate an optimized __restorestate__ function for the given model.
    Parameters
    ----------
    cls: Type[Model]
        The clase to generate a getstate function for.
    Returns
    -------
    result: RestoreStateFn
        A function optimized to restore the state for the given model class.
    """
    # Python must do some caching because using key in state and state[key]
    # seems to be faster than using get
    template = [
        "async def __restorestate__(self, state, scope=None):",
        "if '__model__' in state and state['__model__'] != self.__model__:",
        "    name = state['__model__']",
        "    raise ValueError(",
        "        f'Trying to use {name} state for {self.__model__} object'",
        "    )",
        "if '__ref__' in state and state['__ref__'] is not None:",
        "    scope = scope or {}",
        "    scope[state['__ref__']] = self",
    ]
    default_unflatten = cls.serializer.unflatten
    members = cls.members()
    excluded = (
        "__ref__",
        "__restored__",
    )
    setters = []
    for f, m in members.items():
        if f in excluded:
            continue
        meta = m.metadata or {}
        order = meta.get("setstate_order", 1000)
        # Allow  tagging a custom unflatten fn
        unflatten = meta.get("unflatten", default_unflatten)
        setters.append((order, f, m, unflatten))
    setters.sort(key=lambda it: it[0])
    on_error = cls.__on_error__
    namespace: DictType[str, Any] = {
        "default_unflatten": default_unflatten,
    }
    for order, f, m, unflatten in setters:
        # Since f is potentially an untrusted input, make sure it is a valid
        # python identifier to prevent unintended code being generated.
        if not f.isidentifier():
            raise ValueError(f"Field '{f}' cannot be used for code generation")
        template.append(f"if '{f}' in state:")
        # Determine the expresion to unflatten the value
        if unflatten is default_unflatten:
            RelModel = None
            # If the member is typed we can shortcut looking up the __model__
            # type from the state and restore it directly.
            # Note that this does not work for instances.
            if isinstance(m, Typed):
                types = resolve_member_types(m, resolve=False)
                if types and len(types) == 1 and issubclass(types[0], Model):
                    RelModel = types[0]
            if RelModel is not None:
                namespace[f"rel_model_{f}"] = RelModel
                expr = f"await rel_model_{f}.restore(state['{f}'])"
            elif is_primitive_member(m):
                # Direct assignment
                expr = f"state['{f}']"
            else:
                # Default flatten
                expr = f"await default_unflatten(state['{f}'], scope)"
        else:
            namespace[f"unflatten_{f}"] = unflatten
            if asyncio.iscoroutinefunction(unflatten):
                expr = f"await unflatten_{f}(state['{f}'], scope)"
            else:
                expr = f"unflatten_{f}(state['{f}'], scope)"
        # Do the assignment
        if on_error == "raise":
            template.append(f"    self.{f} = {expr}")
        else:
            if on_error == "log":
                handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
            else:
                handler = "pass"
            template.extend(
                [
                    "    try:",
                    f"        self.{f} = {expr}",
                    "    except Exception as e:",
                    f"        {handler}",
                ]
            )
    # Update restored state
    template.append("self.__restored__ = True")
    source = "\n    ".join(template)
    return generate_function(source, namespace, "__restorestate__") 
[docs]def generate_function(
    source: str, namespace: DictType[str, Any], fn_name: str, optimize: bool = True
) -> Callable[..., Any]:
    """Generate an optimized function
    Parameters
    ----------
    source: str
        The function source code
    namespaced: dict
        Namespace available to the function
    fn_name: str
        The name of the generated function.
    Returns
    -------
    fn: function
        The function generated.
    """
    # print(source)
    try:
        assert source.startswith(f"def {fn_name}") or source.startswith(
            f"async def {fn_name}"
        )
        code = compile(source, __name__, "exec", optimize=1)
    except Exception as e:
        raise RuntimeError(f"Could not generate code: {e}:\n{source}")
    result: DictType[str, Any] = {}
    exec(code, namespace, result)
    # Optimize global access
    fn = result[fn_name]
    fn.__source__ = source
    if optimize:
        bc = Bytecode.from_code(fn.__code__)
        for i, inst in enumerate(bc):
            if isinstance(inst, Label):
                continue
            if inst.name == "LOAD_GLOBAL" and inst.arg in namespace:
                bc[i] = Instr("LOAD_CONST", namespace[inst.arg])
        fn.__code__ = bc.to_code()
    return fn 
[docs]class Model(Atom, metaclass=ModelMeta):
    """An atom model that can be serialized and deserialized to and from
    a database.
    """
    # --------------------------------------------------------------------------
    # Class attributes
    # --------------------------------------------------------------------------
    __slots__ = "__weakref__"
    #: List of database field member names
    __fields__: ClassVar[ListType[str]]
    #: Table name used when saving into the database
    __model__: ClassVar[str]
    #: Error handling
    __on_error__: ClassVar[str] = "log"  # "ignore" or "raise"
    # --------------------------------------------------------------------------
    # Internal model members
    # --------------------------------------------------------------------------
    #: A unique ID used to handle cyclical serialization and deserialization
    __ref__ = Int(factory=lambda: getrandbits(32))
    #: Flag to indicate if this model has been restored or saved
    __restored__ = Bool().tag(store=False)
    # --------------------------------------------------------------------------
    # Serialization API
    # --------------------------------------------------------------------------
    #: Handles encoding and decoding. Subclasses should redefine this to a
    #: subclass of ModelSerializer
    serializer: ClassVar[ModelSerializer] = ModelSerializer.instance()
    #: Optimized serialize functions. These are generated by the metaclass.
    __generated_getstate__: ClassVar[GetStateFn]
    __generated_restorestate__: ClassVar[RestoreStateFn]
    def __getstate__(self, scope: Optional[ScopeType] = None) -> StateType:
        """Get the serialized model state. By default this delegates to an
        optimized function generated by the ModelMeta class.
        Parameters
        ----------
        scope: Optionl[ScopeType
            The scope to lookup circular references.
        Returns
        -------
        state: StateType
            The state of the object.
        """
        return self.__generated_getstate__(scope)
    async def __restorestate__(
        self, state: StateType, scope: Optional[ScopeType] = None
    ):
        """Restore an object from the a state from the database. This is
        async as it will lookup any referenced objects from the DB.
        State is restored by calling setattr(k, v) for every item in the state
        that has an associated atom member.  Members can be tagged with a
        `setstate_order=<number>` to define the order of setattr calls. Errors
        from setattr are caught and logged instead of raised.
        Parameters
        ----------
        state: Dict
            A dictionary of state keys and values
        scope: Dict or None
            A namespace to use to resolve any possible circular references.
            The __ref__ value is used as the keys.
        """
        await self.__generated_restorestate__(state, scope)  # type: ignore
    def __log_restore_error__(
        self, e: Exception, k: str, state: StateType, scope: Optional[ScopeType]
    ):
        """Log details when restoring a member fails. This typically only will
        occur if the state has data from an old model after a schema change.
        """
        obj = state.get(k)
        log.warning(
            f"Error loading state:"
            f"{self.__model__}.{k} = {pformat(obj)}:"
            f"\nRef: {self.__ref__}"
            f"\nScope: {pformat(scope)}"
            f"\nState: {pformat(state)}"
            f"\n{e}"
        )
    # --------------------------------------------------------------------------
    # Database API
    # --------------------------------------------------------------------------
    #: Handles database access. Subclasses should redefine this.
    objects: ClassVar[ModelManager] = ModelManager()
[docs]    @classmethod
    async def restore(cls: Type[M], state: StateType, **kwargs: Any) -> M:
        """Restore an object from the database state"""
        obj = cls.__new__(cls)
        await obj.__restorestate__(state)
        return obj 
[docs]    async def load(self):
        """Alias to load this object from the database"""
        raise NotImplementedError 
[docs]    async def save(self):
        """Alias to delete this object to the database"""
        raise NotImplementedError 
[docs]    async def delete(self):
        """Alias to delete this object in the database"""
        raise NotImplementedError  
[docs]class JSONSerializer(ModelSerializer):
[docs]    def flatten(self, v: Any, scope: Optional[ScopeType] = None):
        """Flatten date, datetime, time, decimal, and bytes as a dict with
        a __py__ field and arguments to reconstruct it. Also see the coercers
        """
        if isinstance(v, (date, datetime, time)):
            # This is inefficient space wise but still allows queries
            s: DictType[str, Any] = {
                "__py__": f"{v.__class__.__module__}.{v.__class__.__name__}"
            }
            if isinstance(v, (date, datetime)):
                s.update({"year": v.year, "month": v.month, "day": v.day})
            if isinstance(v, (time, datetime)):
                s.update(
                    {
                        "hour": v.hour,
                        "minute": v.minute,
                        "second": v.second,
                        "microsecond": v.microsecond,
                        # TODO: Timezones
                    }
                )
            return s
        if isinstance(v, bytes):
            return {"__py__": "bytes", "bytes": b64encode(v).decode()}
        if isinstance(v, Decimal):
            return {"__py__": "decimal", "value": str(v)}
        if isinstance(v, UUID):
            return {"__py__": "uuid", "id": str(v)}
        return super().flatten(v, scope) 
[docs]    def flatten_object(self, obj: Model, scope: ScopeType) -> DictType[str, Any]:
        """Flatten to just json but add in keys to know how to restore it."""
        ref = obj.__ref__
        if ref in scope:
            return {"__ref__": ref, "__model__": obj.__model__}
        else:
            scope[ref] = obj
        return obj.__getstate__(scope) 
[docs]    async def get_object_state(self, obj: Any, state: StateType, scope: ScopeType):
        """State should be contained in the dict"""
        return state 
    def _default_registry(self) -> DictType[str, Type[Model]]:
        return {m.__model__: m for m in find_subclasses(JSONModel)} 
[docs]class JSONModel(Model):
    """A simple model that can be serialized to json. Useful for embedding
    within other models.
    """
    serializer = JSONSerializer.instance()
    __restored__ = set_default(True)  # type: ignore