plexpy/lib/pydantic/_internal/_generics.py
2024-03-24 17:55:28 -07:00

518 lines
22 KiB
Python

from __future__ import annotations
import sys
import types
import typing
from collections import ChainMap
from contextlib import contextmanager
from contextvars import ContextVar
from types import prepare_class
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
from weakref import WeakValueDictionary
import typing_extensions
from ._core_utils import get_type_ref
from ._forward_ref import PydanticRecursiveRef
from ._typing_extra import TypeVarType, typing_base
from ._utils import all_identical, is_model_class
if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias # type: ignore[attr-defined]
if TYPE_CHECKING:
from ..main import BaseModel
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
# while also retaining a limited number of types even without references. This is generally enough to build
# specific recursive generic models without losing required items out of the cache.
KT = TypeVar('KT')
VT = TypeVar('VT')
_LIMITED_DICT_SIZE = 100
if TYPE_CHECKING:
class LimitedDict(dict, MutableMapping[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
...
else:
class LimitedDict(dict):
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
"""
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
else:
GenericTypesCache = WeakValueDictionary
if TYPE_CHECKING:
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
...
else:
class DeepChainMap(ChainMap):
"""Variant of ChainMap that allows direct updates to inner scopes.
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
with some light modifications for this use case.
"""
def clear(self) -> None:
for mapping in self.maps:
mapping.clear()
def __setitem__(self, key: KT, value: VT) -> None:
for mapping in self.maps:
mapping[key] = value
def __delitem__(self, key: KT) -> None:
hit = False
for mapping in self.maps:
if key in mapping:
del mapping[key]
hit = True
if not hit:
raise KeyError(key)
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
# and discover later on that we need to re-add all this infrastructure...
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
_GENERIC_TYPES_CACHE = GenericTypesCache()
class PydanticGenericMetadata(typing_extensions.TypedDict):
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
def create_generic_submodel(
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
) -> type[BaseModel]:
"""Dynamically create a submodel of a provided (generic) BaseModel.
This is used when producing concrete parametrizations of generic models. This function
only *creates* the new subclass; the schema/validators/serialization must be updated to
reflect a concrete parametrization elsewhere.
Args:
model_name: The name of the newly created model.
origin: The base class for the new model to inherit from.
args: A tuple of generic metadata arguments.
params: A tuple of generic metadata parameters.
Returns:
The created submodel.
"""
namespace: dict[str, Any] = {'__module__': origin.__module__}
bases = (origin,)
meta, ns, kwds = prepare_class(model_name, bases)
namespace.update(ns)
created_model = meta(
model_name,
bases,
namespace,
__pydantic_generic_metadata__={
'origin': origin,
'args': args,
'parameters': params,
},
__pydantic_reset_parent_namespace__=False,
**kwds,
)
model_module, called_globally = _get_caller_frame_info(depth=3)
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
return created_model
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
"""Used inside a function to check whether it was called globally.
Args:
depth: The depth to get the frame.
Returns:
A tuple contains `module_name` and `called_globally`.
Raises:
RuntimeError: If the function is not called inside a function.
"""
try:
previous_caller_frame = sys._getframe(depth)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
DictValues: type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
"""
if isinstance(v, TypeVar):
yield v
elif is_model_class(v):
yield from v.__pydantic_generic_metadata__['parameters']
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_args(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('args')
return typing_extensions.get_args(v)
def get_origin(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('origin')
return typing_extensions.get_origin(v)
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
"""
origin = get_origin(cls)
if origin is None:
return None
if not hasattr(origin, '__parameters__'):
return None
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
# So it is safe to access cls.__args__ and origin.__parameters__
args: tuple[Any, ...] = cls.__args__ # type: ignore
parameters: tuple[TypeVarType, ...] = origin.__parameters__
return dict(zip(parameters, args))
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
with the `replace_types` function.
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
"""
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
# in the __origin__, __args__, and __parameters__ attributes of the model.
generic_metadata = cls.__pydantic_generic_metadata__
origin = generic_metadata['origin']
args = generic_metadata['args']
return dict(zip(iter_contained_typevars(origin), args))
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Args:
type_: The class or generic alias.
type_map: Mapping from `TypeVar` instance to concrete types.
Returns:
A new type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
Example:
```py
from typing import List, Tuple, Union
from pydantic._internal._generics import replace_types
replace_types(Tuple[str, Union[List[str], float]], {str: int})
#> Tuple[int, Union[List[int], float]]
```
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
annotated = replace_types(annotated_type, type_map)
for annotation in annotations:
annotated = typing_extensions.Annotated[annotated, annotation]
return annotated
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
# We also cannot use isinstance() since we have to compare types.
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
return _UnionGenericAlias(origin_type, resolved_type_args)
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and is_model_class(type_):
parameters = type_.__pydantic_generic_metadata__['parameters']
if not parameters:
return type_
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
if all_identical(parameters, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
"""Checks if the type, or any of its arbitrary nested args, satisfy
`isinstance(<type>, isinstance_target)`.
"""
if isinstance(type_, isinstance_target):
return True
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
return has_instance_in_type(annotated_type, isinstance_target)
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
return True
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
if any(has_instance_in_type(element, isinstance_target) for element in type_):
return True
return False
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
"""Check the generic model parameters count is equal.
Args:
cls: The generic model.
parameters: A tuple of passed parameters to the generic model.
Raises:
TypeError: If the passed parameters count is not equal to generic model parameters count.
"""
actual = len(parameters)
expected = len(cls.__pydantic_generic_metadata__['parameters'])
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
@contextmanager
def generic_recursion_self_type(
origin: type[BaseModel], args: tuple[Any, ...]
) -> Iterator[PydanticRecursiveRef | None]:
"""This contextmanager should be placed around the recursive calls used to build a generic type,
and accept as arguments the generic origin type and the type arguments being passed to it.
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
can be used while building the core schema, and will produce a schema_ref that will be valid in the
final parent schema.
"""
previously_seen_type_refs = _generic_recursion_cache.get()
if previously_seen_type_refs is None:
previously_seen_type_refs = set()
token = _generic_recursion_cache.set(previously_seen_type_refs)
else:
token = None
try:
type_ref = get_type_ref(origin, args_override=args)
if type_ref in previously_seen_type_refs:
self_type = PydanticRecursiveRef(type_ref=type_ref)
yield self_type
else:
previously_seen_type_refs.add(type_ref)
yield None
finally:
if token:
_generic_recursion_cache.reset(token)
def recursively_defined_type_refs() -> set[str]:
visited = _generic_recursion_cache.get()
if not visited:
return set() # not in a generic recursion, so there are no types
return visited.copy() # don't allow modifications
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
The approach could be modified to not use two different cache keys at different points, but the
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
same after resolving the type arguments will always produce cache hits.
If we wanted to move to only using a single cache key per type, we would either need to always use the
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
equal.
"""
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
def get_cached_generic_type_late(
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
) -> type[BaseModel] | None:
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
if cached is not None:
set_cached_generic_type(parent, typevar_values, cached, origin, args)
return cached
def set_cached_generic_type(
parent: type[BaseModel],
typevar_values: tuple[Any, ...],
type_: type[BaseModel],
origin: type[BaseModel] | None = None,
args: tuple[Any, ...] | None = None,
) -> None:
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
two different keys.
"""
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
if len(typevar_values) == 1:
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
if origin and args:
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
def _union_orderings_key(typevar_values: Any) -> Any:
"""This is intended to help differentiate between Union types with the same arguments in different order.
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
because `typing` considers Union[int, float] to be equal to Union[float, int].
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
(See https://github.com/python/cpython/issues/86483 for reference.)
"""
if isinstance(typevar_values, tuple):
args_data = []
for value in typevar_values:
args_data.append(_union_orderings_key(value))
return tuple(args_data)
elif typing_extensions.get_origin(typevar_values) is typing.Union:
return get_args(typevar_values)
else:
return ()
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for minimal computational overhead during lookups of cached types.
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
would result in the same type.
"""
return cls, typevar_values, _union_orderings_key(typevar_values)
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for use later in the process of creating a new type, when we have more information
about the exact args that will be passed. If it turns out that a different set of inputs to
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
return the cached type, and update the cache with the _early_cache_key as well.
"""
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
# whereas this function will always produce a tuple as the first item in the key.
return _union_orderings_key(typevar_values), origin, args