Providing the session with DI#

In our original script, we had to repeat the logic to construct a session instance for every request type. This is not very DRY.

In this section, we’ll use dependency injection to centralize the session creation logic and make it available to all handlers.

  1from contextlib import asynccontextmanager
  2from typing import Any, AsyncGenerator, Dict, List, Optional
  3
  4from sqlalchemy import select
  5from sqlalchemy.exc import IntegrityError, NoResultFound
  6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  8
  9from litestar import Litestar, get, post, put
 10from litestar.datastructures import State
 11from litestar.exceptions import ClientException, NotFoundException
 12from litestar.status_codes import HTTP_409_CONFLICT
 13
 14TodoType = Dict[str, Any]
 15TodoCollectionType = List[TodoType]
 16
 17
 18class Base(DeclarativeBase): ...
 19
 20
 21class TodoItem(Base):
 22    __tablename__ = "todo_items"
 23
 24    title: Mapped[str] = mapped_column(primary_key=True)
 25    done: Mapped[bool]
 26
 27
 28@asynccontextmanager
 29async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 30    engine = getattr(app.state, "engine", None)
 31    if engine is None:
 32        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 33        app.state.engine = engine
 34
 35    async with engine.begin() as conn:
 36        await conn.run_sync(Base.metadata.create_all)
 37
 38    try:
 39        yield
 40    finally:
 41        await engine.dispose()
 42
 43
 44sessionmaker = async_sessionmaker(expire_on_commit=False)
 45
 46
 47async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 48    async with sessionmaker(bind=state.engine) as session:
 49        try:
 50            async with session.begin():
 51                yield session
 52        except IntegrityError as exc:
 53            raise ClientException(
 54                status_code=HTTP_409_CONFLICT,
 55                detail=str(exc),
 56            ) from exc
 57
 58
 59def serialize_todo(todo: TodoItem) -> TodoType:
 60    return {"title": todo.title, "done": todo.done}
 61
 62
 63async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
 64    query = select(TodoItem).where(TodoItem.title == todo_name)
 65    result = await session.execute(query)
 66    try:
 67        return result.scalar_one()
 68    except NoResultFound as e:
 69        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 70
 71
 72async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
 73    query = select(TodoItem)
 74    if done is not None:
 75        query = query.where(TodoItem.done.is_(done))
 76
 77    result = await session.execute(query)
 78    return result.scalars().all()
 79
 80
 81@get("/")
 82async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
 83    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 84
 85
 86@post("/")
 87async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
 88    new_todo = TodoItem(title=data["title"], done=data["done"])
 89    transaction.add(new_todo)
 90    return serialize_todo(new_todo)
 91
 92
 93@put("/{item_title:str}")
 94async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
 95    todo_item = await get_todo_by_title(item_title, transaction)
 96    todo_item.title = data["title"]
 97    todo_item.done = data["done"]
 98    return serialize_todo(todo_item)
 99
100
101app = Litestar(
102    [get_list, add_item, update_item],
103    dependencies={"transaction": provide_transaction},
104    lifespan=[db_connection],
105)
  1from contextlib import asynccontextmanager
  2from typing import Any, Dict, List, Optional
  3from collections.abc import AsyncGenerator
  4
  5from sqlalchemy import select
  6from sqlalchemy.exc import IntegrityError, NoResultFound
  7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  9
 10from litestar import Litestar, get, post, put
 11from litestar.datastructures import State
 12from litestar.exceptions import ClientException, NotFoundException
 13from litestar.status_codes import HTTP_409_CONFLICT
 14
 15TodoType = Dict[str, Any]
 16TodoCollectionType = List[TodoType]
 17
 18
 19class Base(DeclarativeBase): ...
 20
 21
 22class TodoItem(Base):
 23    __tablename__ = "todo_items"
 24
 25    title: Mapped[str] = mapped_column(primary_key=True)
 26    done: Mapped[bool]
 27
 28
 29@asynccontextmanager
 30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 31    engine = getattr(app.state, "engine", None)
 32    if engine is None:
 33        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 34        app.state.engine = engine
 35
 36    async with engine.begin() as conn:
 37        await conn.run_sync(Base.metadata.create_all)
 38
 39    try:
 40        yield
 41    finally:
 42        await engine.dispose()
 43
 44
 45sessionmaker = async_sessionmaker(expire_on_commit=False)
 46
 47
 48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 49    async with sessionmaker(bind=state.engine) as session:
 50        try:
 51            async with session.begin():
 52                yield session
 53        except IntegrityError as exc:
 54            raise ClientException(
 55                status_code=HTTP_409_CONFLICT,
 56                detail=str(exc),
 57            ) from exc
 58
 59
 60def serialize_todo(todo: TodoItem) -> TodoType:
 61    return {"title": todo.title, "done": todo.done}
 62
 63
 64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
 65    query = select(TodoItem).where(TodoItem.title == todo_name)
 66    result = await session.execute(query)
 67    try:
 68        return result.scalar_one()
 69    except NoResultFound as e:
 70        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 71
 72
 73async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
 74    query = select(TodoItem)
 75    if done is not None:
 76        query = query.where(TodoItem.done.is_(done))
 77
 78    result = await session.execute(query)
 79    return result.scalars().all()
 80
 81
 82@get("/")
 83async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
 84    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 85
 86
 87@post("/")
 88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
 89    new_todo = TodoItem(title=data["title"], done=data["done"])
 90    transaction.add(new_todo)
 91    return serialize_todo(new_todo)
 92
 93
 94@put("/{item_title:str}")
 95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
 96    todo_item = await get_todo_by_title(item_title, transaction)
 97    todo_item.title = data["title"]
 98    todo_item.done = data["done"]
 99    return serialize_todo(todo_item)
100
101
102app = Litestar(
103    [get_list, add_item, update_item],
104    dependencies={"transaction": provide_transaction},
105    lifespan=[db_connection],
106)
  1from contextlib import asynccontextmanager
  2from typing import Any
  3from collections.abc import AsyncGenerator
  4
  5from sqlalchemy import select
  6from sqlalchemy.exc import IntegrityError, NoResultFound
  7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
  9
 10from litestar import Litestar, get, post, put
 11from litestar.datastructures import State
 12from litestar.exceptions import ClientException, NotFoundException
 13from litestar.status_codes import HTTP_409_CONFLICT
 14
 15TodoType = dict[str, Any]
 16TodoCollectionType = list[TodoType]
 17
 18
 19class Base(DeclarativeBase): ...
 20
 21
 22class TodoItem(Base):
 23    __tablename__ = "todo_items"
 24
 25    title: Mapped[str] = mapped_column(primary_key=True)
 26    done: Mapped[bool]
 27
 28
 29@asynccontextmanager
 30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
 31    engine = getattr(app.state, "engine", None)
 32    if engine is None:
 33        engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
 34        app.state.engine = engine
 35
 36    async with engine.begin() as conn:
 37        await conn.run_sync(Base.metadata.create_all)
 38
 39    try:
 40        yield
 41    finally:
 42        await engine.dispose()
 43
 44
 45sessionmaker = async_sessionmaker(expire_on_commit=False)
 46
 47
 48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 49    async with sessionmaker(bind=state.engine) as session:
 50        try:
 51            async with session.begin():
 52                yield session
 53        except IntegrityError as exc:
 54            raise ClientException(
 55                status_code=HTTP_409_CONFLICT,
 56                detail=str(exc),
 57            ) from exc
 58
 59
 60def serialize_todo(todo: TodoItem) -> TodoType:
 61    return {"title": todo.title, "done": todo.done}
 62
 63
 64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
 65    query = select(TodoItem).where(TodoItem.title == todo_name)
 66    result = await session.execute(query)
 67    try:
 68        return result.scalar_one()
 69    except NoResultFound as e:
 70        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
 71
 72
 73async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
 74    query = select(TodoItem)
 75    if done is not None:
 76        query = query.where(TodoItem.done.is_(done))
 77
 78    result = await session.execute(query)
 79    return result.scalars().all()
 80
 81
 82@get("/")
 83async def get_list(transaction: AsyncSession, done: bool | None = None) -> TodoCollectionType:
 84    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 85
 86
 87@post("/")
 88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
 89    new_todo = TodoItem(title=data["title"], done=data["done"])
 90    transaction.add(new_todo)
 91    return serialize_todo(new_todo)
 92
 93
 94@put("/{item_title:str}")
 95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
 96    todo_item = await get_todo_by_title(item_title, transaction)
 97    todo_item.title = data["title"]
 98    todo_item.done = data["done"]
 99    return serialize_todo(todo_item)
100
101
102app = Litestar(
103    [get_list, add_item, update_item],
104    dependencies={"transaction": provide_transaction},
105    lifespan=[db_connection],
106)

In the previous example, the database session is created within each HTTP route handler function. In this script we use dependency injection to decouple creation of the session from the route handlers.

This script introduces a new async generator function called provide_transaction() that creates a new SQLAlchemy session, begins a transaction, and handles any integrity errors that might raise from within the transaction.

1    async with sessionmaker(bind=state.engine) as session:
2        try:
3            async with session.begin():
4                yield session
5        except IntegrityError as exc:
6            raise ClientException(
7                status_code=HTTP_409_CONFLICT,
8                detail=str(exc),
9            ) from exc

That function is declared as a dependency to the Litestar application, using the name transaction.

1app = Litestar(
2    [get_list, add_item, update_item],
3    dependencies={"transaction": provide_transaction},
4    lifespan=[db_connection],
5)

In the route handlers, the database session is injected by declaring the transaction name as a function argument. This is automatically provided by Litestar’s dependency injection system at runtime.

1@get("/")
2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
3    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]

One final improvement in this script is exception handling. In the previous version, a litestar.exceptions.ClientException is raised inside the add_item() handler if there’s an integrity error raised during the insertion of the new TODO item. In our latest revision, we’ve been able to centralize this handling to occur inside the provide_transaction() function.

 1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
 2    async with sessionmaker(bind=state.engine) as session:
 3        try:
 4            async with session.begin():
 5                yield session
 6        except IntegrityError as exc:
 7            raise ClientException(
 8                status_code=HTTP_409_CONFLICT,
 9                detail=str(exc),
10            ) from exc

This change broadens the scope of exception handling to any operation that uses the database session, not just the insertion of new items.

Compare handlers before and after DI#

Just for fun, lets compare the sets of application handlers before and after we introduced dependency injection for our session object:

 1@get("/")
 2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
 3    return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
 4
 5
 6@post("/")
 7async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
 8    new_todo = TodoItem(title=data["title"], done=data["done"])
 9    transaction.add(new_todo)
10    return serialize_todo(new_todo)
11
12
13@put("/{item_title:str}")
14async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
15    todo_item = await get_todo_by_title(item_title, transaction)
16    todo_item.title = data["title"]
17    todo_item.done = data["done"]
18    return serialize_todo(todo_item)
19
20
21app = Litestar(
22    [get_list, add_item, update_item],
23    dependencies={"transaction": provide_transaction},
24    lifespan=[db_connection],
25)
 1@get("/")
 2async def get_list(state: State, done: Optional[bool] = None) -> TodoCollectionType:
 3    async with sessionmaker(bind=state.engine) as session:
 4        return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
 5
 6
 7@post("/")
 8async def add_item(data: TodoType, state: State) -> TodoType:
 9    new_todo = TodoItem(title=data["title"], done=data["done"])
10    async with sessionmaker(bind=state.engine) as session:
11        try:
12            async with session.begin():
13                session.add(new_todo)
14        except IntegrityError as e:
15            raise ClientException(
16                status_code=HTTP_409_CONFLICT,
17                detail=f"TODO {new_todo.title!r} already exists",
18            ) from e
19
20    return serialize_todo(new_todo)
21
22
23@put("/{item_title:str}")
24async def update_item(item_title: str, data: TodoType, state: State) -> TodoType:
25    async with sessionmaker(bind=state.engine) as session, session.begin():
26        todo_item = await get_todo_by_title(item_title, session)
27        todo_item.title = data["title"]
28        todo_item.done = data["done"]
29    return serialize_todo(todo_item)
30
31
32app = Litestar([get_list, add_item, update_item], lifespan=[db_connection])

Much better!

Next steps#

One of the niceties that we’ve lost is the ability to receive and return data to/from our handlers as instances of our data model. In the original TODO application, we modelled with Python dataclasses which are natively supported for (de)serialization by Litestar. In the next section, we will look at how we can get this functionality back!