93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import ast
|
||
|
import contextlib
|
||
|
import inspect
|
||
|
import sys
|
||
|
import types
|
||
|
import warnings
|
||
|
from code import InteractiveConsole
|
||
|
|
||
|
import outcome
|
||
|
|
||
|
import trio
|
||
|
import trio.lowlevel
|
||
|
from trio._util import final
|
||
|
|
||
|
|
||
|
@final
|
||
|
class TrioInteractiveConsole(InteractiveConsole):
|
||
|
# code.InteractiveInterpreter defines locals as Mapping[str, Any]
|
||
|
# but when we pass this to FunctionType it expects a dict. So
|
||
|
# we make the type more specific on our subclass
|
||
|
locals: dict[str, object]
|
||
|
|
||
|
def __init__(self, repl_locals: dict[str, object] | None = None):
|
||
|
super().__init__(locals=repl_locals)
|
||
|
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
|
||
|
|
||
|
def runcode(self, code: types.CodeType) -> None:
|
||
|
func = types.FunctionType(code, self.locals)
|
||
|
if inspect.iscoroutinefunction(func):
|
||
|
result = trio.from_thread.run(outcome.acapture, func)
|
||
|
else:
|
||
|
result = trio.from_thread.run_sync(outcome.capture, func)
|
||
|
if isinstance(result, outcome.Error):
|
||
|
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
|
||
|
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
|
||
|
# the user trying to quit the repl, but rather an error in the code. So, we
|
||
|
# don't try to inspect groups for SystemExit. Instead, we just print and
|
||
|
# return to the REPL.
|
||
|
if isinstance(result.error, SystemExit):
|
||
|
raise result.error
|
||
|
else:
|
||
|
# Inline our own version of self.showtraceback that can use
|
||
|
# outcome.Error.error directly to print clean tracebacks.
|
||
|
# This also means overriding self.showtraceback does nothing.
|
||
|
sys.last_type, sys.last_value = type(result.error), result.error
|
||
|
sys.last_traceback = result.error.__traceback__
|
||
|
# see https://docs.python.org/3/library/sys.html#sys.last_exc
|
||
|
if sys.version_info >= (3, 12):
|
||
|
sys.last_exc = result.error
|
||
|
|
||
|
# We always use sys.excepthook, unlike other implementations.
|
||
|
# This means that overriding self.write also does nothing to tbs.
|
||
|
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)
|
||
|
|
||
|
|
||
|
async def run_repl(console: TrioInteractiveConsole) -> None:
|
||
|
banner = (
|
||
|
f"trio REPL {sys.version} on {sys.platform}\n"
|
||
|
f'Use "await" directly instead of "trio.run()".\n'
|
||
|
f'Type "help", "copyright", "credits" or "license" '
|
||
|
f"for more information.\n"
|
||
|
f'{getattr(sys, "ps1", ">>> ")}import trio'
|
||
|
)
|
||
|
try:
|
||
|
await trio.to_thread.run_sync(console.interact, banner)
|
||
|
finally:
|
||
|
warnings.filterwarnings(
|
||
|
"ignore",
|
||
|
message=r"^coroutine .* was never awaited$",
|
||
|
category=RuntimeWarning,
|
||
|
)
|
||
|
|
||
|
|
||
|
def main(original_locals: dict[str, object]) -> None:
|
||
|
with contextlib.suppress(ImportError):
|
||
|
import readline # noqa: F401
|
||
|
|
||
|
repl_locals: dict[str, object] = {"trio": trio}
|
||
|
for key in {
|
||
|
"__name__",
|
||
|
"__package__",
|
||
|
"__loader__",
|
||
|
"__spec__",
|
||
|
"__builtins__",
|
||
|
"__file__",
|
||
|
}:
|
||
|
repl_locals[key] = original_locals[key]
|
||
|
|
||
|
console = TrioInteractiveConsole(repl_locals)
|
||
|
trio.run(run_repl, console)
|