Providing the session with DI#
In our original script, we had to repeat the logic to construct a session instance for every request type. This is not very DRY.
In this section, we’ll use dependency injection to centralize the session creation logic and make it available to all handlers.
1from contextlib import asynccontextmanager
2from typing import Any, AsyncGenerator, Dict, List, Optional
3
4from sqlalchemy import select
5from sqlalchemy.exc import IntegrityError, NoResultFound
6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
8
9from litestar import Litestar, get, post, put
10from litestar.datastructures import State
11from litestar.di import NamedDependency
12from litestar.exceptions import ClientException, NotFoundException
13from litestar.params import FromPath, FromQuery
14from litestar.status_codes import HTTP_409_CONFLICT
15
16TodoType = Dict[str, Any]
17TodoCollectionType = List[TodoType]
18
19
20class Base(DeclarativeBase): ...
21
22
23class TodoItem(Base):
24 __tablename__ = "todo_items"
25
26 title: Mapped[str] = mapped_column(primary_key=True)
27 done: Mapped[bool]
28
29
30@asynccontextmanager
31async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
32 engine = getattr(app.state, "engine", None)
33 if engine is None:
34 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
35 app.state.engine = engine
36
37 async with engine.begin() as conn:
38 await conn.run_sync(Base.metadata.create_all)
39
40 try:
41 yield
42 finally:
43 await engine.dispose()
44
45
46sessionmaker = async_sessionmaker(expire_on_commit=False)
47
48
49async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
50 async with sessionmaker(bind=state.engine) as session:
51 try:
52 async with session.begin():
53 yield session
54 except IntegrityError as exc:
55 raise ClientException(
56 status_code=HTTP_409_CONFLICT,
57 detail=str(exc),
58 ) from exc
59
60
61def serialize_todo(todo: TodoItem) -> TodoType:
62 return {"title": todo.title, "done": todo.done}
63
64
65async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
66 query = select(TodoItem).where(TodoItem.title == todo_name)
67 result = await session.execute(query)
68 try:
69 return result.scalar_one()
70 except NoResultFound as e:
71 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
72
73
74async def get_todo_list(done: FromQuery[Optional[bool]], session: AsyncSession) -> List[TodoItem]:
75 query = select(TodoItem)
76 if done is not None:
77 query = query.where(TodoItem.done.is_(done))
78
79 result = await session.execute(query)
80 return result.scalars().all()
81
82
83@get("/")
84async def get_list(
85 transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
86) -> TodoCollectionType:
87 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
88
89
90@post("/")
91async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
92 new_todo = TodoItem(title=data["title"], done=data["done"])
93 transaction.add(new_todo)
94 return serialize_todo(new_todo)
95
96
97@put("/{item_title:str}")
98async def update_item(
99 item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
100) -> TodoType:
101 todo_item = await get_todo_by_title(item_title, transaction)
102 todo_item.title = data["title"]
103 todo_item.done = data["done"]
104 return serialize_todo(todo_item)
105
106
107app = Litestar(
108 [get_list, add_item, update_item],
109 dependencies={"transaction": provide_transaction},
110 lifespan=[db_connection],
111)
1from contextlib import asynccontextmanager
2from typing import Any, Optional
3from collections.abc import AsyncGenerator
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.di import NamedDependency
13from litestar.exceptions import ClientException, NotFoundException
14from litestar.params import FromPath, FromQuery
15from litestar.status_codes import HTTP_409_CONFLICT
16
17TodoType = dict[str, Any]
18TodoCollectionType = list[TodoType]
19
20
21class Base(DeclarativeBase): ...
22
23
24class TodoItem(Base):
25 __tablename__ = "todo_items"
26
27 title: Mapped[str] = mapped_column(primary_key=True)
28 done: Mapped[bool]
29
30
31@asynccontextmanager
32async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
33 engine = getattr(app.state, "engine", None)
34 if engine is None:
35 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
36 app.state.engine = engine
37
38 async with engine.begin() as conn:
39 await conn.run_sync(Base.metadata.create_all)
40
41 try:
42 yield
43 finally:
44 await engine.dispose()
45
46
47sessionmaker = async_sessionmaker(expire_on_commit=False)
48
49
50async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
51 async with sessionmaker(bind=state.engine) as session:
52 try:
53 async with session.begin():
54 yield session
55 except IntegrityError as exc:
56 raise ClientException(
57 status_code=HTTP_409_CONFLICT,
58 detail=str(exc),
59 ) from exc
60
61
62def serialize_todo(todo: TodoItem) -> TodoType:
63 return {"title": todo.title, "done": todo.done}
64
65
66async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
67 query = select(TodoItem).where(TodoItem.title == todo_name)
68 result = await session.execute(query)
69 try:
70 return result.scalar_one()
71 except NoResultFound as e:
72 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
73
74
75async def get_todo_list(done: FromQuery[Optional[bool]], session: AsyncSession) -> list[TodoItem]:
76 query = select(TodoItem)
77 if done is not None:
78 query = query.where(TodoItem.done.is_(done))
79
80 result = await session.execute(query)
81 return result.scalars().all()
82
83
84@get("/")
85async def get_list(
86 transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
87) -> TodoCollectionType:
88 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
89
90
91@post("/")
92async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
93 new_todo = TodoItem(title=data["title"], done=data["done"])
94 transaction.add(new_todo)
95 return serialize_todo(new_todo)
96
97
98@put("/{item_title:str}")
99async def update_item(
100 item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
101) -> TodoType:
102 todo_item = await get_todo_by_title(item_title, transaction)
103 todo_item.title = data["title"]
104 todo_item.done = data["done"]
105 return serialize_todo(todo_item)
106
107
108app = Litestar(
109 [get_list, add_item, update_item],
110 dependencies={"transaction": provide_transaction},
111 lifespan=[db_connection],
112)
1from contextlib import asynccontextmanager
2from typing import Any
3from collections.abc import AsyncGenerator
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.di import NamedDependency
13from litestar.exceptions import ClientException, NotFoundException
14from litestar.params import FromPath, FromQuery
15from litestar.status_codes import HTTP_409_CONFLICT
16
17TodoType = dict[str, Any]
18TodoCollectionType = list[TodoType]
19
20
21class Base(DeclarativeBase): ...
22
23
24class TodoItem(Base):
25 __tablename__ = "todo_items"
26
27 title: Mapped[str] = mapped_column(primary_key=True)
28 done: Mapped[bool]
29
30
31@asynccontextmanager
32async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
33 engine = getattr(app.state, "engine", None)
34 if engine is None:
35 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
36 app.state.engine = engine
37
38 async with engine.begin() as conn:
39 await conn.run_sync(Base.metadata.create_all)
40
41 try:
42 yield
43 finally:
44 await engine.dispose()
45
46
47sessionmaker = async_sessionmaker(expire_on_commit=False)
48
49
50async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
51 async with sessionmaker(bind=state.engine) as session:
52 try:
53 async with session.begin():
54 yield session
55 except IntegrityError as exc:
56 raise ClientException(
57 status_code=HTTP_409_CONFLICT,
58 detail=str(exc),
59 ) from exc
60
61
62def serialize_todo(todo: TodoItem) -> TodoType:
63 return {"title": todo.title, "done": todo.done}
64
65
66async def get_todo_by_title(todo_name: FromPath[str], session: AsyncSession) -> TodoItem:
67 query = select(TodoItem).where(TodoItem.title == todo_name)
68 result = await session.execute(query)
69 try:
70 return result.scalar_one()
71 except NoResultFound as e:
72 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
73
74
75async def get_todo_list(done: FromQuery[bool | None], session: AsyncSession) -> list[TodoItem]:
76 query = select(TodoItem)
77 if done is not None:
78 query = query.where(TodoItem.done.is_(done))
79
80 result = await session.execute(query)
81 return result.scalars().all()
82
83
84@get("/")
85async def get_list(
86 transaction: NamedDependency[AsyncSession], done: FromQuery[bool | None] = None
87) -> TodoCollectionType:
88 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
89
90
91@post("/")
92async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
93 new_todo = TodoItem(title=data["title"], done=data["done"])
94 transaction.add(new_todo)
95 return serialize_todo(new_todo)
96
97
98@put("/{item_title:str}")
99async def update_item(
100 item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
101) -> TodoType:
102 todo_item = await get_todo_by_title(item_title, transaction)
103 todo_item.title = data["title"]
104 todo_item.done = data["done"]
105 return serialize_todo(todo_item)
106
107
108app = Litestar(
109 [get_list, add_item, update_item],
110 dependencies={"transaction": provide_transaction},
111 lifespan=[db_connection],
112)
In the previous example, the database session is created within each HTTP route handler function. In this script we use dependency injection to decouple creation of the session from the route handlers.
This script introduces a new async generator function called provide_transaction() that creates a new SQLAlchemy
session, begins a transaction, and handles any integrity errors that might raise from within the transaction.
1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
2 async with sessionmaker(bind=state.engine) as session:
3 try:
4 async with session.begin():
5 yield session
6 except IntegrityError as exc:
7 raise ClientException(
8 status_code=HTTP_409_CONFLICT,
9 detail=str(exc),
10 ) from exc
That function is declared as a dependency to the Litestar application, using the name transaction.
1 todo_item.title = data["title"]
2 todo_item.done = data["done"]
3 return serialize_todo(todo_item)
In the route handlers, the database session is injected by adding a parameter named transaction and marking it with
NamedDependency. The marker tells Litestar to supply the value from the transaction dependency
at runtime, matching the parameter name against the dependency key.
1@get("/")
2async def get_list(
3 transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
One final improvement in this script is exception handling. In the previous version, a
litestar.exceptions.ClientException is raised inside the add_item() handler if there’s an integrity error
raised during the insertion of the new TODO item. In our latest revision, we’ve been able to centralize this handling
to occur inside the provide_transaction() function.
1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
2 async with sessionmaker(bind=state.engine) as session:
3 try:
4 async with session.begin():
5 yield session
6 except IntegrityError as exc:
7 raise ClientException(
8 status_code=HTTP_409_CONFLICT,
9 detail=str(exc),
10 ) from exc
This change broadens the scope of exception handling to any operation that uses the database session, not just the insertion of new items.
Compare handlers before and after DI#
Just for fun, lets compare the sets of application handlers before and after we introduced dependency injection for our session object:
1@get("/")
2async def get_list(
3 transaction: NamedDependency[AsyncSession], done: FromQuery[Optional[bool]] = None
4) -> TodoCollectionType:
5 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
6
7
8@post("/")
9async def add_item(data: TodoType, transaction: NamedDependency[AsyncSession]) -> TodoType:
10 new_todo = TodoItem(title=data["title"], done=data["done"])
11 transaction.add(new_todo)
12 return serialize_todo(new_todo)
13
14
15@put("/{item_title:str}")
16async def update_item(
17 item_title: FromPath[str], data: TodoType, transaction: NamedDependency[AsyncSession]
18) -> TodoType:
19 todo_item = await get_todo_by_title(item_title, transaction)
20 todo_item.title = data["title"]
21 todo_item.done = data["done"]
22 return serialize_todo(todo_item)
1@get("/")
2async def get_list(state: State, done: FromQuery[Optional[bool]] = None) -> TodoCollectionType:
3 async with sessionmaker(bind=state.engine) as session:
4 return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
5
6
7@post("/")
8async def add_item(data: TodoType, state: State) -> TodoType:
9 new_todo = TodoItem(title=data["title"], done=data["done"])
10 async with sessionmaker(bind=state.engine) as session:
11 try:
12 async with session.begin():
13 session.add(new_todo)
14 except IntegrityError as e:
15 raise ClientException(
16 status_code=HTTP_409_CONFLICT,
17 detail=f"TODO {new_todo.title!r} already exists",
18 ) from e
19
20 return serialize_todo(new_todo)
21
22
23@put("/{item_title:str}")
24async def update_item(item_title: FromPath[str], data: TodoType, state: State) -> TodoType:
25 async with sessionmaker(bind=state.engine) as session, session.begin():
26 todo_item = await get_todo_by_title(item_title, session)
27 todo_item.title = data["title"]
28 todo_item.done = data["done"]
29 return serialize_todo(todo_item)
Much better!
Next steps#
One of the niceties that we’ve lost is the ability to receive and return data to/from our handlers as instances of our data model. In the original TODO application, we modelled with Python dataclasses which are natively supported for (de)serialization by Litestar. In the next section, we will look at how we can get this functionality back!