plexpy/lib/pydantic/root_model.py
2024-03-24 17:55:28 -07:00

150 lines
5.5 KiB
Python

"""RootModel class and type definitions."""
from __future__ import annotations as _annotations
import typing
from copy import copy, deepcopy
from pydantic_core import PydanticUndefined
from . import PydanticUserError
from ._internal import _model_construction, _repr
from .main import BaseModel, _object_setattr
if typing.TYPE_CHECKING:
from typing import Any
from typing_extensions import Literal, dataclass_transform
from .fields import Field as PydanticModelField
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
# on a new metaclass.
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField,))
class _RootModelMetaclass(_model_construction.ModelMetaclass):
...
Model = typing.TypeVar('Model', bound='BaseModel')
else:
_RootModelMetaclass = _model_construction.ModelMetaclass
__all__ = ('RootModel',)
RootModelRootType = typing.TypeVar('RootModelRootType')
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/models/#rootmodel-and-custom-root-types
A Pydantic `BaseModel` for the root object of the model.
Attributes:
root: The root object of the model.
__pydantic_root_model__: Whether the model is a RootModel.
__pydantic_private__: Private fields in the model.
__pydantic_extra__: Extra fields in the model.
"""
__pydantic_root_model__ = True
__pydantic_private__ = None
__pydantic_extra__ = None
root: RootModelRootType
def __init_subclass__(cls, **kwargs):
extra = cls.model_config.get('extra')
if extra is not None:
raise PydanticUserError(
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
)
super().__init_subclass__(**kwargs)
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
__tracebackhide__ = True
if data:
if root is not PydanticUndefined:
raise ValueError(
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
)
root = data # type: ignore
self.__pydantic_validator__.validate_python(root, self_instance=self)
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
@classmethod
def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model: # type: ignore
"""Create a new model using the provided root object and update fields set.
Args:
root: The root object of the model.
_fields_set: The set of fields to be updated.
Returns:
The new model.
Raises:
NotImplemented: If the model is not a subclass of `RootModel`.
"""
return super().model_construct(root=root, _fields_set=_fields_set)
def __getstate__(self) -> dict[Any, Any]:
return {
'__dict__': self.__dict__,
'__pydantic_fields_set__': self.__pydantic_fields_set__,
}
def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
_object_setattr(self, '__dict__', state['__dict__'])
def __copy__(self: Model) -> Model:
"""Returns a shallow copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', copy(self.__dict__))
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
"""Returns a deep copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
# and attempting a deepcopy would be marginally slower.
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
if typing.TYPE_CHECKING:
def model_dump( # type: ignore
self,
*,
mode: Literal['json', 'python'] | str = 'python',
include: Any = None,
exclude: Any = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> RootModelRootType:
"""This method is included just to get a more accurate return type for type checkers.
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
See the documentation of `BaseModel.model_dump` for more details about the arguments.
"""
...
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RootModel):
return NotImplemented
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
def __repr_args__(self) -> _repr.ReprArgs:
yield 'root', self.root