Providing the session with DI¶
In our original script, we had to repeat the logic to construct a session instance for every request type. This is not very DRY.
In this section, we’ll use dependency injection to centralize the session creation logic and make it available to all handlers.
1from collections.abc import AsyncGenerator
2from contextlib import asynccontextmanager
3from typing import Any, Optional
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.exceptions import ClientException, NotFoundException
13from litestar.status_codes import HTTP_409_CONFLICT
14
15TodoType = dict[str, Any]
16TodoCollectionType = list[TodoType]
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23 __tablename__ = "todo_items"
24
25 title: Mapped[str] = mapped_column(primary_key=True)
26 done: Mapped[bool]
27
28
29@asynccontextmanager
30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
31 engine = getattr(app.state, "engine", None)
32 if engine is None:
33 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
34 app.state.engine = engine
35
36 async with engine.begin() as conn:
37 await conn.run_sync(Base.metadata.create_all)
38
39 try:
40 yield
41 finally:
42 await engine.dispose()
43
44
45sessionmaker = async_sessionmaker(expire_on_commit=False)
46
47
48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
49 async with sessionmaker(bind=state.engine) as session:
50 try:
51 async with session.begin():
52 yield session
53 except IntegrityError as exc:
54 raise ClientException(
55 status_code=HTTP_409_CONFLICT,
56 detail=str(exc),
57 ) from exc
58
59
60def serialize_todo(todo: TodoItem) -> TodoType:
61 return {"title": todo.title, "done": todo.done}
62
63
64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
65 query = select(TodoItem).where(TodoItem.title == todo_name)
66 result = await session.execute(query)
67 try:
68 return result.scalar_one()
69 except NoResultFound as e:
70 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
71
72
73async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
74 query = select(TodoItem)
75 if done is not None:
76 query = query.where(TodoItem.done.is_(done))
77
78 result = await session.execute(query)
79 return result.scalars().all()
80
81
82@get("/")
83async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
84 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
85
86
87@post("/")
88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
89 new_todo = TodoItem(title=data["title"], done=data["done"])
90 transaction.add(new_todo)
91 return serialize_todo(new_todo)
92
93
94@put("/{item_title:str}")
95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
96 todo_item = await get_todo_by_title(item_title, transaction)
97 todo_item.title = data["title"]
98 todo_item.done = data["done"]
99 return serialize_todo(todo_item)
100
101
102app = Litestar(
103 [get_list, add_item, update_item],
104 dependencies={"transaction": provide_transaction},
105 lifespan=[db_connection],
106)
1from collections.abc import AsyncGenerator
2from contextlib import asynccontextmanager
3from typing import Any
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.exceptions import ClientException, NotFoundException
13from litestar.status_codes import HTTP_409_CONFLICT
14
15TodoType = dict[str, Any]
16TodoCollectionType = list[TodoType]
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23 __tablename__ = "todo_items"
24
25 title: Mapped[str] = mapped_column(primary_key=True)
26 done: Mapped[bool]
27
28
29@asynccontextmanager
30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
31 engine = getattr(app.state, "engine", None)
32 if engine is None:
33 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
34 app.state.engine = engine
35
36 async with engine.begin() as conn:
37 await conn.run_sync(Base.metadata.create_all)
38
39 try:
40 yield
41 finally:
42 await engine.dispose()
43
44
45sessionmaker = async_sessionmaker(expire_on_commit=False)
46
47
48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
49 async with sessionmaker(bind=state.engine) as session:
50 try:
51 async with session.begin():
52 yield session
53 except IntegrityError as exc:
54 raise ClientException(
55 status_code=HTTP_409_CONFLICT,
56 detail=str(exc),
57 ) from exc
58
59
60def serialize_todo(todo: TodoItem) -> TodoType:
61 return {"title": todo.title, "done": todo.done}
62
63
64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
65 query = select(TodoItem).where(TodoItem.title == todo_name)
66 result = await session.execute(query)
67 try:
68 return result.scalar_one()
69 except NoResultFound as e:
70 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
71
72
73async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
74 query = select(TodoItem)
75 if done is not None:
76 query = query.where(TodoItem.done.is_(done))
77
78 result = await session.execute(query)
79 return result.scalars().all()
80
81
82@get("/")
83async def get_list(transaction: AsyncSession, done: bool | None = None) -> TodoCollectionType:
84 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
85
86
87@post("/")
88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
89 new_todo = TodoItem(title=data["title"], done=data["done"])
90 transaction.add(new_todo)
91 return serialize_todo(new_todo)
92
93
94@put("/{item_title:str}")
95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
96 todo_item = await get_todo_by_title(item_title, transaction)
97 todo_item.title = data["title"]
98 todo_item.done = data["done"]
99 return serialize_todo(todo_item)
100
101
102app = Litestar(
103 [get_list, add_item, update_item],
104 dependencies={"transaction": provide_transaction},
105 lifespan=[db_connection],
106)
In the previous example, the database session is created within each HTTP route handler function. In this script we use dependency injection to decouple creation of the session from the route handlers.
This script introduces a new async generator function called provide_transaction()
that creates a new SQLAlchemy
session, begins a transaction, and handles any integrity errors that might raise from within the transaction.
1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
2 async with sessionmaker(bind=state.engine) as session:
3 try:
4 async with session.begin():
5 yield session
6 except IntegrityError as exc:
7 raise ClientException(
8 status_code=HTTP_409_CONFLICT,
9 detail=str(exc),
10 ) from exc
That function is declared as a dependency to the Litestar application, using the name transaction
.
1app = Litestar(
2 [get_list, add_item, update_item],
3 dependencies={"transaction": provide_transaction},
4 lifespan=[db_connection],
In the route handlers, the database session is injected by declaring the transaction
name as a function argument.
This is automatically provided by Litestar’s dependency injection system at runtime.
1@get("/")
2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
3 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
One final improvement in this script is exception handling. In the previous version, a
litestar.exceptions.ClientException
is raised inside the add_item()
handler if there’s an integrity error
raised during the insertion of the new TODO item. In our latest revision, we’ve been able to centralize this handling
to occur inside the provide_transaction()
function.
1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
2 async with sessionmaker(bind=state.engine) as session:
3 try:
4 async with session.begin():
5 yield session
6 except IntegrityError as exc:
7 raise ClientException(
8 status_code=HTTP_409_CONFLICT,
9 detail=str(exc),
10 ) from exc
This change broadens the scope of exception handling to any operation that uses the database session, not just the insertion of new items.
Compare handlers before and after DI¶
Just for fun, lets compare the sets of application handlers before and after we introduced dependency injection for our session object:
1@get("/")
2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
3 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
4
5
6@post("/")
7async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
8 new_todo = TodoItem(title=data["title"], done=data["done"])
9 transaction.add(new_todo)
10 return serialize_todo(new_todo)
11
12
13@put("/{item_title:str}")
14async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
15 todo_item = await get_todo_by_title(item_title, transaction)
16 todo_item.title = data["title"]
17 todo_item.done = data["done"]
18 return serialize_todo(todo_item)
19
20
21app = Litestar(
22 [get_list, add_item, update_item],
23 dependencies={"transaction": provide_transaction},
24 lifespan=[db_connection],
1@get("/")
2async def get_list(state: State, done: Optional[bool] = None) -> TodoCollectionType:
3 async with sessionmaker(bind=state.engine) as session:
4 return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
5
6
7@post("/")
8async def add_item(data: TodoType, state: State) -> TodoType:
9 new_todo = TodoItem(title=data["title"], done=data["done"])
10 async with sessionmaker(bind=state.engine) as session:
11 try:
12 async with session.begin():
13 session.add(new_todo)
14 except IntegrityError as e:
15 raise ClientException(
16 status_code=HTTP_409_CONFLICT,
17 detail=f"TODO {new_todo.title!r} already exists",
18 ) from e
19
20 return serialize_todo(new_todo)
21
22
23@put("/{item_title:str}")
24async def update_item(item_title: str, data: TodoType, state: State) -> TodoType:
25 async with sessionmaker(bind=state.engine) as session, session.begin():
26 todo_item = await get_todo_by_title(item_title, session)
27 todo_item.title = data["title"]
28 todo_item.done = data["done"]
29 return serialize_todo(todo_item)
Much better!
Next steps¶
One of the niceties that we’ve lost is the ability to receive and return data to/from our handlers as instances of our data model. In the original TODO application, we modelled with Python dataclasses which are natively supported for (de)serialization by Litestar. In the next section, we will look at how we can get this functionality back!