312 lines
10 KiB
Python
312 lines
10 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import warnings
|
||
|
from collections import OrderedDict, deque
|
||
|
from dataclasses import dataclass, field
|
||
|
from types import TracebackType
|
||
|
from typing import Generic, NamedTuple, TypeVar
|
||
|
|
||
|
from .. import (
|
||
|
BrokenResourceError,
|
||
|
ClosedResourceError,
|
||
|
EndOfStream,
|
||
|
WouldBlock,
|
||
|
)
|
||
|
from .._core._testing import TaskInfo, get_current_task
|
||
|
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
|
||
|
from ..lowlevel import checkpoint
|
||
|
|
||
|
T_Item = TypeVar("T_Item")
|
||
|
T_co = TypeVar("T_co", covariant=True)
|
||
|
T_contra = TypeVar("T_contra", contravariant=True)
|
||
|
|
||
|
|
||
|
class MemoryObjectStreamStatistics(NamedTuple):
|
||
|
current_buffer_used: int #: number of items stored in the buffer
|
||
|
#: maximum number of items that can be stored on this stream (or :data:`math.inf`)
|
||
|
max_buffer_size: float
|
||
|
open_send_streams: int #: number of unclosed clones of the send stream
|
||
|
open_receive_streams: int #: number of unclosed clones of the receive stream
|
||
|
#: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
|
||
|
tasks_waiting_send: int
|
||
|
#: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
|
||
|
tasks_waiting_receive: int
|
||
|
|
||
|
|
||
|
@dataclass(eq=False)
|
||
|
class MemoryObjectItemReceiver(Generic[T_Item]):
|
||
|
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
|
||
|
item: T_Item = field(init=False)
|
||
|
|
||
|
|
||
|
@dataclass(eq=False)
|
||
|
class MemoryObjectStreamState(Generic[T_Item]):
|
||
|
max_buffer_size: float = field()
|
||
|
buffer: deque[T_Item] = field(init=False, default_factory=deque)
|
||
|
open_send_channels: int = field(init=False, default=0)
|
||
|
open_receive_channels: int = field(init=False, default=0)
|
||
|
waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field(
|
||
|
init=False, default_factory=OrderedDict
|
||
|
)
|
||
|
waiting_senders: OrderedDict[Event, T_Item] = field(
|
||
|
init=False, default_factory=OrderedDict
|
||
|
)
|
||
|
|
||
|
def statistics(self) -> MemoryObjectStreamStatistics:
|
||
|
return MemoryObjectStreamStatistics(
|
||
|
len(self.buffer),
|
||
|
self.max_buffer_size,
|
||
|
self.open_send_channels,
|
||
|
self.open_receive_channels,
|
||
|
len(self.waiting_senders),
|
||
|
len(self.waiting_receivers),
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass(eq=False)
|
||
|
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
|
||
|
_state: MemoryObjectStreamState[T_co]
|
||
|
_closed: bool = field(init=False, default=False)
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
self._state.open_receive_channels += 1
|
||
|
|
||
|
def receive_nowait(self) -> T_co:
|
||
|
"""
|
||
|
Receive the next item if it can be done without waiting.
|
||
|
|
||
|
:return: the received item
|
||
|
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||
|
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
|
||
|
closed from the sending end
|
||
|
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
|
||
|
waiting to send
|
||
|
|
||
|
"""
|
||
|
if self._closed:
|
||
|
raise ClosedResourceError
|
||
|
|
||
|
if self._state.waiting_senders:
|
||
|
# Get the item from the next sender
|
||
|
send_event, item = self._state.waiting_senders.popitem(last=False)
|
||
|
self._state.buffer.append(item)
|
||
|
send_event.set()
|
||
|
|
||
|
if self._state.buffer:
|
||
|
return self._state.buffer.popleft()
|
||
|
elif not self._state.open_send_channels:
|
||
|
raise EndOfStream
|
||
|
|
||
|
raise WouldBlock
|
||
|
|
||
|
async def receive(self) -> T_co:
|
||
|
await checkpoint()
|
||
|
try:
|
||
|
return self.receive_nowait()
|
||
|
except WouldBlock:
|
||
|
# Add ourselves in the queue
|
||
|
receive_event = Event()
|
||
|
receiver = MemoryObjectItemReceiver[T_co]()
|
||
|
self._state.waiting_receivers[receive_event] = receiver
|
||
|
|
||
|
try:
|
||
|
await receive_event.wait()
|
||
|
finally:
|
||
|
self._state.waiting_receivers.pop(receive_event, None)
|
||
|
|
||
|
try:
|
||
|
return receiver.item
|
||
|
except AttributeError:
|
||
|
raise EndOfStream
|
||
|
|
||
|
def clone(self) -> MemoryObjectReceiveStream[T_co]:
|
||
|
"""
|
||
|
Create a clone of this receive stream.
|
||
|
|
||
|
Each clone can be closed separately. Only when all clones have been closed will
|
||
|
the receiving end of the memory stream be considered closed by the sending ends.
|
||
|
|
||
|
:return: the cloned stream
|
||
|
|
||
|
"""
|
||
|
if self._closed:
|
||
|
raise ClosedResourceError
|
||
|
|
||
|
return MemoryObjectReceiveStream(_state=self._state)
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""
|
||
|
Close the stream.
|
||
|
|
||
|
This works the exact same way as :meth:`aclose`, but is provided as a special
|
||
|
case for the benefit of synchronous callbacks.
|
||
|
|
||
|
"""
|
||
|
if not self._closed:
|
||
|
self._closed = True
|
||
|
self._state.open_receive_channels -= 1
|
||
|
if self._state.open_receive_channels == 0:
|
||
|
send_events = list(self._state.waiting_senders.keys())
|
||
|
for event in send_events:
|
||
|
event.set()
|
||
|
|
||
|
async def aclose(self) -> None:
|
||
|
self.close()
|
||
|
|
||
|
def statistics(self) -> MemoryObjectStreamStatistics:
|
||
|
"""
|
||
|
Return statistics about the current state of this stream.
|
||
|
|
||
|
.. versionadded:: 3.0
|
||
|
"""
|
||
|
return self._state.statistics()
|
||
|
|
||
|
def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
|
||
|
return self
|
||
|
|
||
|
def __exit__(
|
||
|
self,
|
||
|
exc_type: type[BaseException] | None,
|
||
|
exc_val: BaseException | None,
|
||
|
exc_tb: TracebackType | None,
|
||
|
) -> None:
|
||
|
self.close()
|
||
|
|
||
|
def __del__(self) -> None:
|
||
|
if not self._closed:
|
||
|
warnings.warn(
|
||
|
f"Unclosed <{self.__class__.__name__}>",
|
||
|
ResourceWarning,
|
||
|
source=self,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass(eq=False)
|
||
|
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
|
||
|
_state: MemoryObjectStreamState[T_contra]
|
||
|
_closed: bool = field(init=False, default=False)
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
self._state.open_send_channels += 1
|
||
|
|
||
|
def send_nowait(self, item: T_contra) -> None:
|
||
|
"""
|
||
|
Send an item immediately if it can be done without waiting.
|
||
|
|
||
|
:param item: the item to send
|
||
|
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||
|
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
|
||
|
receiving end
|
||
|
:raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
|
||
|
to receive
|
||
|
|
||
|
"""
|
||
|
if self._closed:
|
||
|
raise ClosedResourceError
|
||
|
if not self._state.open_receive_channels:
|
||
|
raise BrokenResourceError
|
||
|
|
||
|
while self._state.waiting_receivers:
|
||
|
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
|
||
|
if not receiver.task_info.has_pending_cancellation():
|
||
|
receiver.item = item
|
||
|
receive_event.set()
|
||
|
return
|
||
|
|
||
|
if len(self._state.buffer) < self._state.max_buffer_size:
|
||
|
self._state.buffer.append(item)
|
||
|
else:
|
||
|
raise WouldBlock
|
||
|
|
||
|
async def send(self, item: T_contra) -> None:
|
||
|
"""
|
||
|
Send an item to the stream.
|
||
|
|
||
|
If the buffer is full, this method blocks until there is again room in the
|
||
|
buffer or the item can be sent directly to a receiver.
|
||
|
|
||
|
:param item: the item to send
|
||
|
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||
|
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
|
||
|
receiving end
|
||
|
|
||
|
"""
|
||
|
await checkpoint()
|
||
|
try:
|
||
|
self.send_nowait(item)
|
||
|
except WouldBlock:
|
||
|
# Wait until there's someone on the receiving end
|
||
|
send_event = Event()
|
||
|
self._state.waiting_senders[send_event] = item
|
||
|
try:
|
||
|
await send_event.wait()
|
||
|
except BaseException:
|
||
|
self._state.waiting_senders.pop(send_event, None)
|
||
|
raise
|
||
|
|
||
|
if send_event in self._state.waiting_senders:
|
||
|
del self._state.waiting_senders[send_event]
|
||
|
raise BrokenResourceError from None
|
||
|
|
||
|
def clone(self) -> MemoryObjectSendStream[T_contra]:
|
||
|
"""
|
||
|
Create a clone of this send stream.
|
||
|
|
||
|
Each clone can be closed separately. Only when all clones have been closed will
|
||
|
the sending end of the memory stream be considered closed by the receiving ends.
|
||
|
|
||
|
:return: the cloned stream
|
||
|
|
||
|
"""
|
||
|
if self._closed:
|
||
|
raise ClosedResourceError
|
||
|
|
||
|
return MemoryObjectSendStream(_state=self._state)
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""
|
||
|
Close the stream.
|
||
|
|
||
|
This works the exact same way as :meth:`aclose`, but is provided as a special
|
||
|
case for the benefit of synchronous callbacks.
|
||
|
|
||
|
"""
|
||
|
if not self._closed:
|
||
|
self._closed = True
|
||
|
self._state.open_send_channels -= 1
|
||
|
if self._state.open_send_channels == 0:
|
||
|
receive_events = list(self._state.waiting_receivers.keys())
|
||
|
self._state.waiting_receivers.clear()
|
||
|
for event in receive_events:
|
||
|
event.set()
|
||
|
|
||
|
async def aclose(self) -> None:
|
||
|
self.close()
|
||
|
|
||
|
def statistics(self) -> MemoryObjectStreamStatistics:
|
||
|
"""
|
||
|
Return statistics about the current state of this stream.
|
||
|
|
||
|
.. versionadded:: 3.0
|
||
|
"""
|
||
|
return self._state.statistics()
|
||
|
|
||
|
def __enter__(self) -> MemoryObjectSendStream[T_contra]:
|
||
|
return self
|
||
|
|
||
|
def __exit__(
|
||
|
self,
|
||
|
exc_type: type[BaseException] | None,
|
||
|
exc_val: BaseException | None,
|
||
|
exc_tb: TracebackType | None,
|
||
|
) -> None:
|
||
|
self.close()
|
||
|
|
||
|
def __del__(self) -> None:
|
||
|
if not self._closed:
|
||
|
warnings.warn(
|
||
|
f"Unclosed <{self.__class__.__name__}>",
|
||
|
ResourceWarning,
|
||
|
source=self,
|
||
|
)
|