diff --git a/app/resources/db.py b/app/resources/db.py index c2723d5..99944bc 100644 --- a/app/resources/db.py +++ b/app/resources/db.py @@ -25,16 +25,16 @@ async def close_sa_engine(engine: sa.AsyncEngine) -> None: await engine.dispose() -class CustomAsyncSession(sa.AsyncSession): - async def close(self) -> None: - if isinstance(self.bind, sa.AsyncConnection): - return self.expunge_all() - - return await super().close() - - def create_session(engine: sa.AsyncEngine) -> sa.AsyncSession: - return CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) + # join_transaction_mode is inert in production (the session binds to an engine); when tests bind + # the session to a connection already in a transaction, it makes the session own a savepoint so + # the outer transaction survives commits and the per-test rollback stays clean. + return sa.AsyncSession( + engine, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) async def close_session(session: sa.AsyncSession) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index f27ba25..1845ba7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,11 +47,15 @@ async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[ engine = create_sa_engine() connection = await engine.connect() transaction = await connection.begin() - await connection.begin_nested() di_container.override(ioc.Dependencies.database_engine, connection) try: - yield AsyncSession(connection, expire_on_commit=False, autoflush=False) + yield AsyncSession( + connection, + expire_on_commit=False, + autoflush=False, + join_transaction_mode="create_savepoint", + ) finally: if connection.in_transaction(): await transaction.rollback()