Using the init plugin

In our example application, we’ve seen that we need to manage the database engine within the scope of the application’s lifespan, and the session within the scope of a request. This is a common pattern, and the SQLAlchemyInitPlugin plugin provides assistance for this.

In our latest update, we leverage two features of the plugin:

  1. The plugin will automatically create a database engine for us and manage it within the scope of the application’s lifespan.

  2. The plugin will automatically create a database session for us and manage it within the scope of a request.

We access the database session via dependency injection, using the db_session parameter.

Here’s the updated code:

 1from collections.abc import AsyncGenerator
 2
 3from advanced_alchemy.extensions.litestar import (
 4    SQLAlchemyAsyncConfig,
 5    SQLAlchemyInitPlugin,
 6    SQLAlchemySerializationPlugin,
 7)
 8from sqlalchemy import select
 9from sqlalchemy.exc import IntegrityError, NoResultFound
10from sqlalchemy.ext.asyncio import AsyncSession
11from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
12
13from litestar import Litestar, get, post, put
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22    __tablename__ = "todo_items"
23
24    title: Mapped[str] = mapped_column(primary_key=True)
25    done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29    try:
30        async with db_session.begin():
31            yield db_session
32    except IntegrityError as exc:
33        raise ClientException(
34            status_code=HTTP_409_CONFLICT,
35            detail=str(exc),
36        ) from exc
37
38
39async def get_todo_by_title(todo_name: str, session: AsyncSession) -> TodoItem:
40    query = select(TodoItem).where(TodoItem.title == todo_name)
41    result = await session.execute(query)
42    try:
43        return result.scalar_one()
44    except NoResultFound as e:
45        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
49    query = select(TodoItem)
50    if done is not None:
51        query = query.where(TodoItem.done.is_(done))
52
53    result = await session.execute(query)
54    return list(result.scalars().all())
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
59    return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64    transaction.add(data)
65    return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70    todo_item = await get_todo_by_title(item_title, transaction)
71    todo_item.title = data.title
72    todo_item.done = data.done
73    return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81    [get_list, add_item, update_item],
82    dependencies={"transaction": provide_transaction},
83    plugins=[
84        SQLAlchemySerializationPlugin(),
85        SQLAlchemyInitPlugin(db_config),
86    ],
87)
 1from collections.abc import AsyncGenerator
 2
 3from advanced_alchemy.extensions.litestar import (
 4    SQLAlchemyAsyncConfig,
 5    SQLAlchemyInitPlugin,
 6    SQLAlchemySerializationPlugin,
 7)
 8from sqlalchemy import select
 9from sqlalchemy.exc import IntegrityError, NoResultFound
10from sqlalchemy.ext.asyncio import AsyncSession
11from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
12
13from litestar import Litestar, get, post, put
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22    __tablename__ = "todo_items"
23
24    title: Mapped[str] = mapped_column(primary_key=True)
25    done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession]:
29    try:
30        async with db_session.begin():
31            yield db_session
32    except IntegrityError as exc:
33        raise ClientException(
34            status_code=HTTP_409_CONFLICT,
35            detail=str(exc),
36        ) from exc
37
38
39async def get_todo_by_title(todo_name: str, session: AsyncSession) -> TodoItem:
40    query = select(TodoItem).where(TodoItem.title == todo_name)
41    result = await session.execute(query)
42    try:
43        return result.scalar_one()
44    except NoResultFound as e:
45        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
49    query = select(TodoItem)
50    if done is not None:
51        query = query.where(TodoItem.done.is_(done))
52
53    result = await session.execute(query)
54    return list(result.scalars().all())
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
59    return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64    transaction.add(data)
65    return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70    todo_item = await get_todo_by_title(item_title, transaction)
71    todo_item.title = data.title
72    todo_item.done = data.done
73    return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81    [get_list, add_item, update_item],
82    dependencies={"transaction": provide_transaction},
83    plugins=[
84        SQLAlchemySerializationPlugin(),
85        SQLAlchemyInitPlugin(db_config),
86    ],
87)

The most notable difference is that we no longer need the db_connection() lifespan context manager - the plugin handles this for us. It also handles the creation of the tables in our database if we supply our metadata and set create_all=True when creating a SQLAlchemyAsyncConfig instance.

Additionally, we have a new db_session dependency available to us, which we use in our provide_transaction() dependency provider, instead of creating our own session.

Next steps

Next up, we’ll make one final change to our application, and then we’ll be recap!