Using the init plugin

In our example application, we’ve seen that we need to manage the database engine within the scope of the application’s lifespan, and the session within the scope of a request. This is a common pattern, and the SQLAlchemyInitPlugin plugin provides assistance for this.

In our latest update, we leverage two features of the plugin:

  1. The plugin will automatically create a database engine for us and manage it within the scope of the application’s lifespan.

  2. The plugin will automatically create a database session for us and manage it within the scope of a request.

We access the database session via dependency injection, using the db_session parameter.

Here’s the updated code:

 1from collections.abc import AsyncGenerator
 2from typing import Optional
 3
 4from sqlalchemy import select
 5from sqlalchemy.exc import IntegrityError, NoResultFound
 6from sqlalchemy.ext.asyncio import AsyncSession
 7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 8
 9from litestar import Litestar, get, post, put
10from litestar.contrib.sqlalchemy.plugins import (
11    SQLAlchemyAsyncConfig,
12    SQLAlchemyInitPlugin,
13    SQLAlchemySerializationPlugin,
14)
15from litestar.exceptions import ClientException, NotFoundException
16from litestar.status_codes import HTTP_409_CONFLICT
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23    __tablename__ = "todo_items"
24
25    title: Mapped[str] = mapped_column(primary_key=True)
26    done: Mapped[bool]
27
28
29async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
30    try:
31        async with db_session.begin():
32            yield db_session
33    except IntegrityError as exc:
34        raise ClientException(
35            status_code=HTTP_409_CONFLICT,
36            detail=str(exc),
37        ) from exc
38
39
40async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
41    query = select(TodoItem).where(TodoItem.title == todo_name)
42    result = await session.execute(query)
43    try:
44        return result.scalar_one()
45    except NoResultFound as e:
46        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
47
48
49async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
50    query = select(TodoItem)
51    if done is not None:
52        query = query.where(TodoItem.done.is_(done))
53
54    result = await session.execute(query)
55    return result.scalars().all()
56
57
58@get("/")
59async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> list[TodoItem]:
60    return await get_todo_list(done, transaction)
61
62
63@post("/")
64async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
65    transaction.add(data)
66    return data
67
68
69@put("/{item_title:str}")
70async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
71    todo_item = await get_todo_by_title(item_title, transaction)
72    todo_item.title = data.title
73    todo_item.done = data.done
74    return todo_item
75
76
77db_config = SQLAlchemyAsyncConfig(
78    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
79)
80
81app = Litestar(
82    [get_list, add_item, update_item],
83    dependencies={"transaction": provide_transaction},
84    plugins=[
85        SQLAlchemySerializationPlugin(),
86        SQLAlchemyInitPlugin(db_config),
87    ],
88)
 1from collections.abc import AsyncGenerator
 2
 3from sqlalchemy import select
 4from sqlalchemy.exc import IntegrityError, NoResultFound
 5from sqlalchemy.ext.asyncio import AsyncSession
 6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 7
 8from litestar import Litestar, get, post, put
 9from litestar.contrib.sqlalchemy.plugins import (
10    SQLAlchemyAsyncConfig,
11    SQLAlchemyInitPlugin,
12    SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22    __tablename__ = "todo_items"
23
24    title: Mapped[str] = mapped_column(primary_key=True)
25    done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29    try:
30        async with db_session.begin():
31            yield db_session
32    except IntegrityError as exc:
33        raise ClientException(
34            status_code=HTTP_409_CONFLICT,
35            detail=str(exc),
36        ) from exc
37
38
39async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
40    query = select(TodoItem).where(TodoItem.title == todo_name)
41    result = await session.execute(query)
42    try:
43        return result.scalar_one()
44    except NoResultFound as e:
45        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
49    query = select(TodoItem)
50    if done is not None:
51        query = query.where(TodoItem.done.is_(done))
52
53    result = await session.execute(query)
54    return result.scalars().all()
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
59    return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64    transaction.add(data)
65    return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70    todo_item = await get_todo_by_title(item_title, transaction)
71    todo_item.title = data.title
72    todo_item.done = data.done
73    return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81    [get_list, add_item, update_item],
82    dependencies={"transaction": provide_transaction},
83    plugins=[
84        SQLAlchemySerializationPlugin(),
85        SQLAlchemyInitPlugin(db_config),
86    ],
87)

The most notable difference is that we no longer need the db_connection() lifespan context manager - the plugin handles this for us. It also handles the creation of the tables in our database if we supply our metadata and set create_all=True when creating a SQLAlchemyAsyncConfig instance.

Additionally, we have a new db_session dependency available to us, which we use in our provide_transaction() dependency provider, instead of creating our own session.

Next steps

Next up, we’ll make one final change to our application, and then we’ll be recap!