from __future__ import annotations
import asyncio
import contextvars
from functools import partial
from typing import TYPE_CHECKING, Callable, TypeVar
import sniffio
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from concurrent.futures import ThreadPoolExecutor
import trio
T = TypeVar("T")
P = ParamSpec("P")
__all__ = (
"get_asyncio_executor",
"get_trio_capacity_limiter",
"set_asyncio_executor",
"set_trio_capacity_limiter",
"sync_to_thread",
)
class _State:
EXECUTOR: ThreadPoolExecutor | None = None
LIMITER: trio.CapacityLimiter | None = None
async def _run_sync_asyncio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
ctx = contextvars.copy_context()
bound_fn = partial(ctx.run, fn, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(get_asyncio_executor(), bound_fn) # pyright: ignore
async def _run_sync_trio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
import trio
return await trio.to_thread.run_sync(partial(fn, *args, **kwargs), limiter=get_trio_capacity_limiter())
[docs]
async def sync_to_thread(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Run the synchronous callable ``fn`` asynchronously in a worker thread.
When called from asyncio, uses :meth:`asyncio.loop.run_in_executor` to
run the callable. No executor is specified by default so the current loop's executor
is used. A specific executor can be set using
:func:`~litestar.concurrency.set_asyncio_executor`. This does not affect the loop's
default executor.
When called from trio, uses :func:`trio.to_thread.run_sync` to run the callable. No
capacity limiter is specified by default, but one can be set using
:func:`~litestar.concurrency.set_trio_capacity_limiter`. This does not affect trio's
default capacity limiter.
"""
if (library := sniffio.current_async_library()) == "asyncio":
return await _run_sync_asyncio(fn, *args, **kwargs)
if library == "trio":
return await _run_sync_trio(fn, *args, **kwargs)
raise RuntimeError("Unsupported async library or not in async context")
[docs]
def set_asyncio_executor(executor: ThreadPoolExecutor | None) -> None:
"""Set the executor in which synchronous callables will be run within an asyncio
context
"""
try:
sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass
else:
raise RuntimeError("Cannot set executor from running loop")
_State.EXECUTOR = executor
[docs]
def get_asyncio_executor() -> ThreadPoolExecutor | None:
"""Get the executor in which synchronous callables will be run within an asyncio
context
"""
return _State.EXECUTOR
[docs]
def set_trio_capacity_limiter(limiter: trio.CapacityLimiter | None) -> None:
"""Set the capacity limiter used when running synchronous callable within a trio
context
"""
try:
sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass
else:
raise RuntimeError("Cannot set limiter while in async context")
_State.LIMITER = limiter
[docs]
def get_trio_capacity_limiter() -> trio.CapacityLimiter | None:
"""Get the capacity limiter used when running synchronous callable within a trio
context
"""
return _State.LIMITER