Using the init plugin#

In our example application, we’ve seen that we need to manage the database engine within the scope of the application’s lifespan, and the session within the scope of a request. This is a common pattern, and the SQLAlchemyInitPlugin plugin provides assistance for this.

In our latest update, we leverage two features of the plugin:

  1. The plugin will automatically create a database engine for us and manage it within the scope of the application’s lifespan.

  2. The plugin will automatically create a database session for us and manage it within the scope of a request.

We access the database session via dependency injection, using the db_session parameter.

Here’s the updated code:

 1from typing import AsyncGenerator, List, Optional
 2
 3from sqlalchemy import select
 4from sqlalchemy.exc import IntegrityError, NoResultFound
 5from sqlalchemy.ext.asyncio import AsyncSession
 6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 7
 8from litestar import Litestar, get, post, put
 9from litestar.contrib.sqlalchemy.plugins import (
10    SQLAlchemyAsyncConfig,
11    SQLAlchemyInitPlugin,
12    SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22    __tablename__ = "todo_items"
23
24    title: Mapped[str] = mapped_column(primary_key=True)
25    done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29    try:
30        async with db_session.begin():
31            yield db_session
32    except IntegrityError as exc:
33        raise ClientException(
34            status_code=HTTP_409_CONFLICT,
35            detail=str(exc),
36        ) from exc
37
38
39async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
40    query = select(TodoItem).where(TodoItem.title == todo_name)
41    result = await session.execute(query)
42    try:
43        return result.scalar_one()
44    except NoResultFound as e:
45        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
49    query = select(TodoItem)
50    if done is not None:
51        query = query.where(TodoItem.done.is_(done))
52
53    result = await session.execute(query)
54    return result.scalars().all()
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> List[TodoItem]:
59    return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64    transaction.add(data)
65    return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70    todo_item = await get_todo_by_title(item_title, transaction)
71    todo_item.title = data.title
72    todo_item.done = data.done
73    return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81    [get_list, add_item, update_item],
82    dependencies={"transaction": provide_transaction},
83    plugins=[
84        SQLAlchemySerializationPlugin(),
85        SQLAlchemyInitPlugin(db_config),
86    ],
87)
 1from typing import Optional
 2from collections.abc import AsyncGenerator
 3
 4from sqlalchemy import select
 5from sqlalchemy.exc import IntegrityError, NoResultFound
 6from sqlalchemy.ext.asyncio import AsyncSession
 7from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 8
 9from litestar import Litestar, get, post, put
10from litestar.contrib.sqlalchemy.plugins import (
11    SQLAlchemyAsyncConfig,
12    SQLAlchemyInitPlugin,
13    SQLAlchemySerializationPlugin,
14)
15from litestar.exceptions import ClientException, NotFoundException
16from litestar.status_codes import HTTP_409_CONFLICT
17
18
19class Base(DeclarativeBase): ...
20
21
22class TodoItem(Base):
23    __tablename__ = "todo_items"
24
25    title: Mapped[str] = mapped_column(primary_key=True)
26    done: Mapped[bool]
27
28
29async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
30    try:
31        async with db_session.begin():
32            yield db_session
33    except IntegrityError as exc:
34        raise ClientException(
35            status_code=HTTP_409_CONFLICT,
36            detail=str(exc),
37        ) from exc
38
39
40async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
41    query = select(TodoItem).where(TodoItem.title == todo_name)
42    result = await session.execute(query)
43    try:
44        return result.scalar_one()
45    except NoResultFound as e:
46        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
47
48
49async def get_todo_list(done: Optional[bool], session: AsyncSession) -> list[TodoItem]:
50    query = select(TodoItem)
51    if done is not None:
52        query = query.where(TodoItem.done.is_(done))
53
54    result = await session.execute(query)
55    return result.scalars().all()
56
57
58@get("/")
59async def get_list(transaction: AsyncSession, done: Optional[bool] = None) -> list[TodoItem]:
60    return await get_todo_list(done, transaction)
61
62
63@post("/")
64async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
65    transaction.add(data)
66    return data
67
68
69@put("/{item_title:str}")
70async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
71    todo_item = await get_todo_by_title(item_title, transaction)
72    todo_item.title = data.title
73    todo_item.done = data.done
74    return todo_item
75
76
77db_config = SQLAlchemyAsyncConfig(
78    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
79)
80
81app = Litestar(
82    [get_list, add_item, update_item],
83    dependencies={"transaction": provide_transaction},
84    plugins=[
85        SQLAlchemySerializationPlugin(),
86        SQLAlchemyInitPlugin(db_config),
87    ],
88)
 1from collections.abc import AsyncGenerator
 2
 3from sqlalchemy import select
 4from sqlalchemy.exc import IntegrityError, NoResultFound
 5from sqlalchemy.ext.asyncio import AsyncSession
 6from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 7
 8from litestar import Litestar, get, post, put
 9from litestar.contrib.sqlalchemy.plugins import (
10    SQLAlchemyAsyncConfig,
11    SQLAlchemyInitPlugin,
12    SQLAlchemySerializationPlugin,
13)
14from litestar.exceptions import ClientException, NotFoundException
15from litestar.status_codes import HTTP_409_CONFLICT
16
17
18class Base(DeclarativeBase): ...
19
20
21class TodoItem(Base):
22    __tablename__ = "todo_items"
23
24    title: Mapped[str] = mapped_column(primary_key=True)
25    done: Mapped[bool]
26
27
28async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncSession, None]:
29    try:
30        async with db_session.begin():
31            yield db_session
32    except IntegrityError as exc:
33        raise ClientException(
34            status_code=HTTP_409_CONFLICT,
35            detail=str(exc),
36        ) from exc
37
38
39async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
40    query = select(TodoItem).where(TodoItem.title == todo_name)
41    result = await session.execute(query)
42    try:
43        return result.scalar_one()
44    except NoResultFound as e:
45        raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e
46
47
48async def get_todo_list(done: bool | None, session: AsyncSession) -> list[TodoItem]:
49    query = select(TodoItem)
50    if done is not None:
51        query = query.where(TodoItem.done.is_(done))
52
53    result = await session.execute(query)
54    return result.scalars().all()
55
56
57@get("/")
58async def get_list(transaction: AsyncSession, done: bool | None = None) -> list[TodoItem]:
59    return await get_todo_list(done, transaction)
60
61
62@post("/")
63async def add_item(data: TodoItem, transaction: AsyncSession) -> TodoItem:
64    transaction.add(data)
65    return data
66
67
68@put("/{item_title:str}")
69async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession) -> TodoItem:
70    todo_item = await get_todo_by_title(item_title, transaction)
71    todo_item.title = data.title
72    todo_item.done = data.done
73    return todo_item
74
75
76db_config = SQLAlchemyAsyncConfig(
77    connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True
78)
79
80app = Litestar(
81    [get_list, add_item, update_item],
82    dependencies={"transaction": provide_transaction},
83    plugins=[
84        SQLAlchemySerializationPlugin(),
85        SQLAlchemyInitPlugin(db_config),
86    ],
87)

The most notable difference is that we no longer need the db_connection() lifespan context manager - the plugin handles this for us. It also handles the creation of the tables in our database if we supply our metadata and set create_all=True when creating a SQLAlchemyAsyncConfig instance.

Additionally, we have a new db_session dependency available to us, which we use in our provide_transaction() dependency provider, instead of creating our own session.

Next steps#

Next up, we’ll make one final change to our application, and then we’ll be recap!