From 747343871e4d600cfb74460e0946db0b0c2d5a31 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Fri, 26 Jun 2026 12:25:20 +0300 Subject: [PATCH] refactor: extract test-transaction seam from production session CustomAsyncSession existed only so the per-test rollback worked: its close() became a no-op when bound to a connection. That is a test concern living in the production data-access module. Replace the hack with SQLAlchemy 2.0's join_transaction_mode="create_savepoint". Each session owns its savepoint, so auto_commit releases a savepoint while the fixture's outer transaction survives and is rolled back per test. The kwarg is inert in production (the session binds to an engine, never an in-transaction connection). Production session is now a plain AsyncSession with no test branch; the conftest fixture drops begin_nested() and owns the savepoint lifecycle via the same mode. Co-Authored-By: Claude Opus 4.8 (1M context) --- app/resources/db.py | 18 +++++++++--------- tests/conftest.py | 8 ++++++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/app/resources/db.py b/app/resources/db.py index c2723d5..99944bc 100644 --- a/app/resources/db.py +++ b/app/resources/db.py @@ -25,16 +25,16 @@ async def close_sa_engine(engine: sa.AsyncEngine) -> None: await engine.dispose() -class CustomAsyncSession(sa.AsyncSession): - async def close(self) -> None: - if isinstance(self.bind, sa.AsyncConnection): - return self.expunge_all() - - return await super().close() - - def create_session(engine: sa.AsyncEngine) -> sa.AsyncSession: - return CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) + # join_transaction_mode is inert in production (the session binds to an engine); when tests bind + # the session to a connection already in a transaction, it makes the session own a savepoint so + # the outer transaction survives commits and the per-test rollback stays clean. + return sa.AsyncSession( + engine, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) async def close_session(session: sa.AsyncSession) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index f27ba25..1845ba7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,11 +47,15 @@ async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[ engine = create_sa_engine() connection = await engine.connect() transaction = await connection.begin() - await connection.begin_nested() di_container.override(ioc.Dependencies.database_engine, connection) try: - yield AsyncSession(connection, expire_on_commit=False, autoflush=False) + yield AsyncSession( + connection, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) finally: if connection.in_transaction(): await transaction.rollback()