Using the init plugin#

In our example application, we’ve seen that we need to manage the database engine within the scope of the application’s lifespan, and the session within the scope of a request. This is a common pattern, and the SQLAlchemyInitPlugin plugin provides assistance for this.

In our latest update, we leverage two features of the plugin:

  1. The plugin will automatically create a database engine for us and manage it within the scope of the application’s lifespan.

  2. The plugin will automatically create a database session for us and manage it within the scope of a request.

We access the database session via dependency injection, using the db_session parameter.

Here’s the updated code:

 1from typing import AsyncGenerator, List, Optional
 2
 3from sqlalchemy import select
 4from sqlalchemy.exc import IntegrityError, NoResultFound
 5from sqlalchemy.ext.asyncio import AsyncSession
 6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 7
 8from litestar import Litestar, get, post, put
 9from litestar.contrib.sqlalchemy.plugins import (
10    SQLAlchemyAsyncConfig,
11    SQLAlchemyInitPlugin,
12    SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase):
19    ...
20
21
22class TodoItem(Base):
23    __tablename__ = "todo_items"
24
25    title: Mapped[str] = mapped_column(primary_key=True)
26    done: Mapped[bool]
27
28
29async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
30    try:
31        async with db_session.begin():
32            yield db_session
33    except IntegrityError as exc:
34        raise ClientException(
35            status_code=HTTP_409_CONFLICT,
36            detail=str(exc),
37        ) from exc
38
39
40async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
41    query = select(TodoItem).where(TodoItem.title == todo_name)
42    result = await session.execute(query)
43    try:
44        return result.scalar_one()
45    except NoResultFound as e:
46        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
47
48
49async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
50    query = select(TodoItem)
51    if done is not None:
52        query = query.where(TodoItem.done.is_(done))
53
54    result = await session.execute(query)
55    return result.scalars().all()
56
57
58@get("/")
59async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
60    return await get_todo_list(done, transaction)
61
62
63@post("/")
64async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
65    transaction.add(data)
66    return data
67
68
69@put("/{item_title:str}")
70async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
71    todo_item = await get_todo_by_title(item_title, transaction)
72    todo_item.title = data.title
73    todo_item.done = data.done
74    return todo_item
75
76
77db_config = SQLAlchemyAsyncConfig(
78    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
79)
80
81app = Litestar(
82    [get_list, add_item, update_item],
83    dependencies={"transaction": provide_transaction},
84    plugins=[
85        SQLAlchemySerializationPlugin(),
86        SQLAlchemyInitPlugin(db_config),
87    ],
88)
 1from typing import List, Optional
 2from collections.abc import AsyncGenerator
 3
 4from sqlalchemy import select
 5from sqlalchemy.exc import IntegrityError, NoResultFound
 6from sqlalchemy.ext.asyncio import AsyncSession
 7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 8
 9from litestar import Litestar, get, post, put
10from litestar.contrib.sqlalchemy.plugins import (
11    SQLAlchemyAsyncConfig,
12    SQLAlchemyInitPlugin,
13    SQLAlchemySerializationPlugin,
14)
15from litestar.exceptions import ClientException, NotFoundException
16from litestar.status_codes import HTTP_409_CONFLICT
17
18
19class Base(DeclarativeBase):
20    ...
21
22
23class TodoItem(Base):
24    __tablename__ = "todo_items"
25
26    title: Mapped[str] = mapped_column(primary_key=True)
27    done: Mapped[bool]
28
29
30async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
31    try:
32        async with db_session.begin():
33            yield db_session
34    except IntegrityError as exc:
35        raise ClientException(
36            status_code=HTTP_409_CONFLICT,
37            detail=str(exc),
38        ) from exc
39
40
41async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
42    query = select(TodoItem).where(TodoItem.title == todo_name)
43    result = await session.execute(query)
44    try:
45        return result.scalar_one()
46    except NoResultFound as e:
47        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
48
49
50async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
51    query = select(TodoItem)
52    if done is not None:
53        query = query.where(TodoItem.done.is_(done))
54
55    result = await session.execute(query)
56    return result.scalars().all()
57
58
59@get("/")
60async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
61    return await get_todo_list(done, transaction)
62
63
64@post("/")
65async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
66    transaction.add(data)
67    return data
68
69
70@put("/{item_title:str}")
71async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
72    todo_item = await get_todo_by_title(item_title, transaction)
73    todo_item.title = data.title
74    todo_item.done = data.done
75    return todo_item
76
77
78db_config = SQLAlchemyAsyncConfig(
79    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
80)
81
82app = Litestar(
83    [get_list, add_item, update_item],
84    dependencies={"transaction": provide_transaction},
85    plugins=[
86        SQLAlchemySerializationPlugin(),
87        SQLAlchemyInitPlugin(db_config),
88    ],
89)
 1from collections.abc import AsyncGenerator
 2
 3from sqlalchemy import select
 4from sqlalchemy.exc import IntegrityError, NoResultFound
 5from sqlalchemy.ext.asyncio import AsyncSession
 6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 7
 8from litestar import Litestar, get, post, put
 9from litestar.contrib.sqlalchemy.plugins import (
10    SQLAlchemyAsyncConfig,
11    SQLAlchemyInitPlugin,
12    SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase):
19    ...
20
21
22class TodoItem(Base):
23    __tablename__ = "todo_items"
24
25    title: Mapped[str] = mapped_column(primary_key=True)
26    done: Mapped[bool]
27
28
29async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
30    try:
31        async with db_session.begin():
32            yield db_session
33    except IntegrityError as exc:
34        raise ClientException(
35            status_code=HTTP_409_CONFLICT,
36            detail=str(exc),
37        ) from exc
38
39
40async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
41    query = select(TodoItem).where(TodoItem.title == todo_name)
42    result = await session.execute(query)
43    try:
44        return result.scalar_one()
45    except NoResultFound as e:
46        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
47
48
49async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
50    query = select(TodoItem)
51    if done is not None:
52        query = query.where(TodoItem.done.is_(done))
53
54    result = await session.execute(query)
55    return result.scalars().all()
56
57
58@get("/")
59async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
60    return await get_todo_list(done, transaction)
61
62
63@post("/")
64async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
65    transaction.add(data)
66    return data
67
68
69@put("/{item_title:str}")
70async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
71    todo_item = await get_todo_by_title(item_title, transaction)
72    todo_item.title = data.title
73    todo_item.done = data.done
74    return todo_item
75
76
77db_config = SQLAlchemyAsyncConfig(
78    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
79)
80
81app = Litestar(
82    [get_list, add_item, update_item],
83    dependencies={"transaction": provide_transaction},
84    plugins=[
85        SQLAlchemySerializationPlugin(),
86        SQLAlchemyInitPlugin(db_config),
87    ],
88)

The most notable difference is that we no longer need the db_connection() lifespan context manager - the plugin handles this for us. It also handles the creation of the tables in our database if we supply our metadata and set create_all=True when creating a SQLAlchemyAsyncConfig instance.

Additionally, we have a new db_session dependency available to us, which we use in our provide_transaction() dependency provider, instead of creating our own session.

Next steps#

Next up, we’ll make one final change to our application, and then we’ll be recap!