diff --git a/packages/gooddata-eval/src/gooddata_eval/core/chat/sse_client.py b/packages/gooddata-eval/src/gooddata_eval/core/chat/sse_client.py index 1d0ca6292..9d0056c4c 100644 --- a/packages/gooddata-eval/src/gooddata_eval/core/chat/sse_client.py +++ b/packages/gooddata-eval/src/gooddata_eval/core/chat/sse_client.py @@ -14,15 +14,88 @@ """ import json +import logging +import os +import time from dataclasses import dataclass, field -from typing import Any, Iterable +from typing import Any, Callable, Iterable, TypeVar import httpx from gooddata_eval.core.models import ChatResult, DatasetItem +_log = logging.getLogger(__name__) + SSE_DATA_PREFIX = "data: " +_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 502, 503, 504}) +_METADATA_SYNC_MARKER = "METADATA_SYNC_IN_PROGRESS" + + +class ChatError(RuntimeError): + """Non-retryable error reported by the chat SSE stream.""" + + def __init__(self, message: str, *, status_code: int | None = None, detail: str | None = None) -> None: + super().__init__(message) + self.status_code = status_code + self.detail = detail + + +class TransientChatError(ChatError): + """Retryable transient error: gen-ai temporarily unavailable or still syncing metadata.""" + + +def _int_env(name: str, default: int) -> int: + """Read an int from the environment, falling back to ``default`` when unset or blank.""" + raw = os.getenv(name) + return int(raw) if raw else default + + +def _float_env(name: str, default: float) -> float: + """Read a float from the environment, falling back to ``default`` when unset or blank.""" + raw = os.getenv(name) + return float(raw) if raw else default + + +# Retry budget. Defaults give a ~2 min worst-case cap per send (5/10/20/40/60s); +# overridable via env so CI can retune without cutting a new gooddata-eval release. +_MAX_RETRIES = _int_env("GOODDATA_EVAL_CHAT_MAX_RETRIES", 5) +_INITIAL_BACKOFF_S = _float_env("GOODDATA_EVAL_CHAT_INITIAL_BACKOFF_S", 5.0) +_BACKOFF_FACTOR = _float_env("GOODDATA_EVAL_CHAT_BACKOFF_FACTOR", 2.0) +_MAX_BACKOFF_S = _float_env("GOODDATA_EVAL_CHAT_MAX_BACKOFF_S", 60.0) + +T = TypeVar("T") + + +def _is_retryable_exc(exc: Exception) -> bool: + if isinstance(exc, TransientChatError): + return True + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code in _RETRYABLE_STATUS_CODES + return False + + +def _retry_transient(operation: Callable[[], T], *, is_retryable: Callable[[Exception], bool]) -> T: + """Run ``operation``; retry retryable failures with bounded exponential backoff.""" + delay = _INITIAL_BACKOFF_S + for attempt in range(_MAX_RETRIES + 1): # 0..N => N retries + 1 initial attempt + try: + return operation() + except Exception as exc: # noqa: PERF203 — retry loop: per-attempt try/except is intentional + if attempt == _MAX_RETRIES or not is_retryable(exc): + raise + sleep_s = min(delay, _MAX_BACKOFF_S) + _log.warning( + "Transient gen-ai error (attempt %d/%d): %s; retrying in %.0fs", + attempt + 1, + _MAX_RETRIES + 1, + exc, + sleep_s, + ) + time.sleep(sleep_s) + delay *= _BACKOFF_FACTOR + raise AssertionError("unreachable") # loop either returns or raises + @dataclass class _SseAccumulator: @@ -114,12 +187,23 @@ def parse_sse_lines(lines: Iterable[str]) -> ChatResult: if not line or line.startswith("event: ") or not line.startswith(SSE_DATA_PREFIX): continue data_str = line[len(SSE_DATA_PREFIX) :] + if _METADATA_SYNC_MARKER in data_str: + raise TransientChatError( + f"SSE transient error: {_METADATA_SYNC_MARKER}", + status_code=None, + detail=None, + ) try: event_data = json.loads(data_str) except json.JSONDecodeError: continue if "statusCode" in event_data: - raise RuntimeError(f"SSE error {event_data.get('statusCode')}: {event_data.get('detail')}") + code = event_data.get("statusCode") + detail = event_data.get("detail") + message = f"SSE error {code}: {detail}" + if code in _RETRYABLE_STATUS_CODES: + raise TransientChatError(message, status_code=code, detail=detail) + raise ChatError(message, status_code=code, detail=detail) item = event_data.get("item") if not item: continue @@ -149,12 +233,17 @@ def __init__(self, host: str, token: str, workspace_id: str, *, timeout: float = self._client = httpx.Client(timeout=timeout) def create_conversation(self) -> str: - resp = self._client.post(self._base, headers={**self._auth, "Content-Type": "application/json"}) - resp.raise_for_status() - body = resp.json() - if "conversationId" not in body: - raise ValueError(f"GoodData /chat/conversations response missing 'conversationId': {body}") - return body["conversationId"] + def _do() -> str: + resp = self._client.post(self._base, headers={**self._auth, "Content-Type": "application/json"}) + resp.raise_for_status() + body = resp.json() + if "conversationId" not in body: + raise ValueError(f"GoodData /chat/conversations response missing 'conversationId': {body}") + return body["conversationId"] + + # NOTE: retrying create is not idempotent — a created-then-503 can leak an + # orphaned (ephemeral) conversation. Acceptable for eval; do not reuse blindly. + return _retry_transient(_do, is_retryable=_is_retryable_exc) def delete_conversation(self, conversation_id: str) -> None: try: @@ -166,9 +255,13 @@ def send_message(self, conversation_id: str, question: str) -> ChatResult: url = f"{self._base}/{conversation_id}/messages" headers = {**self._auth, "Accept": "text/event-stream", "Content-Type": "application/json"} body = {"item": {"role": "user", "content": {"type": "text", "text": question}}} - with self._client.stream("POST", url, json=body, headers=headers) as resp: - resp.raise_for_status() - return parse_sse_lines(resp.iter_lines()) + + def _do() -> ChatResult: + with self._client.stream("POST", url, json=body, headers=headers) as resp: + resp.raise_for_status() + return parse_sse_lines(resp.iter_lines()) + + return _retry_transient(_do, is_retryable=_is_retryable_exc) def ask(self, item: DatasetItem) -> ChatResult: """Run one single-turn conversation: create, send, parse, clean up.""" diff --git a/packages/gooddata-eval/tests/test_sse_client.py b/packages/gooddata-eval/tests/test_sse_client.py index 87a7d3a84..84f0bfa5b 100644 --- a/packages/gooddata-eval/tests/test_sse_client.py +++ b/packages/gooddata-eval/tests/test_sse_client.py @@ -1,8 +1,10 @@ # (C) 2026 GoodData Corporation import json +import httpx import pytest -from gooddata_eval.core.chat.sse_client import parse_sse_lines +from gooddata_eval.core.chat import sse_client as sse_mod +from gooddata_eval.core.chat.sse_client import ChatClient, ChatError, TransientChatError, parse_sse_lines def test_parse_sse_lines_collects_text_and_visualization(fixtures_dir): @@ -77,3 +79,146 @@ def test_parse_sse_lines_prefers_multipart_viz_over_adhoc_fallback(): ] result = parse_sse_lines(lines) assert result.created_visualizations.objects[0].id == "real" + + +@pytest.mark.parametrize("code", [429, 502, 503, 504]) +def test_parse_sse_lines_transient_status_codes(code): + with pytest.raises(TransientChatError) as ei: + parse_sse_lines([f'data: {{"statusCode": {code}, "detail": null}}']) + assert ei.value.status_code == code + + +def test_parse_sse_lines_metadata_sync_is_transient(): + with pytest.raises(TransientChatError): + parse_sse_lines(['data: {"reasonCode": "METADATA_SYNC_IN_PROGRESS"}']) + + +def test_parse_sse_lines_metadata_sync_marker_in_malformed_json_is_transient(): + # marker present but the data payload is not valid JSON -> still transient, not swallowed + with pytest.raises(TransientChatError): + parse_sse_lines(["data: {bad json METADATA_SYNC_IN_PROGRESS"]) + + +def test_parse_sse_lines_non_retryable_status_is_chat_error_not_transient(): + with pytest.raises(ChatError) as ei: + parse_sse_lines(['data: {"statusCode": 400, "detail": "bad"}']) + assert not isinstance(ei.value, TransientChatError) + assert ei.value.status_code == 400 + + +def _client_with_handler(handler): + client = ChatClient(host="https://example.invalid", token="t", workspace_id="w") + client._client = httpx.Client(transport=httpx.MockTransport(handler)) + return client + + +_TRANSIENT_SSE = b'data: {"statusCode": 503, "detail": null}\n' +_NONRETRY_SSE = b'data: {"statusCode": 400, "detail": "bad"}\n' +_OK_SSE = b'data: {"item": {"role": "assistant", "content": {"type": "text", "text": "ok"}}}\n' + + +def test_send_message_retries_transient_then_succeeds(monkeypatch): + sleeps = [] + monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s)) + calls = {"n": 0} + + def handler(request): + calls["n"] += 1 + return httpx.Response(200, content=_TRANSIENT_SSE if calls["n"] < 3 else _OK_SSE) + + client = _client_with_handler(handler) + result = client.send_message("conv", "q") + assert result.text_response == "ok" + assert calls["n"] == 3 + assert sleeps == [5, 10] + + +def test_send_message_backoff_schedule_then_raises(monkeypatch): + sleeps = [] + monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s)) + calls = {"n": 0} + + def handler(request): + calls["n"] += 1 + return httpx.Response(200, content=_TRANSIENT_SSE) + + client = _client_with_handler(handler) + with pytest.raises(TransientChatError): + client.send_message("conv", "q") + assert calls["n"] == 6 # 1 initial + 5 retries + assert sleeps == [5, 10, 20, 40, 60] + + +def test_send_message_does_not_retry_non_transient(monkeypatch): + sleeps = [] + monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s)) + calls = {"n": 0} + + def handler(request): + calls["n"] += 1 + return httpx.Response(200, content=_NONRETRY_SSE) + + client = _client_with_handler(handler) + with pytest.raises(ChatError) as ei: + client.send_message("conv", "q") + assert not isinstance(ei.value, TransientChatError) + assert calls["n"] == 1 + assert sleeps == [] + + +def test_create_conversation_retries_then_succeeds(monkeypatch): + sleeps = [] + monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s)) + calls = {"n": 0} + + def handler(request): + calls["n"] += 1 + if calls["n"] < 3: + return httpx.Response(503) + return httpx.Response(200, json={"conversationId": "abc"}) + + client = _client_with_handler(handler) + assert client.create_conversation() == "abc" + assert calls["n"] == 3 + assert sleeps == [5, 10] + + +def test_create_conversation_does_not_retry_4xx(monkeypatch): + sleeps = [] + monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s)) + calls = {"n": 0} + + def handler(request): + calls["n"] += 1 + return httpx.Response(400) + + client = _client_with_handler(handler) + with pytest.raises(httpx.HTTPStatusError): + client.create_conversation() + assert calls["n"] == 1 + assert sleeps == [] + + +def test_int_env_uses_default_when_unset(monkeypatch): + monkeypatch.delenv("GD_TEST_INT", raising=False) + assert sse_mod._int_env("GD_TEST_INT", 5) == 5 + + +def test_int_env_uses_default_when_blank(monkeypatch): + monkeypatch.setenv("GD_TEST_INT", "") + assert sse_mod._int_env("GD_TEST_INT", 5) == 5 + + +def test_int_env_reads_override(monkeypatch): + monkeypatch.setenv("GD_TEST_INT", "2") + assert sse_mod._int_env("GD_TEST_INT", 5) == 2 + + +def test_float_env_uses_default_when_unset(monkeypatch): + monkeypatch.delenv("GD_TEST_FLOAT", raising=False) + assert sse_mod._float_env("GD_TEST_FLOAT", 5.0) == 5.0 + + +def test_float_env_reads_override(monkeypatch): + monkeypatch.setenv("GD_TEST_FLOAT", "1.5") + assert sse_mod._float_env("GD_TEST_FLOAT", 5.0) == 1.5