Providing the session with DI#

In our original script, we had to repeat the logic to construct a session instance for every request type. This is not very DRY.

In this section, we’ll use dependency injection to centralize the session creation logic and make it available to all handlers.

  1from contextlib import asynccontextmanager
  2from typing import Any, AsyncGenerator, Dict, List, Optional
  3
  4from sqlalchemy import select
  5from sqlalchemy.exc import IntegrityError, NoResultFound
  6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  8
  9from litestar import Litestar, get, post, put
 10from litestar.datastructures import State
 11from litestar.di import NamedDependency
 12from litestar.exceptions import ClientException, NotFoundException
 13from litestar.params import FromPath, FromQuery
 14from litestar.status_codes import HTTP_409_CONFLICT
 15
 16TodoType = Dict[str, Any]
 17TodoCollectionType = List[TodoType]
 18
 19
 20class Base(DeclarativeBase): ...
 21
 22
 23class TodoItem(Base):
 24    __tablename__ = "todo_items"
 25
 26    title: Mapped[str] = mapped_column(primary_key=True)
 27    done: Mapped[bool]
 28
 29
 30@asynccontextmanager
 31async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 32    engine = getattr(app.state, "engine", None)
 33    if engine is None:
 34        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 35        app.state.engine = engine
 36
 37    async with engine.begin() as conn:
 38        await conn.run_sync(Base.metadata.create_all)
 39
 40    try:
 41        yield
 42    finally:
 43        await engine.dispose()
 44
 45
 46sessionmaker = async_sessionmaker(expire_on_commit=False)
 47
 48
 49async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 50    async with sessionmaker(bind=state.engine) as session:
 51        try:
 52            async with session.begin():
 53                yield session
 54        except IntegrityError as exc:
 55            raise ClientException(
 56                status_code=HTTP_409_CONFLICT,
 57                detail=str(exc),
 58            ) from exc
 59
 60
 61def serialize_todo(todo: TodoItem) -> TodoType:
 62    return {"title": todo.title, "done": todo.done}
 63
 64
 65async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
 66    query = select(TodoItem).where(TodoItem.title == todo_name)
 67    result = await session.execute(query)
 68    try:
 69        return result.scalar_one()
 70    except NoResultFound as e:
 71        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 72
 73
 74async def get_todo_list(done: FromQuery[Optional[bool]], session: AsyncSession) -> List[TodoItem]:
 75    query = select(TodoItem)
 76    if done is not None:
 77        query = query.where(TodoItem.done.is_(done))
 78
 79    result = await session.execute(query)
 80    return result.scalars().all()
 81
 82
 83@get("/")
 84async def get_list(
 85    transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
 86) -> TodoCollectionType:
 87    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 88
 89
 90@post("/")
 91async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
 92    new_todo = TodoItem(title=data["title"], done=data["done"])
 93    transaction.add(new_todo)
 94    return serialize_todo(new_todo)
 95
 96
 97@put("/{item_title:str}")
 98async def update_item(
 99    item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
100) -> TodoType:
101    todo_item = await get_todo_by_title(item_title, transaction)
102    todo_item.title = data["title"]
103    todo_item.done = data["done"]
104    return serialize_todo(todo_item)
105
106
107app = Litestar(
108    [get_list, add_item, update_item],
109    dependencies={"transaction": provide_transaction},
110    lifespan=[db_connection],
111)
  1from contextlib import asynccontextmanager
  2from typing import Any, Optional
  3from collections.abc import AsyncGenerator
  4
  5from sqlalchemy import select
  6from sqlalchemy.exc import IntegrityError, NoResultFound
  7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  9
 10from litestar import Litestar, get, post, put
 11from litestar.datastructures import State
 12from litestar.di import NamedDependency
 13from litestar.exceptions import ClientException, NotFoundException
 14from litestar.params import FromPath, FromQuery
 15from litestar.status_codes import HTTP_409_CONFLICT
 16
 17TodoType = dict[str, Any]
 18TodoCollectionType = list[TodoType]
 19
 20
 21class Base(DeclarativeBase): ...
 22
 23
 24class TodoItem(Base):
 25    __tablename__ = "todo_items"
 26
 27    title: Mapped[str] = mapped_column(primary_key=True)
 28    done: Mapped[bool]
 29
 30
 31@asynccontextmanager
 32async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 33    engine = getattr(app.state, "engine", None)
 34    if engine is None:
 35        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 36        app.state.engine = engine
 37
 38    async with engine.begin() as conn:
 39        await conn.run_sync(Base.metadata.create_all)
 40
 41    try:
 42        yield
 43    finally:
 44        await engine.dispose()
 45
 46
 47sessionmaker = async_sessionmaker(expire_on_commit=False)
 48
 49
 50async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 51    async with sessionmaker(bind=state.engine) as session:
 52        try:
 53            async with session.begin():
 54                yield session
 55        except IntegrityError as exc:
 56            raise ClientException(
 57                status_code=HTTP_409_CONFLICT,
 58                detail=str(exc),
 59            ) from exc
 60
 61
 62def serialize_todo(todo: TodoItem) -> TodoType:
 63    return {"title": todo.title, "done": todo.done}
 64
 65
 66async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
 67    query = select(TodoItem).where(TodoItem.title == todo_name)
 68    result = await session.execute(query)
 69    try:
 70        return result.scalar_one()
 71    except NoResultFound as e:
 72        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 73
 74
 75async def get_todo_list(done: FromQuery[Optional[bool]], session: AsyncSession) -> list[TodoItem]:
 76    query = select(TodoItem)
 77    if done is not None:
 78        query = query.where(TodoItem.done.is_(done))
 79
 80    result = await session.execute(query)
 81    return result.scalars().all()
 82
 83
 84@get("/")
 85async def get_list(
 86    transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
 87) -> TodoCollectionType:
 88    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 89
 90
 91@post("/")
 92async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
 93    new_todo = TodoItem(title=data["title"], done=data["done"])
 94    transaction.add(new_todo)
 95    return serialize_todo(new_todo)
 96
 97
 98@put("/{item_title:str}")
 99async def update_item(
100    item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
101) -> TodoType:
102    todo_item = await get_todo_by_title(item_title, transaction)
103    todo_item.title = data["title"]
104    todo_item.done = data["done"]
105    return serialize_todo(todo_item)
106
107
108app = Litestar(
109    [get_list, add_item, update_item],
110    dependencies={"transaction": provide_transaction},
111    lifespan=[db_connection],
112)
  1from contextlib import asynccontextmanager
  2from typing import Any
  3from collections.abc import AsyncGenerator
  4
  5from sqlalchemy import select
  6from sqlalchemy.exc import IntegrityError, NoResultFound
  7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  9
 10from litestar import Litestar, get, post, put
 11from litestar.datastructures import State
 12from litestar.di import NamedDependency
 13from litestar.exceptions import ClientException, NotFoundException
 14from litestar.params import FromPath, FromQuery
 15from litestar.status_codes import HTTP_409_CONFLICT
 16
 17TodoType = dict[str, Any]
 18TodoCollectionType = list[TodoType]
 19
 20
 21class Base(DeclarativeBase): ...
 22
 23
 24class TodoItem(Base):
 25    __tablename__ = "todo_items"
 26
 27    title: Mapped[str] = mapped_column(primary_key=True)
 28    done: Mapped[bool]
 29
 30
 31@asynccontextmanager
 32async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 33    engine = getattr(app.state, "engine", None)
 34    if engine is None:
 35        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 36        app.state.engine = engine
 37
 38    async with engine.begin() as conn:
 39        await conn.run_sync(Base.metadata.create_all)
 40
 41    try:
 42        yield
 43    finally:
 44        await engine.dispose()
 45
 46
 47sessionmaker = async_sessionmaker(expire_on_commit=False)
 48
 49
 50async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 51    async with sessionmaker(bind=state.engine) as session:
 52        try:
 53            async with session.begin():
 54                yield session
 55        except IntegrityError as exc:
 56            raise ClientException(
 57                status_code=HTTP_409_CONFLICT,
 58                detail=str(exc),
 59            ) from exc
 60
 61
 62def serialize_todo(todo: TodoItem) -> TodoType:
 63    return {"title": todo.title, "done": todo.done}
 64
 65
 66async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
 67    query = select(TodoItem).where(TodoItem.title == todo_name)
 68    result = await session.execute(query)
 69    try:
 70        return result.scalar_one()
 71    except NoResultFound as e:
 72        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 73
 74
 75async def get_todo_list(done: FromQuery[bool | None], session: AsyncSession) -> list[TodoItem]:
 76    query = select(TodoItem)
 77    if done is not None:
 78        query = query.where(TodoItem.done.is_(done))
 79
 80    result = await session.execute(query)
 81    return result.scalars().all()
 82
 83
 84@get("/")
 85async def get_list(
 86    transaction: NamedDependency[AsyncSession], done: FromQuery[bool | None] = None
 87) -> TodoCollectionType:
 88    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 89
 90
 91@post("/")
 92async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
 93    new_todo = TodoItem(title=data["title"], done=data["done"])
 94    transaction.add(new_todo)
 95    return serialize_todo(new_todo)
 96
 97
 98@put("/{item_title:str}")
 99async def update_item(
100    item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
101) -> TodoType:
102    todo_item = await get_todo_by_title(item_title, transaction)
103    todo_item.title = data["title"]
104    todo_item.done = data["done"]
105    return serialize_todo(todo_item)
106
107
108app = Litestar(
109    [get_list, add_item, update_item],
110    dependencies={"transaction": provide_transaction},
111    lifespan=[db_connection],
112)

In the previous example, the database session is created within each HTTP route handler function. In this script we use dependency injection to decouple creation of the session from the route handlers.

This script introduces a new async generator function called provide_transaction() that creates a new SQLAlchemy session, begins a transaction, and handles any integrity errors that might raise from within the transaction.

 1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 2    async with sessionmaker(bind=state.engine) as session:
 3        try:
 4            async with session.begin():
 5                yield session
 6        except IntegrityError as exc:
 7            raise ClientException(
 8                status_code=HTTP_409_CONFLICT,
 9                detail=str(exc),
10            ) from exc

That function is declared as a dependency to the Litestar application, using the name transaction.

1    todo_item.title = data["title"]
2    todo_item.done = data["done"]
3    return serialize_todo(todo_item)

In the route handlers, the database session is injected by adding a parameter named transaction and marking it with NamedDependency. The marker tells Litestar to supply the value from the transaction dependency at runtime, matching the parameter name against the dependency key.

1@get("/")
2async def get_list(
3    transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None

One final improvement in this script is exception handling. In the previous version, a litestar.exceptions.ClientException is raised inside the add_item() handler if there’s an integrity error raised during the insertion of the new TODO item. In our latest revision, we’ve been able to centralize this handling to occur inside the provide_transaction() function.

 1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 2    async with sessionmaker(bind=state.engine) as session:
 3        try:
 4            async with session.begin():
 5                yield session
 6        except IntegrityError as exc:
 7            raise ClientException(
 8                status_code=HTTP_409_CONFLICT,
 9                detail=str(exc),
10            ) from exc

This change broadens the scope of exception handling to any operation that uses the database session, not just the insertion of new items.

Compare handlers before and after DI#

Just for fun, lets compare the sets of application handlers before and after we introduced dependency injection for our session object:

 1@get("/")
 2async def get_list(
 3    transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
 4) -> TodoCollectionType:
 5    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 6
 7
 8@post("/")
 9async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
10    new_todo = TodoItem(title=data["title"], done=data["done"])
11    transaction.add(new_todo)
12    return serialize_todo(new_todo)
13
14
15@put("/{item_title:str}")
16async def update_item(
17    item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
18) -> TodoType:
19    todo_item = await get_todo_by_title(item_title, transaction)
20    todo_item.title = data["title"]
21    todo_item.done = data["done"]
22    return serialize_todo(todo_item)
 1@get("/")
 2async def get_list(state: State, done: FromQuery[Optional[bool]] = None) -> TodoCollectionType:
 3    async with sessionmaker(bind=state.engine) as session:
 4        return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
 5
 6
 7@post("/")
 8async def add_item(data: TodoType, state: State) -> TodoType:
 9    new_todo = TodoItem(title=data["title"], done=data["done"])
10    async with sessionmaker(bind=state.engine) as session:
11        try:
12            async with session.begin():
13                session.add(new_todo)
14        except IntegrityError as e:
15            raise ClientException(
16                status_code=HTTP_409_CONFLICT,
17                detail=f"TODO {new_todo.title!r} already exists",
18            ) from e
19
20    return serialize_todo(new_todo)
21
22
23@put("/{item_title:str}")
24async def update_item(item_title: FromPath[str], data: TodoType, state: State) -> TodoType:
25    async with sessionmaker(bind=state.engine) as session, session.begin():
26        todo_item = await get_todo_by_title(item_title, session)
27        todo_item.title = data["title"]
28        todo_item.done = data["done"]
29    return serialize_todo(todo_item)

Much better!

Next steps#

One of the niceties that we’ve lost is the ability to receive and return data to/from our handlers as instances of our data model. In the original TODO application, we modelled with Python dataclasses which are natively supported for (de)serialization by Litestar. In the next section, we will look at how we can get this functionality back!