409 lines
14 KiB
Python
409 lines
14 KiB
Python
|
# Little utilities we use internally
|
|||
|
from __future__ import annotations
|
|||
|
|
|||
|
import collections.abc
|
|||
|
import inspect
|
|||
|
import os
|
|||
|
import signal
|
|||
|
import threading
|
|||
|
from abc import ABCMeta
|
|||
|
from functools import update_wrapper
|
|||
|
from typing import (
|
|||
|
TYPE_CHECKING,
|
|||
|
Any,
|
|||
|
Awaitable,
|
|||
|
Callable,
|
|||
|
Generic,
|
|||
|
NoReturn,
|
|||
|
Sequence,
|
|||
|
TypeVar,
|
|||
|
final as std_final,
|
|||
|
)
|
|||
|
|
|||
|
from sniffio import thread_local as sniffio_loop
|
|||
|
|
|||
|
import trio
|
|||
|
|
|||
|
CallT = TypeVar("CallT", bound=Callable[..., Any])
|
|||
|
T = TypeVar("T")
|
|||
|
RetT = TypeVar("RetT")
|
|||
|
|
|||
|
if TYPE_CHECKING:
|
|||
|
from types import AsyncGeneratorType, TracebackType
|
|||
|
|
|||
|
from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack
|
|||
|
|
|||
|
ArgsT = ParamSpec("ArgsT")
|
|||
|
PosArgsT = TypeVarTuple("PosArgsT")
|
|||
|
|
|||
|
|
|||
|
if TYPE_CHECKING:
|
|||
|
# Don't type check the implementation below, pthread_kill does not exist on Windows.
|
|||
|
def signal_raise(signum: int) -> None: ...
|
|||
|
|
|||
|
|
|||
|
# Equivalent to the C function raise(), which Python doesn't wrap
|
|||
|
elif os.name == "nt":
|
|||
|
# On Windows, os.kill exists but is really weird.
|
|||
|
#
|
|||
|
# If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
|
|||
|
# those using GenerateConsoleCtrlEvent. But I found that when I tried
|
|||
|
# to run my test normally, it would freeze waiting... unless I added
|
|||
|
# print statements, in which case the test suddenly worked. So I guess
|
|||
|
# these signals are only delivered if/when you access the console? I
|
|||
|
# don't really know what was going on there. From reading the
|
|||
|
# GenerateConsoleCtrlEvent docs I don't know how it worked at all.
|
|||
|
#
|
|||
|
# I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
|
|||
|
# work for creating synthetic control-C events, and... failed
|
|||
|
# utterly. There are lots of details in the code and comments
|
|||
|
# removed/added at this commit:
|
|||
|
# https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
|
|||
|
#
|
|||
|
# OTOH, if you pass os.kill any *other* signal number... then CPython
|
|||
|
# just calls TerminateProcess (wtf).
|
|||
|
#
|
|||
|
# So, anyway, os.kill is not so useful for testing purposes. Instead,
|
|||
|
# we use raise():
|
|||
|
#
|
|||
|
# https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
|
|||
|
#
|
|||
|
# Have to import cffi inside the 'if os.name' block because we don't
|
|||
|
# depend on cffi on non-Windows platforms. (It would be easy to switch
|
|||
|
# this to ctypes though if we ever remove the cffi dependency.)
|
|||
|
#
|
|||
|
# Some more information:
|
|||
|
# https://bugs.python.org/issue26350
|
|||
|
#
|
|||
|
# Anyway, we use this for two things:
|
|||
|
# - redelivering unhandled signals
|
|||
|
# - generating synthetic signals for tests
|
|||
|
# and for both of those purposes, 'raise' works fine.
|
|||
|
import cffi
|
|||
|
|
|||
|
_ffi = cffi.FFI()
|
|||
|
_ffi.cdef("int raise(int);")
|
|||
|
_lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
|
|||
|
signal_raise = getattr(_lib, "raise")
|
|||
|
else:
|
|||
|
|
|||
|
def signal_raise(signum: int) -> None:
|
|||
|
signal.pthread_kill(threading.get_ident(), signum)
|
|||
|
|
|||
|
|
|||
|
# See: #461 as to why this is needed.
|
|||
|
# The gist is that threading.main_thread() has the capability to lie to us
|
|||
|
# if somebody else edits the threading ident cache to replace the main
|
|||
|
# thread; causing threading.current_thread() to return a _DummyThread,
|
|||
|
# causing the C-c check to fail, and so on.
|
|||
|
# Trying to use signal out of the main thread will fail, so we can then
|
|||
|
# reliably check if this is the main thread without relying on a
|
|||
|
# potentially modified threading.
|
|||
|
def is_main_thread() -> bool:
|
|||
|
"""Attempt to reliably check if we are in the main thread."""
|
|||
|
try:
|
|||
|
signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
|
|||
|
return True
|
|||
|
except (TypeError, ValueError):
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
######
|
|||
|
# Call the function and get the coroutine object, while giving helpful
|
|||
|
# errors for common mistakes. Returns coroutine object.
|
|||
|
######
|
|||
|
def coroutine_or_error(
|
|||
|
async_fn: Callable[[Unpack[PosArgsT]], Awaitable[RetT]],
|
|||
|
*args: Unpack[PosArgsT],
|
|||
|
) -> collections.abc.Coroutine[object, NoReturn, RetT]:
|
|||
|
def _return_value_looks_like_wrong_library(value: object) -> bool:
|
|||
|
# Returned by legacy @asyncio.coroutine functions, which includes
|
|||
|
# a surprising proportion of asyncio builtins.
|
|||
|
if isinstance(value, collections.abc.Generator):
|
|||
|
return True
|
|||
|
# The protocol for detecting an asyncio Future-like object
|
|||
|
if getattr(value, "_asyncio_future_blocking", None) is not None:
|
|||
|
return True
|
|||
|
# This janky check catches tornado Futures and twisted Deferreds.
|
|||
|
# By the time we're calling this function, we already know
|
|||
|
# something has gone wrong, so a heuristic is pretty safe.
|
|||
|
return value.__class__.__name__ in ("Future", "Deferred")
|
|||
|
|
|||
|
# Make sure a sync-fn-that-returns-coroutine still sees itself as being
|
|||
|
# in trio context
|
|||
|
prev_loop, sniffio_loop.name = sniffio_loop.name, "trio"
|
|||
|
|
|||
|
try:
|
|||
|
coro = async_fn(*args)
|
|||
|
|
|||
|
except TypeError:
|
|||
|
# Give good error for: nursery.start_soon(trio.sleep(1))
|
|||
|
if isinstance(async_fn, collections.abc.Coroutine):
|
|||
|
# explicitly close coroutine to avoid RuntimeWarning
|
|||
|
async_fn.close()
|
|||
|
|
|||
|
raise TypeError(
|
|||
|
"Trio was expecting an async function, but instead it got "
|
|||
|
f"a coroutine object {async_fn!r}\n"
|
|||
|
"\n"
|
|||
|
"Probably you did something like:\n"
|
|||
|
"\n"
|
|||
|
f" trio.run({async_fn.__name__}(...)) # incorrect!\n"
|
|||
|
f" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
|
|||
|
"\n"
|
|||
|
"Instead, you want (notice the parentheses!):\n"
|
|||
|
"\n"
|
|||
|
f" trio.run({async_fn.__name__}, ...) # correct!\n"
|
|||
|
f" nursery.start_soon({async_fn.__name__}, ...) # correct!"
|
|||
|
) from None
|
|||
|
|
|||
|
# Give good error for: nursery.start_soon(future)
|
|||
|
if _return_value_looks_like_wrong_library(async_fn):
|
|||
|
raise TypeError(
|
|||
|
"Trio was expecting an async function, but instead it got "
|
|||
|
f"{async_fn!r} – are you trying to use a library written for "
|
|||
|
"asyncio/twisted/tornado or similar? That won't work "
|
|||
|
"without some sort of compatibility shim."
|
|||
|
) from None
|
|||
|
|
|||
|
raise
|
|||
|
|
|||
|
finally:
|
|||
|
sniffio_loop.name = prev_loop
|
|||
|
|
|||
|
# We can't check iscoroutinefunction(async_fn), because that will fail
|
|||
|
# for things like functools.partial objects wrapping an async
|
|||
|
# function. So we have to just call it and then check whether the
|
|||
|
# return value is a coroutine object.
|
|||
|
# Note: will not be necessary on python>=3.8, see https://bugs.python.org/issue34890
|
|||
|
# TODO: python3.7 support is now dropped, so the above can be addressed.
|
|||
|
if not isinstance(coro, collections.abc.Coroutine):
|
|||
|
# Give good error for: nursery.start_soon(func_returning_future)
|
|||
|
if _return_value_looks_like_wrong_library(coro):
|
|||
|
raise TypeError(
|
|||
|
f"Trio got unexpected {coro!r} – are you trying to use a "
|
|||
|
"library written for asyncio/twisted/tornado or similar? "
|
|||
|
"That won't work without some sort of compatibility shim."
|
|||
|
)
|
|||
|
|
|||
|
if inspect.isasyncgen(coro):
|
|||
|
raise TypeError(
|
|||
|
"start_soon expected an async function but got an async "
|
|||
|
f"generator {coro!r}"
|
|||
|
)
|
|||
|
|
|||
|
# Give good error for: nursery.start_soon(some_sync_fn)
|
|||
|
raise TypeError(
|
|||
|
"Trio expected an async function, but {!r} appears to be "
|
|||
|
"synchronous".format(getattr(async_fn, "__qualname__", async_fn))
|
|||
|
)
|
|||
|
|
|||
|
return coro
|
|||
|
|
|||
|
|
|||
|
class ConflictDetector:
|
|||
|
"""Detect when two tasks are about to perform operations that would
|
|||
|
conflict.
|
|||
|
|
|||
|
Use as a synchronous context manager; if two tasks enter it at the same
|
|||
|
time then the second one raises an error. You can use it when there are
|
|||
|
two pieces of code that *would* collide and need a lock if they ever were
|
|||
|
called at the same time, but that should never happen.
|
|||
|
|
|||
|
We use this in particular for things like, making sure that two different
|
|||
|
tasks don't call sendall simultaneously on the same stream.
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, msg: str) -> None:
|
|||
|
self._msg = msg
|
|||
|
self._held = False
|
|||
|
|
|||
|
def __enter__(self) -> None:
|
|||
|
if self._held:
|
|||
|
raise trio.BusyResourceError(self._msg)
|
|||
|
else:
|
|||
|
self._held = True
|
|||
|
|
|||
|
def __exit__(
|
|||
|
self,
|
|||
|
exc_type: type[BaseException] | None,
|
|||
|
exc_value: BaseException | None,
|
|||
|
traceback: TracebackType | None,
|
|||
|
) -> None:
|
|||
|
self._held = False
|
|||
|
|
|||
|
|
|||
|
def async_wraps(
|
|||
|
cls: type[object],
|
|||
|
wrapped_cls: type[object],
|
|||
|
attr_name: str,
|
|||
|
) -> Callable[[CallT], CallT]:
|
|||
|
"""Similar to wraps, but for async wrappers of non-async functions."""
|
|||
|
|
|||
|
def decorator(func: CallT) -> CallT:
|
|||
|
func.__name__ = attr_name
|
|||
|
func.__qualname__ = ".".join((cls.__qualname__, attr_name))
|
|||
|
|
|||
|
func.__doc__ = f"Like :meth:`~{wrapped_cls.__module__}.{wrapped_cls.__qualname__}.{attr_name}`, but async."
|
|||
|
|
|||
|
return func
|
|||
|
|
|||
|
return decorator
|
|||
|
|
|||
|
|
|||
|
def fixup_module_metadata(
|
|||
|
module_name: str, namespace: collections.abc.Mapping[str, object]
|
|||
|
) -> None:
|
|||
|
seen_ids: set[int] = set()
|
|||
|
|
|||
|
def fix_one(qualname: str, name: str, obj: object) -> None:
|
|||
|
# avoid infinite recursion (relevant when using
|
|||
|
# typing.Generic, for example)
|
|||
|
if id(obj) in seen_ids:
|
|||
|
return
|
|||
|
seen_ids.add(id(obj))
|
|||
|
|
|||
|
mod = getattr(obj, "__module__", None)
|
|||
|
if mod is not None and mod.startswith("trio."):
|
|||
|
obj.__module__ = module_name
|
|||
|
# Modules, unlike everything else in Python, put fully-qualified
|
|||
|
# names into their __name__ attribute. We check for "." to avoid
|
|||
|
# rewriting these.
|
|||
|
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
|||
|
obj.__name__ = name
|
|||
|
if hasattr(obj, "__qualname__"):
|
|||
|
obj.__qualname__ = qualname
|
|||
|
if isinstance(obj, type):
|
|||
|
for attr_name, attr_value in obj.__dict__.items():
|
|||
|
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
|||
|
|
|||
|
for objname, obj in namespace.items():
|
|||
|
if not objname.startswith("_"): # ignore private attributes
|
|||
|
fix_one(objname, objname, obj)
|
|||
|
|
|||
|
|
|||
|
# We need ParamSpec to type this "properly", but that requires a runtime typing_extensions import
|
|||
|
# to use as a class base. This is only used at runtime and isn't correct for type checkers anyway,
|
|||
|
# so don't bother.
|
|||
|
class generic_function(Generic[RetT]):
|
|||
|
"""Decorator that makes a function indexable, to communicate
|
|||
|
non-inferrable generic type parameters to a static type checker.
|
|||
|
|
|||
|
If you write::
|
|||
|
|
|||
|
@generic_function
|
|||
|
def open_memory_channel(max_buffer_size: int) -> Tuple[
|
|||
|
SendChannel[T], ReceiveChannel[T]
|
|||
|
]: ...
|
|||
|
|
|||
|
it is valid at runtime to say ``open_memory_channel[bytes](5)``.
|
|||
|
This behaves identically to ``open_memory_channel(5)`` at runtime,
|
|||
|
and currently won't type-check without a mypy plugin or clever stubs,
|
|||
|
but at least it becomes possible to write those.
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, fn: Callable[..., RetT]) -> None:
|
|||
|
update_wrapper(self, fn)
|
|||
|
self._fn = fn
|
|||
|
|
|||
|
def __call__(self, *args: Any, **kwargs: Any) -> RetT:
|
|||
|
return self._fn(*args, **kwargs)
|
|||
|
|
|||
|
def __getitem__(self, subscript: object) -> Self:
|
|||
|
return self
|
|||
|
|
|||
|
|
|||
|
def _init_final_cls(cls: type[object]) -> NoReturn:
|
|||
|
"""Raises an exception when a final class is subclassed."""
|
|||
|
raise TypeError(f"{cls.__module__}.{cls.__qualname__} does not support subclassing")
|
|||
|
|
|||
|
|
|||
|
def _final_impl(decorated: type[T]) -> type[T]:
|
|||
|
"""Decorator that enforces a class to be final (i.e., subclass not allowed).
|
|||
|
|
|||
|
If a class uses this metaclass like this::
|
|||
|
|
|||
|
@final
|
|||
|
class SomeClass:
|
|||
|
pass
|
|||
|
|
|||
|
The metaclass will ensure that no subclass can be created.
|
|||
|
|
|||
|
Raises
|
|||
|
------
|
|||
|
- TypeError if a subclass is created
|
|||
|
"""
|
|||
|
# Override the method blindly. We're always going to raise, so it doesn't
|
|||
|
# matter what the original did (if anything).
|
|||
|
decorated.__init_subclass__ = classmethod(_init_final_cls) # type: ignore[assignment]
|
|||
|
# Apply the typing decorator, in 3.11+ it adds a __final__ marker attribute.
|
|||
|
return std_final(decorated)
|
|||
|
|
|||
|
|
|||
|
if TYPE_CHECKING:
|
|||
|
from typing import final
|
|||
|
else:
|
|||
|
final = _final_impl
|
|||
|
|
|||
|
|
|||
|
@final # No subclassing of NoPublicConstructor itself.
|
|||
|
class NoPublicConstructor(ABCMeta):
|
|||
|
"""Metaclass that ensures a private constructor.
|
|||
|
|
|||
|
If a class uses this metaclass like this::
|
|||
|
|
|||
|
@final
|
|||
|
class SomeClass(metaclass=NoPublicConstructor):
|
|||
|
pass
|
|||
|
|
|||
|
The metaclass will ensure that no instance can be initialized. This should always be
|
|||
|
used with @final.
|
|||
|
|
|||
|
If you try to instantiate your class (SomeClass()), a TypeError will be thrown. Use
|
|||
|
_create() instead in the class's implementation.
|
|||
|
|
|||
|
Raises
|
|||
|
------
|
|||
|
- TypeError if an instance is created.
|
|||
|
"""
|
|||
|
|
|||
|
def __call__(cls, *args: object, **kwargs: object) -> None:
|
|||
|
raise TypeError(
|
|||
|
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
|
|||
|
)
|
|||
|
|
|||
|
def _create(cls: type[T], *args: object, **kwargs: object) -> T:
|
|||
|
return super().__call__(*args, **kwargs) # type: ignore
|
|||
|
|
|||
|
|
|||
|
def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str:
|
|||
|
"""Return the fully-qualified name of the async generator function
|
|||
|
that produced the async generator iterator *agen*.
|
|||
|
"""
|
|||
|
if not hasattr(agen, "ag_code"): # pragma: no cover
|
|||
|
return repr(agen)
|
|||
|
try:
|
|||
|
module = agen.ag_frame.f_globals["__name__"]
|
|||
|
except (AttributeError, KeyError):
|
|||
|
module = f"<{agen.ag_code.co_filename}>"
|
|||
|
try:
|
|||
|
qualname = agen.__qualname__
|
|||
|
except AttributeError:
|
|||
|
qualname = agen.ag_code.co_name
|
|||
|
return f"{module}.{qualname}"
|
|||
|
|
|||
|
|
|||
|
# work around a pyright error
|
|||
|
if TYPE_CHECKING:
|
|||
|
Fn = TypeVar("Fn", bound=Callable[..., object])
|
|||
|
|
|||
|
def wraps(
|
|||
|
wrapped: Callable[..., object],
|
|||
|
assigned: Sequence[str] = ...,
|
|||
|
updated: Sequence[str] = ...,
|
|||
|
) -> Callable[[Fn], Fn]: ...
|
|||
|
|
|||
|
else:
|
|||
|
from functools import wraps # noqa: F401 # this is re-exported
|