Using the init plugin#
In our example application, we’ve seen that we need to manage the database engine within the scope of the application’s
lifespan, and the session within the scope of a request. This is a common pattern, and the
SQLAlchemyInitPlugin
plugin provides assistance for
this.
In our latest update, we leverage two features of the plugin:
The plugin will automatically create a database engine for us and manage it within the scope of the application’s lifespan.
The plugin will automatically create a database session for us and manage it within the scope of a request.
We access the database session via dependency injection, using the db_session
parameter.
Here’s the updated code:
1from typing import AsyncGenerator, List, Optional
2
3from sqlalchemy import select
4from sqlalchemy.exc import IntegrityError, NoResultFound
5from sqlalchemy.ext.asyncio import AsyncSession
6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
7
8from litestar import Litestar, get, post, put
9from litestar.contrib.sqlalchemy.plugins import (
10 SQLAlchemyAsyncConfig,
11 SQLAlchemyInitPlugin,
12 SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22 __tablename__ = "todo_items"
23
24 title: Mapped[str] = mapped_column(primary_key=True)
25 done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29 try:
30 async with db_session.begin():
31 yield db_session
32 except IntegrityError as exc:
33 raise ClientException(
34 status_code=HTTP_409_CONFLICT,
35 detail=str(exc),
36 ) from exc
37
38
39async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
40 query = select(TodoItem).where(TodoItem.title == todo_name)
41 result = await session.execute(query)
42 try:
43 return result.scalar_one()
44 except NoResultFound as e:
45 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
49 query = select(TodoItem)
50 if done is not None:
51 query = query.where(TodoItem.done.is_(done))
52
53 result = await session.execute(query)
54 return result.scalars().all()
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
59 return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64 transaction.add(data)
65 return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70 todo_item = await get_todo_by_title(item_title, transaction)
71 todo_item.title = data.title
72 todo_item.done = data.done
73 return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77 connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81 [get_list, add_item, update_item],
82 dependencies={"transaction": provide_transaction},
83 plugins=[
84 SQLAlchemySerializationPlugin(),
85 SQLAlchemyInitPlugin(db_config),
86 ],
87)
1from typing import Optional
2from collections.abc import AsyncGenerator
3
4from sqlalchemy import select
5from sqlalchemy.exc import IntegrityError, NoResultFound
6from sqlalchemy.ext.asyncio import AsyncSession
7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
8
9from litestar import Litestar, get, post, put
10from litestar.contrib.sqlalchemy.plugins import (
11 SQLAlchemyAsyncConfig,
12 SQLAlchemyInitPlugin,
13 SQLAlchemySerializationPlugin,
14)
15from litestar.exceptions import ClientException, NotFoundException
16from litestar.status_codes import HTTP_409_CONFLICT
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23 __tablename__ = "todo_items"
24
25 title: Mapped[str] = mapped_column(primary_key=True)
26 done: Mapped[bool]
27
28
29async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
30 try:
31 async with db_session.begin():
32 yield db_session
33 except IntegrityError as exc:
34 raise ClientException(
35 status_code=HTTP_409_CONFLICT,
36 detail=str(exc),
37 ) from exc
38
39
40async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
41 query = select(TodoItem).where(TodoItem.title == todo_name)
42 result = await session.execute(query)
43 try:
44 return result.scalar_one()
45 except NoResultFound as e:
46 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
47
48
49async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
50 query = select(TodoItem)
51 if done is not None:
52 query = query.where(TodoItem.done.is_(done))
53
54 result = await session.execute(query)
55 return result.scalars().all()
56
57
58@get("/")
59async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> list[TodoItem]:
60 return await get_todo_list(done, transaction)
61
62
63@post("/")
64async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
65 transaction.add(data)
66 return data
67
68
69@put("/{item_title:str}")
70async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
71 todo_item = await get_todo_by_title(item_title, transaction)
72 todo_item.title = data.title
73 todo_item.done = data.done
74 return todo_item
75
76
77db_config = SQLAlchemyAsyncConfig(
78 connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
79)
80
81app = Litestar(
82 [get_list, add_item, update_item],
83 dependencies={"transaction": provide_transaction},
84 plugins=[
85 SQLAlchemySerializationPlugin(),
86 SQLAlchemyInitPlugin(db_config),
87 ],
88)
1from collections.abc import AsyncGenerator
2
3from sqlalchemy import select
4from sqlalchemy.exc import IntegrityError, NoResultFound
5from sqlalchemy.ext.asyncio import AsyncSession
6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
7
8from litestar import Litestar, get, post, put
9from litestar.contrib.sqlalchemy.plugins import (
10 SQLAlchemyAsyncConfig,
11 SQLAlchemyInitPlugin,
12 SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22 __tablename__ = "todo_items"
23
24 title: Mapped[str] = mapped_column(primary_key=True)
25 done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29 try:
30 async with db_session.begin():
31 yield db_session
32 except IntegrityError as exc:
33 raise ClientException(
34 status_code=HTTP_409_CONFLICT,
35 detail=str(exc),
36 ) from exc
37
38
39async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
40 query = select(TodoItem).where(TodoItem.title == todo_name)
41 result = await session.execute(query)
42 try:
43 return result.scalar_one()
44 except NoResultFound as e:
45 raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
49 query = select(TodoItem)
50 if done is not None:
51 query = query.where(TodoItem.done.is_(done))
52
53 result = await session.execute(query)
54 return result.scalars().all()
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
59 return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64 transaction.add(data)
65 return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70 todo_item = await get_todo_by_title(item_title, transaction)
71 todo_item.title = data.title
72 todo_item.done = data.done
73 return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77 connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81 [get_list, add_item, update_item],
82 dependencies={"transaction": provide_transaction},
83 plugins=[
84 SQLAlchemySerializationPlugin(),
85 SQLAlchemyInitPlugin(db_config),
86 ],
87)
The most notable difference is that we no longer need the db_connection()
lifespan context manager - the plugin
handles this for us. It also handles the creation of the tables in our database if we supply our metadata and
set create_all=True
when creating a SQLAlchemyAsyncConfig
instance.
Additionally, we have a new db_session
dependency available to us, which we use in our provide_transaction()
dependency provider, instead of creating our own session.
Next steps#
Next up, we’ll make one final change to our application, and then we’ll be recap!