104 lines
2.9 KiB
Python
104 lines
2.9 KiB
Python
|
# SPDX-License-Identifier: MIT
|
||
|
|
||
|
import inspect
|
||
|
import platform
|
||
|
import sys
|
||
|
import threading
|
||
|
|
||
|
from collections.abc import Mapping, Sequence # noqa: F401
|
||
|
from typing import _GenericAlias
|
||
|
|
||
|
|
||
|
PYPY = platform.python_implementation() == "PyPy"
|
||
|
PY_3_8_PLUS = sys.version_info[:2] >= (3, 8)
|
||
|
PY_3_9_PLUS = sys.version_info[:2] >= (3, 9)
|
||
|
PY_3_10_PLUS = sys.version_info[:2] >= (3, 10)
|
||
|
PY_3_11_PLUS = sys.version_info[:2] >= (3, 11)
|
||
|
PY_3_12_PLUS = sys.version_info[:2] >= (3, 12)
|
||
|
PY_3_13_PLUS = sys.version_info[:2] >= (3, 13)
|
||
|
PY_3_14_PLUS = sys.version_info[:2] >= (3, 14)
|
||
|
|
||
|
|
||
|
if sys.version_info < (3, 8):
|
||
|
try:
|
||
|
from typing_extensions import Protocol
|
||
|
except ImportError: # pragma: no cover
|
||
|
Protocol = object
|
||
|
else:
|
||
|
from typing import Protocol # noqa: F401
|
||
|
|
||
|
if PY_3_14_PLUS: # pragma: no cover
|
||
|
import annotationlib
|
||
|
|
||
|
_get_annotations = annotationlib.get_annotations
|
||
|
|
||
|
else:
|
||
|
|
||
|
def _get_annotations(cls):
|
||
|
"""
|
||
|
Get annotations for *cls*.
|
||
|
"""
|
||
|
return cls.__dict__.get("__annotations__", {})
|
||
|
|
||
|
|
||
|
class _AnnotationExtractor:
|
||
|
"""
|
||
|
Extract type annotations from a callable, returning None whenever there
|
||
|
is none.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ["sig"]
|
||
|
|
||
|
def __init__(self, callable):
|
||
|
try:
|
||
|
self.sig = inspect.signature(callable)
|
||
|
except (ValueError, TypeError): # inspect failed
|
||
|
self.sig = None
|
||
|
|
||
|
def get_first_param_type(self):
|
||
|
"""
|
||
|
Return the type annotation of the first argument if it's not empty.
|
||
|
"""
|
||
|
if not self.sig:
|
||
|
return None
|
||
|
|
||
|
params = list(self.sig.parameters.values())
|
||
|
if params and params[0].annotation is not inspect.Parameter.empty:
|
||
|
return params[0].annotation
|
||
|
|
||
|
return None
|
||
|
|
||
|
def get_return_type(self):
|
||
|
"""
|
||
|
Return the return type if it's not empty.
|
||
|
"""
|
||
|
if (
|
||
|
self.sig
|
||
|
and self.sig.return_annotation is not inspect.Signature.empty
|
||
|
):
|
||
|
return self.sig.return_annotation
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
# Thread-local global to track attrs instances which are already being repr'd.
|
||
|
# This is needed because there is no other (thread-safe) way to pass info
|
||
|
# about the instances that are already being repr'd through the call stack
|
||
|
# in order to ensure we don't perform infinite recursion.
|
||
|
#
|
||
|
# For instance, if an instance contains a dict which contains that instance,
|
||
|
# we need to know that we're already repr'ing the outside instance from within
|
||
|
# the dict's repr() call.
|
||
|
#
|
||
|
# This lives here rather than in _make.py so that the functions in _make.py
|
||
|
# don't have a direct reference to the thread-local in their globals dict.
|
||
|
# If they have such a reference, it breaks cloudpickle.
|
||
|
repr_context = threading.local()
|
||
|
|
||
|
|
||
|
def get_generic_base(cl):
|
||
|
"""If this is a generic class (A[str]), return the generic base for it."""
|
||
|
if cl.__class__ is _GenericAlias:
|
||
|
return cl.__origin__
|
||
|
return None
|