Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ output = wavespeed.run(
{"prompt": "Cat"},
timeout=36000.0, # Max wait time in seconds (default: 36000.0)
poll_interval=1.0, # Status check interval (default: 1.0)
enable_sync_mode=False, # Single request mode, no polling (default: False)
enable_sync_mode=False, # Best-effort sync result attempt (default: False)
)
```

### Sync Mode

Use `enable_sync_mode=True` for a single request that waits for the result (no polling).
Use `enable_sync_mode=True` to ask the API to wait for the result in the initial
request. If the server-side sync wait times out, the SDK raises an error with
the task ID/result URL; the task continues processing and can be queried later.

> **Note:** Not all models support sync mode. Check the model documentation for availability.

Expand Down
4 changes: 3 additions & 1 deletion src/wavespeed/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def run(
input: Input parameters for the model.
timeout: Maximum time to wait for completion (None = no timeout).
poll_interval: Interval between status checks in seconds.
enable_sync_mode: If True, use synchronous mode (single request).
enable_sync_mode: If True, use synchronous mode (best-effort single
request). If the server-side sync wait times out, an error is
raised with the task ID so the result can be queried later.
max_retries: Maximum retries for this request (overrides default setting).

Returns:
Expand Down
68 changes: 26 additions & 42 deletions src/wavespeed/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Client:
client = Client(api_key="your-api-key")
output = client.run("wavespeed-ai/z-image/turbo", {"prompt": "Cat"})

# With sync mode (single request, waits for result)
# With sync mode (best-effort single request, waits for result)
output = client.run("wavespeed-ai/z-image/turbo", {"prompt": "Cat"}, enable_sync_mode=True)

# With retry
Expand Down Expand Up @@ -334,6 +334,25 @@ def _is_retryable_error(self, error: Exception) -> bool:

return False

@staticmethod
def _format_sync_mode_error(data: dict[str, Any]) -> str:
"""Build an actionable error for a non-completed sync-mode response."""
request_id = data.get("id") or "unknown"
error = data.get("error") or "Unknown error"
urls = data.get("urls") or {}
result_url = urls.get("get") if isinstance(urls, dict) else None

is_sync_timeout = data.get("code") == 5004 or (
data.get("status") == "processing" and "Sync mode timed out" in error
)
if is_sync_timeout:
message = f"Sync mode timed out (task_id: {request_id}): {error}"
if result_url and result_url not in message:
message += f" Query the result later at: {result_url}"
return message

return f"Prediction failed (task_id: {request_id}): {error}"

def run(
self,
model: str,
Expand All @@ -351,9 +370,9 @@ def run(
input: Input parameters for the model.
timeout: Maximum time to wait for completion (None = no timeout).
poll_interval: Interval between status checks in seconds.
enable_sync_mode: If True, use synchronous mode (single request).
If sync mode fails with a gateway timeout (HTTP 502/504),
the SDK automatically falls back to async mode (submit + poll).
enable_sync_mode: If True, use synchronous mode (best-effort single
request). If the server-side sync wait times out, the SDK raises
an error with the task ID so the result can be queried later.
max_retries: Maximum task-level retries (overrides client setting).

Returns:
Expand All @@ -366,28 +385,19 @@ def run(
"""
task_retries = max_retries if max_retries is not None else self.max_retries
last_error = None
# Track whether we should fall back from sync to async mode.
# This happens when sync mode hits a gateway timeout (502/504) after
# exhausting connection-level retries — the gateway cannot hold the
# connection long enough, but the backend may still be healthy.
use_sync = enable_sync_mode

for attempt in range(task_retries + 1):
try:
request_id, sync_result = self._submit(
model, input, enable_sync_mode=use_sync, timeout=timeout
model, input, enable_sync_mode=enable_sync_mode, timeout=timeout

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Stop resubmitting timed-out sync requests

When enable_sync_mode=True and the initial POST ends with an HTTP 504/502 sync timeout, _submit raises a RuntimeError containing HTTP 5, so _is_retryable_error() treats it as task-retryable and this call resubmits the same sync job on every max_retries attempt. That undermines the new no-fallback/no-resubmit behavior for sync timeouts and can create duplicate predictions for users who configure task retries.

Useful? React with 👍 / 👎.

)

if use_sync:
if enable_sync_mode:
# In sync mode, extract outputs from the result
status = sync_result.get("data", {}).get("status")
if status != "completed":
error = (
sync_result.get("data", {}).get("error") or "Unknown error"
)
request_id = sync_result.get("data", {}).get("id", "unknown")
raise RuntimeError(
f"Prediction failed (task_id: {request_id}): {error}"
self._format_sync_mode_error(sync_result.get("data", {}))
)
data = sync_result.get("data", {})
return {"outputs": data.get("outputs", [])}
Expand All @@ -397,17 +407,6 @@ def run(
except Exception as e:
last_error = e

# Sync-to-async fallback: if sync mode got a gateway timeout
# (502/504) after all connection retries, switch to async mode
# and retry immediately without consuming a task-level retry.
if use_sync and self._is_gateway_timeout(e):
print(
"Sync mode hit gateway timeout, "
"falling back to async mode (submit + poll)..."
)
use_sync = False
continue

is_retryable = self._is_retryable_error(e)

if not is_retryable or attempt >= task_retries:
Expand All @@ -423,21 +422,6 @@ def run(
raise last_error
raise RuntimeError(f"All {task_retries + 1} attempts failed")

@staticmethod
def _is_gateway_timeout(error: Exception) -> bool:
"""Check if an error is a gateway timeout (HTTP 502 or 504).

Args:
error: The exception to check.

Returns:
True if the error indicates a gateway timeout.
"""
if isinstance(error, RuntimeError):
error_str = str(error)
return "HTTP 502" in error_str or "HTTP 504" in error_str
return False

def upload(self, file: str | BinaryIO, *, timeout: float | None = None) -> str:
"""Upload a file to WaveSpeed.

Expand Down
37 changes: 37 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,43 @@ def test_run_timeout(self, mock_post, mock_get, mock_sleep, mock_time):
client.run("wavespeed-ai/z-image/turbo", {"prompt": "test"}, timeout=10)
self.assertIn("timed out", str(ctx.exception))

@patch("wavespeed.api.client.requests.get")
@patch("wavespeed.api.client.requests.post")
def test_run_sync_mode_timeout_raises_without_fallback(self, mock_post, mock_get):
"""Test sync-mode timeout keeps the task queryable without async fallback."""
result_url = "https://api.wavespeed.ai/api/v3/predictions/req-timeout/result"
mock_post_response = MagicMock()
mock_post_response.status_code = 200
mock_post_response.json.return_value = {
"data": {
"id": "req-timeout",
"status": "processing",
"code": 5004,
"error": (
"Sync mode timed out after 90 seconds. The prediction is "
"still processing asynchronously."
),
"urls": {"get": result_url},
}
}
mock_post.return_value = mock_post_response

client = Client(api_key="test-key")
with self.assertRaises(RuntimeError) as ctx:
client.run(
"wavespeed-ai/z-image/turbo",
{"prompt": "test"},
enable_sync_mode=True,
max_retries=1,
)

error = str(ctx.exception)
self.assertIn("Sync mode timed out", error)
self.assertIn("req-timeout", error)
self.assertIn(result_url, error)
mock_post.assert_called_once()
mock_get.assert_not_called()


class TestModuleLevelRun(unittest.TestCase):
"""Tests for the module-level run() function."""
Expand Down
Loading