webdriver_template/telecli/lib/python3.11/site-packages/trio/testing/_sequencer.py

87 lines
2.7 KiB
Python
Raw Normal View History

2024-08-10 14:48:21 +03:00
from __future__ import annotations
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
import attrs
from .. import Event, _core, _util
if TYPE_CHECKING:
from collections.abc import AsyncIterator
@_util.final
@attrs.define(eq=False, slots=False)
class Sequencer:
"""A convenience class for forcing code in different tasks to run in an
explicit linear order.
Instances of this class implement a ``__call__`` method which returns an
async context manager. The idea is that you pass a sequence number to
``__call__`` to say where this block of code should go in the linear
sequence. Block 0 starts immediately, and then block N doesn't start until
block N-1 has finished.
Example:
An extremely elaborate way to print the numbers 0-5, in order::
async def worker1(seq):
async with seq(0):
print(0)
async with seq(4):
print(4)
async def worker2(seq):
async with seq(2):
print(2)
async with seq(5):
print(5)
async def worker3(seq):
async with seq(1):
print(1)
async with seq(3):
print(3)
async def main():
seq = trio.testing.Sequencer()
async with trio.open_nursery() as nursery:
nursery.start_soon(worker1, seq)
nursery.start_soon(worker2, seq)
nursery.start_soon(worker3, seq)
"""
_sequence_points: defaultdict[int, Event] = attrs.field(
factory=lambda: defaultdict(Event), init=False
)
_claimed: set[int] = attrs.field(factory=set, init=False)
_broken: bool = attrs.field(default=False, init=False)
@asynccontextmanager
async def __call__(self, position: int) -> AsyncIterator[None]:
if position in self._claimed:
raise RuntimeError(f"Attempted to reuse sequence point {position}")
if self._broken:
raise RuntimeError("sequence broken!")
self._claimed.add(position)
if position != 0:
try:
await self._sequence_points[position].wait()
except _core.Cancelled:
self._broken = True
for event in self._sequence_points.values():
event.set()
raise RuntimeError(
"Sequencer wait cancelled -- sequence broken"
) from None
else:
if self._broken:
raise RuntimeError("sequence broken!")
try:
yield
finally:
self._sequence_points[position + 1].set()