Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 104 additions & 11 deletions packages/gooddata-eval/src/gooddata_eval/core/chat/sse_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,88 @@
"""

import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Iterable
from typing import Any, Callable, Iterable, TypeVar

import httpx

from gooddata_eval.core.models import ChatResult, DatasetItem

_log = logging.getLogger(__name__)

SSE_DATA_PREFIX = "data: "

_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 502, 503, 504})
_METADATA_SYNC_MARKER = "METADATA_SYNC_IN_PROGRESS"


class ChatError(RuntimeError):
"""Non-retryable error reported by the chat SSE stream."""

def __init__(self, message: str, *, status_code: int | None = None, detail: str | None = None) -> None:
super().__init__(message)
self.status_code = status_code
self.detail = detail


class TransientChatError(ChatError):
"""Retryable transient error: gen-ai temporarily unavailable or still syncing metadata."""


def _int_env(name: str, default: int) -> int:
"""Read an int from the environment, falling back to ``default`` when unset or blank."""
raw = os.getenv(name)
return int(raw) if raw else default


def _float_env(name: str, default: float) -> float:
"""Read a float from the environment, falling back to ``default`` when unset or blank."""
raw = os.getenv(name)
return float(raw) if raw else default


# Retry budget. Defaults give a ~2 min worst-case cap per send (5/10/20/40/60s);
# overridable via env so CI can retune without cutting a new gooddata-eval release.
_MAX_RETRIES = _int_env("GOODDATA_EVAL_CHAT_MAX_RETRIES", 5)
_INITIAL_BACKOFF_S = _float_env("GOODDATA_EVAL_CHAT_INITIAL_BACKOFF_S", 5.0)
_BACKOFF_FACTOR = _float_env("GOODDATA_EVAL_CHAT_BACKOFF_FACTOR", 2.0)
_MAX_BACKOFF_S = _float_env("GOODDATA_EVAL_CHAT_MAX_BACKOFF_S", 60.0)

T = TypeVar("T")


def _is_retryable_exc(exc: Exception) -> bool:
if isinstance(exc, TransientChatError):
return True
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code in _RETRYABLE_STATUS_CODES
return False


def _retry_transient(operation: Callable[[], T], *, is_retryable: Callable[[Exception], bool]) -> T:
"""Run ``operation``; retry retryable failures with bounded exponential backoff."""
delay = _INITIAL_BACKOFF_S
for attempt in range(_MAX_RETRIES + 1): # 0..N => N retries + 1 initial attempt
try:
return operation()
except Exception as exc: # noqa: PERF203 — retry loop: per-attempt try/except is intentional
if attempt == _MAX_RETRIES or not is_retryable(exc):
raise
sleep_s = min(delay, _MAX_BACKOFF_S)
_log.warning(
"Transient gen-ai error (attempt %d/%d): %s; retrying in %.0fs",
attempt + 1,
_MAX_RETRIES + 1,
exc,
sleep_s,
)
time.sleep(sleep_s)
delay *= _BACKOFF_FACTOR
raise AssertionError("unreachable") # loop either returns or raises


@dataclass
class _SseAccumulator:
Expand Down Expand Up @@ -114,12 +187,23 @@ def parse_sse_lines(lines: Iterable[str]) -> ChatResult:
if not line or line.startswith("event: ") or not line.startswith(SSE_DATA_PREFIX):
continue
data_str = line[len(SSE_DATA_PREFIX) :]
if _METADATA_SYNC_MARKER in data_str:
raise TransientChatError(
f"SSE transient error: {_METADATA_SYNC_MARKER}",
status_code=None,
detail=None,
)
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
continue
if "statusCode" in event_data:
raise RuntimeError(f"SSE error {event_data.get('statusCode')}: {event_data.get('detail')}")
code = event_data.get("statusCode")
detail = event_data.get("detail")
message = f"SSE error {code}: {detail}"
if code in _RETRYABLE_STATUS_CODES:
raise TransientChatError(message, status_code=code, detail=detail)
raise ChatError(message, status_code=code, detail=detail)
item = event_data.get("item")
if not item:
continue
Expand Down Expand Up @@ -149,12 +233,17 @@ def __init__(self, host: str, token: str, workspace_id: str, *, timeout: float =
self._client = httpx.Client(timeout=timeout)

def create_conversation(self) -> str:
resp = self._client.post(self._base, headers={**self._auth, "Content-Type": "application/json"})
resp.raise_for_status()
body = resp.json()
if "conversationId" not in body:
raise ValueError(f"GoodData /chat/conversations response missing 'conversationId': {body}")
return body["conversationId"]
def _do() -> str:
resp = self._client.post(self._base, headers={**self._auth, "Content-Type": "application/json"})
resp.raise_for_status()
body = resp.json()
if "conversationId" not in body:
raise ValueError(f"GoodData /chat/conversations response missing 'conversationId': {body}")
return body["conversationId"]

# NOTE: retrying create is not idempotent — a created-then-503 can leak an
# orphaned (ephemeral) conversation. Acceptable for eval; do not reuse blindly.
return _retry_transient(_do, is_retryable=_is_retryable_exc)

def delete_conversation(self, conversation_id: str) -> None:
try:
Expand All @@ -166,9 +255,13 @@ def send_message(self, conversation_id: str, question: str) -> ChatResult:
url = f"{self._base}/{conversation_id}/messages"
headers = {**self._auth, "Accept": "text/event-stream", "Content-Type": "application/json"}
body = {"item": {"role": "user", "content": {"type": "text", "text": question}}}
with self._client.stream("POST", url, json=body, headers=headers) as resp:
resp.raise_for_status()
return parse_sse_lines(resp.iter_lines())

def _do() -> ChatResult:
with self._client.stream("POST", url, json=body, headers=headers) as resp:
resp.raise_for_status()
return parse_sse_lines(resp.iter_lines())

return _retry_transient(_do, is_retryable=_is_retryable_exc)

def ask(self, item: DatasetItem) -> ChatResult:
"""Run one single-turn conversation: create, send, parse, clean up."""
Expand Down
147 changes: 146 additions & 1 deletion packages/gooddata-eval/tests/test_sse_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# (C) 2026 GoodData Corporation
import json

import httpx
import pytest
from gooddata_eval.core.chat.sse_client import parse_sse_lines
from gooddata_eval.core.chat import sse_client as sse_mod
from gooddata_eval.core.chat.sse_client import ChatClient, ChatError, TransientChatError, parse_sse_lines


def test_parse_sse_lines_collects_text_and_visualization(fixtures_dir):
Expand Down Expand Up @@ -77,3 +79,146 @@ def test_parse_sse_lines_prefers_multipart_viz_over_adhoc_fallback():
]
result = parse_sse_lines(lines)
assert result.created_visualizations.objects[0].id == "real"


@pytest.mark.parametrize("code", [429, 502, 503, 504])
def test_parse_sse_lines_transient_status_codes(code):
with pytest.raises(TransientChatError) as ei:
parse_sse_lines([f'data: {{"statusCode": {code}, "detail": null}}'])
assert ei.value.status_code == code


def test_parse_sse_lines_metadata_sync_is_transient():
with pytest.raises(TransientChatError):
parse_sse_lines(['data: {"reasonCode": "METADATA_SYNC_IN_PROGRESS"}'])


def test_parse_sse_lines_metadata_sync_marker_in_malformed_json_is_transient():
# marker present but the data payload is not valid JSON -> still transient, not swallowed
with pytest.raises(TransientChatError):
parse_sse_lines(["data: {bad json METADATA_SYNC_IN_PROGRESS"])


def test_parse_sse_lines_non_retryable_status_is_chat_error_not_transient():
with pytest.raises(ChatError) as ei:
parse_sse_lines(['data: {"statusCode": 400, "detail": "bad"}'])
assert not isinstance(ei.value, TransientChatError)
assert ei.value.status_code == 400


def _client_with_handler(handler):
client = ChatClient(host="https://example.invalid", token="t", workspace_id="w")
client._client = httpx.Client(transport=httpx.MockTransport(handler))
return client


_TRANSIENT_SSE = b'data: {"statusCode": 503, "detail": null}\n'
_NONRETRY_SSE = b'data: {"statusCode": 400, "detail": "bad"}\n'
_OK_SSE = b'data: {"item": {"role": "assistant", "content": {"type": "text", "text": "ok"}}}\n'


def test_send_message_retries_transient_then_succeeds(monkeypatch):
sleeps = []
monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s))
calls = {"n": 0}

def handler(request):
calls["n"] += 1
return httpx.Response(200, content=_TRANSIENT_SSE if calls["n"] < 3 else _OK_SSE)

client = _client_with_handler(handler)
result = client.send_message("conv", "q")
assert result.text_response == "ok"
assert calls["n"] == 3
assert sleeps == [5, 10]


def test_send_message_backoff_schedule_then_raises(monkeypatch):
sleeps = []
monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s))
calls = {"n": 0}

def handler(request):
calls["n"] += 1
return httpx.Response(200, content=_TRANSIENT_SSE)

client = _client_with_handler(handler)
with pytest.raises(TransientChatError):
client.send_message("conv", "q")
assert calls["n"] == 6 # 1 initial + 5 retries
assert sleeps == [5, 10, 20, 40, 60]


def test_send_message_does_not_retry_non_transient(monkeypatch):
sleeps = []
monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s))
calls = {"n": 0}

def handler(request):
calls["n"] += 1
return httpx.Response(200, content=_NONRETRY_SSE)

client = _client_with_handler(handler)
with pytest.raises(ChatError) as ei:
client.send_message("conv", "q")
assert not isinstance(ei.value, TransientChatError)
assert calls["n"] == 1
assert sleeps == []


def test_create_conversation_retries_then_succeeds(monkeypatch):
sleeps = []
monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s))
calls = {"n": 0}

def handler(request):
calls["n"] += 1
if calls["n"] < 3:
return httpx.Response(503)
return httpx.Response(200, json={"conversationId": "abc"})

client = _client_with_handler(handler)
assert client.create_conversation() == "abc"
assert calls["n"] == 3
assert sleeps == [5, 10]


def test_create_conversation_does_not_retry_4xx(monkeypatch):
sleeps = []
monkeypatch.setattr(sse_mod.time, "sleep", lambda s: sleeps.append(s))
calls = {"n": 0}

def handler(request):
calls["n"] += 1
return httpx.Response(400)

client = _client_with_handler(handler)
with pytest.raises(httpx.HTTPStatusError):
client.create_conversation()
assert calls["n"] == 1
assert sleeps == []


def test_int_env_uses_default_when_unset(monkeypatch):
monkeypatch.delenv("GD_TEST_INT", raising=False)
assert sse_mod._int_env("GD_TEST_INT", 5) == 5


def test_int_env_uses_default_when_blank(monkeypatch):
monkeypatch.setenv("GD_TEST_INT", "")
assert sse_mod._int_env("GD_TEST_INT", 5) == 5


def test_int_env_reads_override(monkeypatch):
monkeypatch.setenv("GD_TEST_INT", "2")
assert sse_mod._int_env("GD_TEST_INT", 5) == 2


def test_float_env_uses_default_when_unset(monkeypatch):
monkeypatch.delenv("GD_TEST_FLOAT", raising=False)
assert sse_mod._float_env("GD_TEST_FLOAT", 5.0) == 5.0


def test_float_env_reads_override(monkeypatch):
monkeypatch.setenv("GD_TEST_FLOAT", "1.5")
assert sse_mod._float_env("GD_TEST_FLOAT", 5.0) == 1.5
Loading