Providing the session with DI#
In our original script, we had to repeat the logic to construct a session instance for every request type. This is not very DRY.
In this section, we’ll use dependency injection to centralize the session creation logic and make it available to all handlers.
1from contextlib import asynccontextmanager
2from typing import Any, AsyncGenerator, Dict, List, Optional
3
4from sqlalchemy import select
5from sqlalchemy.exc import IntegrityError, NoResultFound
6from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
8
9from litestar import Litestar, get, post, put
10from litestar.datastructures import State
11from litestar.exceptions import ClientException, NotFoundException
12from litestar.status_codes import HTTP_409_CONFLICT
13
14TodoType = Dict[str, Any]
15TodoCollectionType = List[TodoType]
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22 __tablename__ = "todo_items"
23
24 title: Mapped[str] = mapped_column(primary_key=True)
25 done: Mapped[bool]
26
27
28@asynccontextmanager
29async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
30 engine = getattr(app.state, "engine", None)
31 if engine is None:
32 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
33 app.state.engine = engine
34
35 async with engine.begin() as conn:
36 await conn.run_sync(Base.metadata.create_all)
37
38 try:
39 yield
40 finally:
41 await engine.dispose()
42
43
44sessionmaker = async_sessionmaker(expire_on_commit=False)
45
46
47async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
48 async with sessionmaker(bind=state.engine) as session:
49 try:
50 async with session.begin():
51 yield session
52 except IntegrityError as exc:
53 raise ClientException(
54 status_code=HTTP_409_CONFLICT,
55 detail=str(exc),
56 ) from exc
57
58
59def serialize_todo(todo: TodoItem) -> TodoType:
60 return {"title": todo.title, "done": todo.done}
61
62
63async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
64 query = select(TodoItem).where(TodoItem.title == todo_name)
65 result = await session.execute(query)
66 try:
67 return result.scalar_one()
68 except NoResultFound as e:
69 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
70
71
72async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
73 query = select(TodoItem)
74 if done is not None:
75 query = query.where(TodoItem.done.is_(done))
76
77 result = await session.execute(query)
78 return result.scalars().all()
79
80
81@get("/")
82async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
83 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
84
85
86@post("/")
87async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
88 new_todo = TodoItem(title=data["title"], done=data["done"])
89 transaction.add(new_todo)
90 return serialize_todo(new_todo)
91
92
93@put("/{item_title:str}")
94async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
95 todo_item = await get_todo_by_title(item_title, transaction)
96 todo_item.title = data["title"]
97 todo_item.done = data["done"]
98 return serialize_todo(todo_item)
99
100
101app = Litestar(
102 [get_list, add_item, update_item],
103 dependencies={"transaction": provide_transaction},
104 lifespan=[db_connection],
105)
1from contextlib import asynccontextmanager
2from typing import Any, Optional
3from collections.abc import AsyncGenerator
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.exceptions import ClientException, NotFoundException
13from litestar.status_codes import HTTP_409_CONFLICT
14
15TodoType = dict[str, Any]
16TodoCollectionType = list[TodoType]
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23 __tablename__ = "todo_items"
24
25 title: Mapped[str] = mapped_column(primary_key=True)
26 done: Mapped[bool]
27
28
29@asynccontextmanager
30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
31 engine = getattr(app.state, "engine", None)
32 if engine is None:
33 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
34 app.state.engine = engine
35
36 async with engine.begin() as conn:
37 await conn.run_sync(Base.metadata.create_all)
38
39 try:
40 yield
41 finally:
42 await engine.dispose()
43
44
45sessionmaker = async_sessionmaker(expire_on_commit=False)
46
47
48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
49 async with sessionmaker(bind=state.engine) as session:
50 try:
51 async with session.begin():
52 yield session
53 except IntegrityError as exc:
54 raise ClientException(
55 status_code=HTTP_409_CONFLICT,
56 detail=str(exc),
57 ) from exc
58
59
60def serialize_todo(todo: TodoItem) -> TodoType:
61 return {"title": todo.title, "done": todo.done}
62
63
64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
65 query = select(TodoItem).where(TodoItem.title == todo_name)
66 result = await session.execute(query)
67 try:
68 return result.scalar_one()
69 except NoResultFound as e:
70 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
71
72
73async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
74 query = select(TodoItem)
75 if done is not None:
76 query = query.where(TodoItem.done.is_(done))
77
78 result = await session.execute(query)
79 return result.scalars().all()
80
81
82@get("/")
83async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
84 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
85
86
87@post("/")
88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
89 new_todo = TodoItem(title=data["title"], done=data["done"])
90 transaction.add(new_todo)
91 return serialize_todo(new_todo)
92
93
94@put("/{item_title:str}")
95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
96 todo_item = await get_todo_by_title(item_title, transaction)
97 todo_item.title = data["title"]
98 todo_item.done = data["done"]
99 return serialize_todo(todo_item)
100
101
102app = Litestar(
103 [get_list, add_item, update_item],
104 dependencies={"transaction": provide_transaction},
105 lifespan=[db_connection],
106)
1from contextlib import asynccontextmanager
2from typing import Any
3from collections.abc import AsyncGenerator
4
5from sqlalchemy import select
6from sqlalchemy.exc import IntegrityError, NoResultFound
7from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
9
10from litestar import Litestar, get, post, put
11from litestar.datastructures import State
12from litestar.exceptions import ClientException, NotFoundException
13from litestar.status_codes import HTTP_409_CONFLICT
14
15TodoType = dict[str, Any]
16TodoCollectionType = list[TodoType]
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23 __tablename__ = "todo_items"
24
25 title: Mapped[str] = mapped_column(primary_key=True)
26 done: Mapped[bool]
27
28
29@asynccontextmanager
30async def db_connection(app: Litestar) -> AsyncGenerator[None, None]:
31 engine = getattr(app.state, "engine", None)
32 if engine is None:
33 engine = create_async_engine("sqlite+aiosqlite:///todo.sqlite")
34 app.state.engine = engine
35
36 async with engine.begin() as conn:
37 await conn.run_sync(Base.metadata.create_all)
38
39 try:
40 yield
41 finally:
42 await engine.dispose()
43
44
45sessionmaker = async_sessionmaker(expire_on_commit=False)
46
47
48async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
49 async with sessionmaker(bind=state.engine) as session:
50 try:
51 async with session.begin():
52 yield session
53 except IntegrityError as exc:
54 raise ClientException(
55 status_code=HTTP_409_CONFLICT,
56 detail=str(exc),
57 ) from exc
58
59
60def serialize_todo(todo: TodoItem) -> TodoType:
61 return {"title": todo.title, "done": todo.done}
62
63
64async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
65 query = select(TodoItem).where(TodoItem.title == todo_name)
66 result = await session.execute(query)
67 try:
68 return result.scalar_one()
69 except NoResultFound as e:
70 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
71
72
73async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
74 query = select(TodoItem)
75 if done is not None:
76 query = query.where(TodoItem.done.is_(done))
77
78 result = await session.execute(query)
79 return result.scalars().all()
80
81
82@get("/")
83async def get_list(transaction: AsyncSession, done: bool | None = None) -> TodoCollectionType:
84 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
85
86
87@post("/")
88async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
89 new_todo = TodoItem(title=data["title"], done=data["done"])
90 transaction.add(new_todo)
91 return serialize_todo(new_todo)
92
93
94@put("/{item_title:str}")
95async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
96 todo_item = await get_todo_by_title(item_title, transaction)
97 todo_item.title = data["title"]
98 todo_item.done = data["done"]
99 return serialize_todo(todo_item)
100
101
102app = Litestar(
103 [get_list, add_item, update_item],
104 dependencies={"transaction": provide_transaction},
105 lifespan=[db_connection],
106)
In the previous example, the database session is created within each HTTP route handler function. In this script we use dependency injection to decouple creation of the session from the route handlers.
This script introduces a new async generator function called provide_transaction()
that creates a new SQLAlchemy
session, begins a transaction, and handles any integrity errors that might raise from within the transaction.
1 async with sessionmaker(bind=state.engine) as session:
2 try:
3 async with session.begin():
4 yield session
5 except IntegrityError as exc:
6 raise ClientException(
7 status_code=HTTP_409_CONFLICT,
8 detail=str(exc),
9 ) from exc
That function is declared as a dependency to the Litestar application, using the name transaction
.
1app = Litestar(
2 [get_list, add_item, update_item],
3 dependencies={"transaction": provide_transaction},
4 lifespan=[db_connection],
5)
In the route handlers, the database session is injected by declaring the transaction
name as a function argument.
This is automatically provided by Litestar’s dependency injection system at runtime.
1@get("/")
2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
3 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
One final improvement in this script is exception handling. In the previous version, a
litestar.exceptions.ClientException
is raised inside the add_item()
handler if there’s an integrity error
raised during the insertion of the new TODO item. In our latest revision, we’ve been able to centralize this handling
to occur inside the provide_transaction()
function.
1async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None]:
2 async with sessionmaker(bind=state.engine) as session:
3 try:
4 async with session.begin():
5 yield session
6 except IntegrityError as exc:
7 raise ClientException(
8 status_code=HTTP_409_CONFLICT,
9 detail=str(exc),
10 ) from exc
This change broadens the scope of exception handling to any operation that uses the database session, not just the insertion of new items.
Compare handlers before and after DI#
Just for fun, lets compare the sets of application handlers before and after we introduced dependency injection for our session object:
1@get("/")
2async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> TodoCollectionType:
3 return [serialize_todo(todo) for todo in await get_todo_list(done, transaction)]
4
5
6@post("/")
7async def add_item(data: TodoType, transaction: AsyncSession) -> TodoType:
8 new_todo = TodoItem(title=data["title"], done=data["done"])
9 transaction.add(new_todo)
10 return serialize_todo(new_todo)
11
12
13@put("/{item_title:str}")
14async def update_item(item_title: str, data: TodoType, transaction: AsyncSession) -> TodoType:
15 todo_item = await get_todo_by_title(item_title, transaction)
16 todo_item.title = data["title"]
17 todo_item.done = data["done"]
18 return serialize_todo(todo_item)
19
20
21app = Litestar(
22 [get_list, add_item, update_item],
23 dependencies={"transaction": provide_transaction},
24 lifespan=[db_connection],
25)
1@get("/")
2async def get_list(state: State, done: Optional[bool] = None) -> TodoCollectionType:
3 async with sessionmaker(bind=state.engine) as session:
4 return [serialize_todo(todo) for todo in await get_todo_list(done, session)]
5
6
7@post("/")
8async def add_item(data: TodoType, state: State) -> TodoType:
9 new_todo = TodoItem(title=data["title"], done=data["done"])
10 async with sessionmaker(bind=state.engine) as session:
11 try:
12 async with session.begin():
13 session.add(new_todo)
14 except IntegrityError as e:
15 raise ClientException(
16 status_code=HTTP_409_CONFLICT,
17 detail=f"TODO {new_todo.title!r} already exists",
18 ) from e
19
20 return serialize_todo(new_todo)
21
22
23@put("/{item_title:str}")
24async def update_item(item_title: str, data: TodoType, state: State) -> TodoType:
25 async with sessionmaker(bind=state.engine) as session, session.begin():
26 todo_item = await get_todo_by_title(item_title, session)
27 todo_item.title = data["title"]
28 todo_item.done = data["done"]
29 return serialize_todo(todo_item)
30
31
32app = Litestar([get_list, add_item, update_item], lifespan=[db_connection])
Much better!
Next steps#
One of the niceties that we’ve lost is the ability to receive and return data to/from our handlers as instances of our data model. In the original TODO application, we modelled with Python dataclasses which are natively supported for (de)serialization by Litestar. In the next section, we will look at how we can get this functionality back!