Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/acp/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def send_notification(self, method: str, params: JsonValue | None = None)
async def _receive_loop(self) -> None:
try:
while True:
line = await asyncio.wait_for(self._reader.readline(), timeout=self._receive_timeout)
line = await self._read_line()
if not line:
break
line = line.strip()
Expand All @@ -172,6 +172,24 @@ async def _receive_loop(self) -> None:
raise RequestError.internal_error({"details": "Agent timeout"}) from None
self._disconnect()

async def _read_line(self) -> bytes:
chunks: list[bytes] = []
try:
while True:
try:
line = await self._wait_for_reader(self._reader.readuntil(b"\n"))
except asyncio.LimitOverrunError as exc:
chunks.append(await self._wait_for_reader(self._reader.readexactly(exc.consumed)))
else:
chunks.append(line)
return b"".join(chunks)
except asyncio.IncompleteReadError as exc:
chunks.append(exc.partial)
return b"".join(chunks)

async def _wait_for_reader(self, awaitable: Awaitable[bytes]) -> bytes:
return await asyncio.wait_for(awaitable, timeout=self._receive_timeout)

async def _process_message(self, message: dict[str, Any]) -> None:
method = message.get("method")
has_id = "id" in message
Expand Down
78 changes: 48 additions & 30 deletions tests/real_user/test_stdio_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,46 +48,62 @@ async def test_spawn_stdio_transport_custom_limit_handles_large_line() -> None:
async def test_run_agent_stdio_buffer_limit() -> None:
"""Test that run_agent with different buffer limits can handle appropriately sized messages."""
with tempfile.TemporaryDirectory() as tmpdir:
# Test 1: Small buffer (1KB) fails with large message (70KB)
# Test 1: Small buffer (1KB) reads a large message (70KB) in chunks
small_agent = os.path.join(tmpdir, "small_agent.py")
with open(small_agent, "w") as f:
f.write("""
import asyncio
from acp.core import run_agent
from acp.interfaces import Agent

class TestAgent(Agent):
async def list_capabilities(self):
return {"capabilities": {}}

asyncio.run(run_agent(TestAgent(), stdio_buffer_limit_bytes=1024))
""")

# Send a 70KB message - should fail with 1KB buffer
large_msg = '{"jsonrpc":"2.0","method":"test","params":{"data":"' + "X" * LARGE_LINE_SIZE + '"}}\n'
f.write(
textwrap.dedent(
"""
import asyncio
from acp.core import run_agent
from acp.interfaces import Agent
from acp.schema import InitializeResponse

class TestAgent(Agent):
async def initialize(self, protocol_version, client_capabilities=None, client_info=None, **kwargs):
return InitializeResponse(protocol_version=protocol_version)

asyncio.run(run_agent(TestAgent(), stdio_buffer_limit_bytes=1024))
"""
).strip()
)

# Send a 70KB message - should be read in chunks despite the 1KB buffer
large_msg = (
'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":1,"_meta":{"data":"'
+ "X" * LARGE_LINE_SIZE
+ '"}}}\n'
)
result = subprocess.run( # noqa: S603
[sys.executable, small_agent], input=large_msg, capture_output=True, text=True, timeout=2
)

# Should have errors in stderr about the buffer limit
assert "Error" in result.stderr or result.returncode != 0, (
f"Expected error with small buffer, got: {result.stderr}"
)
assert result.returncode == 0
assert "LimitOverrunError" not in result.stderr
assert "Separator is found, but chunk is longer than limit" not in result.stderr
assert "oversized JSON-RPC frame" not in result.stderr
assert '"id":1' in result.stdout
assert '"protocolVersion":1' in result.stdout

# Test 2: Large buffer (200KB) succeeds with large message (70KB)
large_agent = os.path.join(tmpdir, "large_agent.py")
with open(large_agent, "w") as f:
f.write(f"""
import asyncio
from acp.core import run_agent
from acp.interfaces import Agent

class TestAgent(Agent):
async def list_capabilities(self):
return {{"capabilities": {{}}}}

asyncio.run(run_agent(TestAgent(), stdio_buffer_limit_bytes={LARGE_LINE_SIZE * 3}))
""")
f.write(
textwrap.dedent(
f"""
import asyncio
from acp.core import run_agent
from acp.interfaces import Agent
from acp.schema import InitializeResponse

class TestAgent(Agent):
async def initialize(self, protocol_version, client_capabilities=None, client_info=None, **kwargs):
return InitializeResponse(protocol_version=protocol_version)

asyncio.run(run_agent(TestAgent(), stdio_buffer_limit_bytes={LARGE_LINE_SIZE * 3}))
"""
).strip()
)

# Same message, but with a buffer 3x the size - should handle it
result = subprocess.run( # noqa: S603
Expand All @@ -98,3 +114,5 @@ async def list_capabilities(self):
# (it may have other errors from invalid JSON-RPC, but not buffer overrun)
if "LimitOverrunError" in result.stderr or "buffer" in result.stderr.lower():
pytest.fail(f"Large buffer still hit limit error: {result.stderr}")
assert '"id":1' in result.stdout
assert '"protocolVersion":1' in result.stdout
123 changes: 123 additions & 0 deletions tests/test_connection_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

import asyncio
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock

import pytest

from acp.connection import Connection
from acp.exceptions import RequestError


async def _noop_handler(method: str, params: Any, is_notification: bool) -> Any:
return None


def _make_connection(
*,
limit: int = 128,
receive_timeout: float | None = None,
) -> tuple[Connection, asyncio.StreamReader]:
reader = asyncio.StreamReader(limit=limit)
transport = MagicMock()
transport.is_closing.return_value = False
protocol = AsyncMock()
writer = asyncio.StreamWriter(transport, protocol, reader, asyncio.get_running_loop())
conn = Connection(_noop_handler, writer, reader, listening=False, receive_timeout=receive_timeout)
return conn, reader


@pytest.mark.asyncio
async def test_receive_loop_handles_oversized_frame(caplog: pytest.LogCaptureFixture) -> None:
conn, reader = _make_connection(limit=128)
processed: list[str] = []

async def tracking_process(message: dict[str, Any]) -> None:
processed.append(message["method"])

conn._process_message = tracking_process # type: ignore[method-assign]
oversized = {"jsonrpc": "2.0", "method": "too-large", "params": {"data": "X" * 256}}
survivor = {"jsonrpc": "2.0", "method": "survivor"}
reader.feed_data(json.dumps(oversized).encode() + b"\n" + json.dumps(survivor).encode() + b"\n")
reader.feed_eof()

with caplog.at_level("WARNING"):
await conn._receive_loop()
await conn.close()

assert processed == ["too-large", "survivor"]
assert "oversized JSON-RPC frame" not in caplog.text


@pytest.mark.asyncio
async def test_receive_loop_handles_consecutive_oversized_frames() -> None:
conn, reader = _make_connection(limit=128)
processed: list[str] = []

async def tracking_process(message: dict[str, Any]) -> None:
processed.append(message["method"])

conn._process_message = tracking_process # type: ignore[method-assign]
for index in range(2):
oversized = {"jsonrpc": "2.0", "method": f"too-large-{index}", "params": {"data": "Y" * 256}}
reader.feed_data(json.dumps(oversized).encode() + b"\n")
survivor = {"jsonrpc": "2.0", "method": "survivor"}
reader.feed_data(json.dumps(survivor).encode() + b"\n")
reader.feed_eof()

await conn._receive_loop()
await conn.close()

assert processed == ["too-large-0", "too-large-1", "survivor"]


@pytest.mark.asyncio
async def test_receive_loop_handles_eof_during_oversized_frame() -> None:
conn, reader = _make_connection(limit=64)
reader.feed_data(b"X" * 256)
reader.feed_eof()

await conn._receive_loop()
await conn.close()

assert conn._disconnected is True


@pytest.mark.asyncio
async def test_receive_loop_keeps_timeout_semantics() -> None:
conn, _reader = _make_connection(receive_timeout=0.01)

with pytest.raises(RequestError) as exc_info:
await conn._receive_loop()
await conn.close()

exc = exc_info.value
assert isinstance(exc, RequestError)
assert str(exc) == "Internal error"
assert exc.data == {"details": "Agent timeout"}


@pytest.mark.asyncio
async def test_receive_loop_keeps_timeout_semantics_while_reading_oversized_frame() -> None:
conn, reader = _make_connection(limit=64, receive_timeout=0.01)
reader.feed_data(b"X" * 256)

with pytest.raises(RequestError) as exc_info:
await conn._receive_loop()
await conn.close()

exc = exc_info.value
assert isinstance(exc, RequestError)
assert exc.data == {"details": "Agent timeout"}


@pytest.mark.asyncio
async def test_receive_loop_does_not_swallow_unrelated_reader_error() -> None:
conn, reader = _make_connection()
reader.set_exception(ValueError("reader failed"))

with pytest.raises(ValueError, match="reader failed"):
await conn._receive_loop()
await conn.close()
Loading