From d3a5b37cf8e45d2aac39e4fe1daa7d1fcfb42d75 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Thu, 25 Jun 2026 22:13:36 +0000 Subject: [PATCH 01/11] Add Model Provider Service routing for claude and codex Route claude/codex through a Unity Catalog Model Provider Service (external Anthropic/OpenAI provider) via the Databricks-Model-Provider-Service header instead of pinning a Databricks model. - `ucode claude/codex --provider ` routes per-invocation; verifies the MPS feature is enabled and fails with a clear message if not. - `ucode configure --model-provider` opt-in: lists matching services (anthropic for claude, openai for codex), persists the choice per tool; launches then use it automatically. Default configure path is unchanged. - Provider mode pins no Databricks model (the agent's own canonical names route through the header) and skips the heavy model discovery, fetching only a web-search model. - Friendlier USE CONNECTION permission error naming the service. Co-authored-by: Isaac --- src/ucode/agents/__init__.py | 76 ++++++---- src/ucode/agents/claude.py | 32 +++-- src/ucode/agents/codex.py | 57 ++++++-- src/ucode/cli.py | 267 +++++++++++++++++++++++++++++------ src/ucode/databricks.py | 75 ++++++++++ src/ucode/state.py | 27 ++++ src/ucode/ui.py | 26 ++++ tests/test_agent_claude.py | 25 ++++ tests/test_agent_codex.py | 37 +++++ tests/test_agents_init.py | 22 +++ tests/test_cli.py | 81 +++++++++++ tests/test_databricks.py | 64 +++++++++ tests/test_e2e.py | 134 ++++++++++++++++++ tests/test_state.py | 28 ++++ 14 files changed, 866 insertions(+), 85 deletions(-) diff --git a/src/ucode/agents/__init__.py b/src/ucode/agents/__init__.py index b94e855..414c354 100644 --- a/src/ucode/agents/__init__.py +++ b/src/ucode/agents/__init__.py @@ -20,7 +20,7 @@ from ucode.databricks import ( install_databricks_cli, ) -from ucode.state import load_state, save_state +from ucode.state import get_provider_service, load_state, save_state from ucode.telemetry import agent_version from ucode.ui import ( console, @@ -254,16 +254,23 @@ def resolve_launch_model( return state, model -def configure_tool(tool: str, state: dict, model: str | None = None) -> dict: +def configure_tool( + tool: str, state: dict, model: str | None = None, provider: str | None = None +) -> dict: result: dict | tuple[dict, str] if tool == "codex": - result = codex.write_tool_config(state, model) + result = codex.write_tool_config(state, model, provider=provider) + elif tool == "claude": + # A Model Provider Service routes by header and pins no Databricks + # model, so the usual "model required" guard doesn't apply to claude. + if not model and not provider: + raise RuntimeError(f"A {tool} model must be selected before configuration.") + result = claude.write_tool_config(state, model, provider=provider) else: + # provider routing is claude/codex-only; every other tool needs a model. if not model: raise RuntimeError(f"A {tool} model must be selected before configuration.") - if tool == "claude": - result = claude.write_tool_config(state, model) - elif tool == "gemini": + if tool == "gemini": result = gemini.write_tool_config(state, model) elif tool == "copilot": result = copilot.write_tool_config(state, model) @@ -325,24 +332,34 @@ def _availability_failure_detail(tool: str, state: dict) -> str: def configure_single_tool(tool: str, state: dict) -> dict: """Check availability, configure, and persist state for one tool only.""" - with spinner(f"Checking {TOOL_SPECS[tool]['display']} availability..."): - ok = check_gateway_endpoint(state, tool) - if not ok: - detail = _availability_failure_detail(tool, state) - raise RuntimeError( - f"{TOOL_SPECS[tool]['display']} is not available on this workspace.{detail}" - ) - if tool == "codex": - state = configure_tool("codex", state) - else: - state, model = resolve_launch_model(tool, state, None) - state = configure_tool(tool, state, model) + provider = get_provider_service(state, tool) + # A Model Provider Service routes through the same gateway and pins no + # Databricks model, so the per-tool model availability check doesn't apply. + if not provider: + with spinner(f"Checking {TOOL_SPECS[tool]['display']} availability..."): + ok = check_gateway_endpoint(state, tool) + if not ok: + detail = _availability_failure_detail(tool, state) + raise RuntimeError( + f"{TOOL_SPECS[tool]['display']} is not available on this workspace.{detail}" + ) + state = _configure_one(tool, state, provider) available_tools = list(set((state.get("available_tools") or []) + [tool])) state["available_tools"] = available_tools save_state(state) return state +def _configure_one(tool: str, state: dict, provider: str | None) -> dict: + """Write one tool's config, routing through ``provider`` when set.""" + if provider: + return configure_tool(tool, state, None, provider=provider) + if tool == "codex": + return configure_tool("codex", state) + state, model = resolve_launch_model(tool, state, None) + return configure_tool(tool, state, model) + + def configure_selected_tools(state: dict, tools: list[str]) -> dict: """Configure the given tools. Caller is responsible for ensuring each tool is available on the workspace. @@ -352,11 +369,7 @@ def configure_selected_tools(state: dict, tools: list[str]) -> dict: run is preserved. """ for tool in tools: - if tool == "codex": - state = configure_tool("codex", state) - else: - state, model = resolve_launch_model(tool, state, None) - state = configure_tool(tool, state, model) + state = _configure_one(tool, state, get_provider_service(state, tool)) existing = state.get("available_tools") or [] state["available_tools"] = sorted(set(existing) | set(tools)) @@ -440,6 +453,21 @@ def validate_tool(tool: str) -> tuple[bool, str]: return False, "timed out" +def provider_permission_error(tool: str, state: dict, err: str) -> str: + """Rewrite the opaque gateway connection-permission failure into an + actionable message naming the Model Provider Service the user must be + granted access to. Returns ``err`` unchanged when it doesn't apply. + """ + provider = get_provider_service(state, tool) + if provider and "USE CONNECTION on SCHEMA_CONNECTION" in err: + return ( + f"You don't have EXECUTE permission on the model provider service " + f"'{provider}'. Ask its owner to grant you access, then re-run " + f"`ucode configure`." + ) + return err + + def validate_all_tools(state: dict) -> None: from rich.panel import Panel # local to avoid bumping module-level deps @@ -470,7 +498,7 @@ def validate_all_tools(state: dict) -> None: if ok: print_success(f"{spec['display']} is working") else: - print_err(f"{spec['display']}: {err}") + print_err(f"{spec['display']}: {provider_permission_error(tool, state, err)}") managed = bool(state.get("managed_configs", {}).get(tool)) restore_file(spec["config_path"], spec["backup_path"], managed) # Rollback settings.json for Pi diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index c335948..c733178 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -124,28 +124,36 @@ def _web_search_mcp_entry(workspace: str, search_model: str, profile: str | None def render_overlay( workspace: str, - model: str, + model: str | None, claude_models: dict[str, str] | None = None, disable_web_search: bool = False, profile: str | None = None, use_pat: bool = False, + provider: str | None = None, ) -> tuple[dict, list[list[str]]]: """Return (overlay, managed_key_paths) for Claude settings.json. NOTE: MCP servers are NOT written here. Claude Code reads `mcpServers` from `~/.claude.json`, not `~/.claude/settings.json` — registration goes - through `claude mcp add-json` (see `_register_web_search_mcp`).""" + through `claude mcp add-json` (see `_register_web_search_mcp`). + + When `provider` is set (a `..` Model Provider + Service), the request is routed to that external provider via the + `Databricks-Model-Provider-Service` header and no Databricks model id is + pinned — Claude Code uses its own canonical model names, which the provider + understands.""" base_url = build_tool_base_url("claude", workspace) # ANTHROPIC_CUSTOM_HEADERS is parsed as `key: value` pairs separated by # newlines (Anthropic SDK convention). Setting User-Agent here overrides # the SDK's default UA on outbound requests so the gateway can attribute # traffic to ucode. - custom_headers = "\n".join( - [ - "x-databricks-use-coding-agent-mode: true", - f"User-Agent: ucode/{ucode_version()} claude/{agent_version('claude')}", - ] - ) + header_lines = [ + "x-databricks-use-coding-agent-mode: true", + f"User-Agent: ucode/{ucode_version()} claude/{agent_version('claude')}", + ] + if provider: + header_lines.append(f"Databricks-Model-Provider-Service: {provider}") + custom_headers = "\n".join(header_lines) env: dict[str, str] = { "ANTHROPIC_BASE_URL": base_url, "ANTHROPIC_CUSTOM_HEADERS": custom_headers, @@ -159,7 +167,10 @@ def render_overlay( # only one row per model. `ucode claude -- --model X` still overrides for a # single session via Claude Code's own --model flag. _ = model # API stability; no longer pinned via env. - if claude_models: + # With a Model Provider Service, the header routes to the external provider + # and Claude Code's own canonical model names are sent verbatim — pinning a + # Databricks model id here would mislabel the picker and isn't routable. + if claude_models and not provider: # Picker rows show the raw routable id (e.g. "system.ai.claude-opus-4-8[1m]") # so users can see which gateway-routable model is behind each shortcut. # We deliberately don't set the `_NAME` companion env vars — the raw id @@ -243,7 +254,7 @@ def _unregister_web_search_mcp() -> None: pass -def write_tool_config(state: dict, model: str) -> dict: +def write_tool_config(state: dict, model: str | None, provider: str | None = None) -> dict: backup_existing_file(CLAUDE_SETTINGS_PATH, CLAUDE_BACKUP_PATH) web_search_model = _resolve_web_search_model(state) overlay, managed_keys = render_overlay( @@ -253,6 +264,7 @@ def write_tool_config(state: dict, model: str) -> dict: disable_web_search=web_search_model is not None, profile=state.get("profile"), use_pat=bool(state.get("use_pat")), + provider=provider, ) tracing_env_vars = tracing_env(state, "claude") stop_hook_command = claude_tracing_stop_hook_command() if tracing_env_vars else None diff --git a/src/ucode/agents/codex.py b/src/ucode/agents/codex.py index 4127afa..c0c8786 100644 --- a/src/ucode/agents/codex.py +++ b/src/ucode/agents/codex.py @@ -101,16 +101,26 @@ def _use_legacy_layout() -> bool: return parsed < MINIMUM_CODEX_VERSION -def _provider_block(workspace: str, databricks_profile: str | None, use_pat: bool = False) -> dict: +def _provider_block( + workspace: str, + databricks_profile: str | None, + use_pat: bool = False, + provider: str | None = None, +) -> dict: auth_command = build_auth_shell_command(workspace, databricks_profile, use_pat=use_pat) base_url = build_tool_base_url("codex", workspace) + http_headers = { + "User-Agent": f"ucode/{ucode_version()} codex/{agent_version('codex')}", + } + # Route to an external Model Provider Service; the gateway selects the + # provider from this header on every request. + if provider: + http_headers["Databricks-Model-Provider-Service"] = provider return { "name": "Databricks AI Gateway", "base_url": base_url, "wire_api": "responses", - "http_headers": { - "User-Agent": f"ucode/{ucode_version()} codex/{agent_version('codex')}", - }, + "http_headers": http_headers, "auth": { "command": "sh", "args": ["-c", auth_command], @@ -125,12 +135,15 @@ def render_overlay( model: str | None = None, databricks_profile: str | None = None, use_pat: bool = False, + provider: str | None = None, ) -> dict: overlay: dict = {"model_provider": CODEX_MODEL_PROVIDER_NAME} if model: overlay["model"] = model overlay["model_providers"] = { - CODEX_MODEL_PROVIDER_NAME: _provider_block(workspace, databricks_profile, use_pat), + CODEX_MODEL_PROVIDER_NAME: _provider_block( + workspace, databricks_profile, use_pat, provider + ), } return overlay @@ -140,6 +153,7 @@ def render_legacy_overlay( model: str | None = None, databricks_profile: str | None = None, use_pat: bool = False, + provider: str | None = None, ) -> dict: """Overlay for Codex CLI < 0.134.0, which only reads `~/.codex/config.toml`. @@ -153,7 +167,9 @@ def render_legacy_overlay( "profile": CODEX_PROFILE_NAME, "profiles": {CODEX_PROFILE_NAME: profile_block}, "model_providers": { - CODEX_MODEL_PROVIDER_NAME: _provider_block(workspace, databricks_profile, use_pat), + CODEX_MODEL_PROVIDER_NAME: _provider_block( + workspace, databricks_profile, use_pat, provider + ), }, } @@ -290,9 +306,12 @@ def _parse_gpt(model: str | None) -> tuple[int, int | None, int | None, str] | N ) -def write_tool_config(state: dict, model: str | None = None) -> dict: +def write_tool_config(state: dict, model: str | None = None, provider: str | None = None) -> dict: workspace = state["workspace"] - chosen_model = _codex_model_id(model or default_model(state)) + # With a Model Provider Service the gateway routes by header and Codex sends + # its own canonical model name (e.g. `gpt-5`) — leave `model` unset so no + # Databricks endpoint id is pinned. + chosen_model = None if provider else _codex_model_id(model or default_model(state)) databricks_profile = state.get("profile") if _use_legacy_layout(): @@ -302,10 +321,20 @@ def write_tool_config(state: dict, model: str | None = None) -> dict: # ucode's entry from the shared file. backup_existing_file(LEGACY_CODEX_CONFIG_PATH, LEGACY_CODEX_BACKUP_PATH) overlay = render_legacy_overlay( - workspace, chosen_model, databricks_profile, use_pat=bool(state.get("use_pat")) + workspace, + chosen_model, + databricks_profile, + use_pat=bool(state.get("use_pat")), + provider=provider, ) doc = read_toml_safe(LEGACY_CODEX_CONFIG_PATH) deep_merge_dict(doc, overlay) + if provider: + # deep_merge can't drop keys, so clear a `model` pinned by an + # earlier non-provider run that the provider overlay omits. + profiles = doc.get("profiles") + if isinstance(profiles, dict) and isinstance(profiles.get(CODEX_PROFILE_NAME), dict): + profiles[CODEX_PROFILE_NAME].pop("model", None) write_toml_file(LEGACY_CODEX_CONFIG_PATH, doc) state = mark_tool_managed(state, "codex", LEGACY_MANAGED_KEYS) save_state(state) @@ -314,10 +343,18 @@ def write_tool_config(state: dict, model: str | None = None) -> dict: _remove_legacy_ucode_profile() backup_existing_file(CODEX_CONFIG_PATH, CODEX_BACKUP_PATH) overlay = render_overlay( - workspace, chosen_model, databricks_profile, use_pat=bool(state.get("use_pat")) + workspace, + chosen_model, + databricks_profile, + use_pat=bool(state.get("use_pat")), + provider=provider, ) doc = read_toml_safe(CODEX_CONFIG_PATH) deep_merge_dict(doc, overlay) + if provider: + # deep_merge can't drop keys, so clear a `model` pinned by an earlier + # non-provider run that the provider overlay omits. + doc.pop("model", None) write_toml_file(CODEX_CONFIG_PATH, doc) state = mark_tool_managed(state, "codex", MANAGED_KEYS) save_state(state) diff --git a/src/ucode/cli.py b/src/ucode/cli.py index a110f1f..41504b0 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -19,6 +19,7 @@ ensure_provider_state, install_tool_binary, normalize_tool, + provider_permission_error, resolve_launch_model, validate_all_tools, validate_tool, @@ -42,7 +43,10 @@ get_databricks_profiles, get_databricks_token, install_databricks_cli, + is_model_provider_feature_unavailable, + list_model_provider_services, list_profile_entries, + list_tool_provider_services, normalize_workspace_url, resolve_pat_token, run_databricks_login, @@ -53,7 +57,15 @@ purge_cross_workspace_mcp_residue, revert_mcp_configs, ) -from ucode.state import STATE_PATH, clear_state, load_full_state, load_state, save_state +from ucode.state import ( + STATE_PATH, + clear_state, + get_provider_service, + load_full_state, + load_state, + save_state, + set_provider_service, +) from ucode.tracing import configure_tracing_command from ucode.ui import ( console, @@ -64,6 +76,8 @@ print_note, print_section, print_success, + print_warning, + prompt_for_selection, prompt_for_tools, prompt_for_workspace, set_verbosity, @@ -195,6 +209,7 @@ def configure_shared_state( tools: list[str] | None = None, force_login: bool = False, use_pat: bool | None = None, + skip_model_discovery: bool = False, ) -> dict: """Log into Databricks, enforce AI Gateway v2, fetch model lists, persist state. @@ -257,24 +272,36 @@ def configure_shared_state( claude_models = {} gemini_models = [] codex_models = [] - # UC-first, best-effort: one UC model-services call yields all families as - # `system.ai.` ids, bucketed by name. If a family comes back - # empty (workspace without UC model-services, or the listing failed), fall - # back to the per-family AI Gateway listing for that family only. - with spinner("Fetching available models..."): - ms_claude, ms_codex, ms_gemini, ms_reason = discover_model_services(workspace, token) + web_search_model: str | None = None + if skip_model_discovery: + # Provider mode: the agent routes through a Model Provider Service and + # pins no Databricks model, so the full family discovery is unused. Web + # search (claude only) still needs one Responses-capable model, so fetch + # just that with a single call. if want_claude: - claude_models, claude_reason = ms_claude, ms_reason - if not claude_models: - claude_models, claude_reason = discover_claude_models(workspace, token) - if want_gemini: - gemini_models, gemini_reason = ms_gemini, ms_reason - if not gemini_models: - gemini_models, gemini_reason = discover_gemini_models(workspace, token) - if want_codex: - codex_models, codex_reason = ms_codex, ms_reason - if not codex_models: - codex_models, codex_reason = discover_codex_models(workspace, token) + with spinner("Fetching web search model..."): + ws_models, _ = discover_codex_models(workspace, token) + if ws_models: + web_search_model = ws_models[0] + else: + # UC-first, best-effort: one UC model-services call yields all families as + # `system.ai.` ids, bucketed by name. If a family comes back + # empty (workspace without UC model-services, or the listing failed), fall + # back to the per-family AI Gateway listing for that family only. + with spinner("Fetching available models..."): + ms_claude, ms_codex, ms_gemini, ms_reason = discover_model_services(workspace, token) + if want_claude: + claude_models, claude_reason = ms_claude, ms_reason + if not claude_models: + claude_models, claude_reason = discover_claude_models(workspace, token) + if want_gemini: + gemini_models, gemini_reason = ms_gemini, ms_reason + if not gemini_models: + gemini_models, gemini_reason = discover_gemini_models(workspace, token) + if want_codex: + codex_models, codex_reason = ms_codex, ms_reason + if not codex_models: + codex_models, codex_reason = discover_codex_models(workspace, token) opencode_models: dict[str, list[str]] = {} if claude_models: opencode_models["anthropic"] = list(claude_models.values()) @@ -297,14 +324,21 @@ def configure_shared_state( else: state.pop("use_pat", None) state["base_urls"] = build_shared_base_urls(workspace) - if want_claude: - state["claude_models"] = claude_models - if want_gemini: - state["gemini_models"] = gemini_models - if want_codex: - state["codex_models"] = codex_models - if fetch_all or "opencode" in tools: - state["opencode_models"] = opencode_models + if skip_model_discovery: + # Don't clobber any previously-discovered Databricks model lists; provider + # mode just doesn't refresh or use them. Persist the web-search model so + # claude's web_search MCP keeps working through the normal gateway. + if web_search_model: + state["web_search_model"] = web_search_model + else: + if want_claude: + state["claude_models"] = claude_models + if want_gemini: + state["gemini_models"] = gemini_models + if want_codex: + state["codex_models"] = codex_models + if fetch_all or "opencode" in tools: + state["opencode_models"] = opencode_models save_state(state) # Scrub MCP entries that ucode wrote for the previous workspace so the new # workspace's agent configs aren't stale. @@ -343,6 +377,75 @@ def _configure_shared_workspace_states( return states +def _provider_summary(tool: str, state: dict) -> str: + """Short label for the Configuration Complete box: 'Databricks' when no + Model Provider Service is configured, otherwise the external provider type + backing this tool (claude routes to Anthropic, codex to OpenAI).""" + if not get_provider_service(state, tool): + return "Databricks" + return {"claude": "Anthropic", "codex": "OpenAI"}.get(tool, "Model Provider Service") + + +def _maybe_select_provider_service(tool: str, state: dict) -> dict: + """Interactively let the user route claude/codex through a Model Provider + Service instead of Databricks models, and persist (or clear) the choice. + + No-op for tools other than claude/codex. Falls back to Databricks when no + matching provider services are found or the listing fails. + """ + if tool not in ("claude", "codex"): + return state + display = TOOL_SPECS[tool]["display"] + + def _use_databricks() -> dict: + new_state = set_provider_service(state, tool, None) + save_state(new_state) + return new_state + + # Probe first so we only offer the picker when it's actually usable. The + # caller already opted in via `--model-provider`, so explain any fallback + # rather than silently dropping back to Databricks. + token = get_databricks_token(state["workspace"], state.get("profile")) + with spinner("Checking for model provider services..."): + names, reason = list_tool_provider_services(tool, state["workspace"], token) + if reason is not None: + if is_model_provider_feature_unavailable(reason): + print_note( + "Model Provider Service feature is not available for this workspace; " + f"configuring {display} with Databricks models." + ) + else: + print_warning(f"Could not list model provider services: {reason}") + print_note("Falling back to Databricks models.") + return _use_databricks() + if not names: + # Feature is on but no service matches this tool's provider type. + print_note(f"No model provider services available for {display}; using Databricks models.") + return _use_databricks() + + choice = prompt_for_selection( + f"How should {display} be configured?", + [ + ("databricks", "Databricks models"), + ("mps", "Model Provider Service (external provider)"), + ], + ) + if choice is None: + raise KeyboardInterrupt + if choice == "databricks": + return _use_databricks() + + selected = prompt_for_selection( + "Select a model provider service:", [(name, name) for name in names] + ) + if selected is None: + raise KeyboardInterrupt + state = set_provider_service(state, tool, selected) + save_state(state) + print_success(f"{display} will route through {selected}") + return state + + def configure_workspace_command( tool: str | None = None, selected_tools: list[str] | None = None, @@ -351,10 +454,15 @@ def configure_workspace_command( prompt_optional_updates: bool = True, use_pat: bool = False, skip_validate: bool = False, + use_model_provider: bool = False, ) -> int: if tool is not None and selected_tools is not None: raise RuntimeError("Use either --agent or --agents, not both.") + # The Databricks-vs-Model-Provider-Service picker is opt-in via + # `--model-provider`; without it, configure stays on the plain Databricks path. + offer_provider = use_model_provider + workspace_entries = workspaces or [_prompt_for_configuration(tool)] if tool is not None: @@ -365,12 +473,15 @@ def configure_workspace_command( use_pat=use_pat, ) state = states[0] + if offer_provider: + state = _maybe_select_provider_service(tool, state) state = configure_single_tool(tool, state) spec = TOOL_SPECS[tool] console.print( Panel( f"[bold]Workspace:[/bold] [cyan]{state['workspace']}[/cyan]\n" - f"[bold]{spec['display']}:[/bold] [green]configured[/green]", + f"[bold]{spec['display']}:[/bold] [green]configured[/green] " + f"[dim](Provider: {_provider_summary(tool, state)})[/dim]", title="Configuration Complete", style="green", expand=False, @@ -384,7 +495,7 @@ def configure_workspace_command( if ok: print_success(f"{spec['display']} is working") else: - print_err(f"{spec['display']}: {err}") + print_err(f"{spec['display']}: {provider_permission_error(tool, state, err)}") managed = bool(state.get("managed_configs", {}).get(tool)) restore_file(spec["config_path"], spec["backup_path"], managed) available_tools = [t for t in (state.get("available_tools") or []) if t != tool] @@ -440,12 +551,21 @@ def configure_workspace_command( prompt_optional_updates=prompt_optional_updates, ) + # Offer the provider picker for the chosen claude/codex tools only when + # `--model-provider` was passed; otherwise stay on the Databricks path. + if offer_provider: + for tool_name in picked: + state = _maybe_select_provider_service(tool_name, state) + state = configure_selected_tools(state, picked) summary_lines = [f"[bold]Workspace:[/bold] [cyan]{state['workspace']}[/cyan]"] for tool_name in picked: spec = TOOL_SPECS[tool_name] - summary_lines.append(f"[bold]{spec['display']}:[/bold] [green]configured[/green]") + summary_lines.append( + f"[bold]{spec['display']}:[/bold] [green]configured[/green] " + f"[dim](Provider: {_provider_summary(tool_name, state)})[/dim]" + ) console.print( Panel( "\n".join(summary_lines), @@ -494,6 +614,9 @@ def status() -> int: config_path = spec["config_path"] print_kv("Coding Agent", spec["display"]) print_kv("Configured", "yes" if configured else "no") + provider_service = get_provider_service(state, tool) + if configured and provider_service: + print_kv("Model Provider Service", provider_service) print_kv("Base URL", base_url) if configured and tool in MCP_CLIENTS: tool_mcp_servers = [ @@ -612,7 +735,8 @@ def _auto_configure_tool(tool: str) -> None: console.print( Panel( f"[bold]Workspace:[/bold] [cyan]{state['workspace']}[/cyan]\n" - f"[bold]{spec['display']}:[/bold] [green]configured[/green]", + f"[bold]{spec['display']}:[/bold] [green]configured[/green] " + f"[dim](Provider: {_provider_summary(tool, state)})[/dim]", title="Configuration Complete", style="green", expand=False, @@ -624,7 +748,7 @@ def _auto_configure_tool(tool: str) -> None: if ok: print_success(f"{spec['display']} is working") else: - print_err(f"{spec['display']}: {err}") + print_err(f"{spec['display']}: {provider_permission_error(tool, state, err)}") managed = bool(state.get("managed_configs", {}).get(tool)) restore_file(spec["config_path"], spec["backup_path"], managed) available_tools = [t for t in (state.get("available_tools") or []) if t != tool] @@ -633,7 +757,7 @@ def _auto_configure_tool(tool: str) -> None: raise RuntimeError(f"{spec['display']} validation failed — config reverted.") -def _launch_tool(tool_name: str, ctx: typer.Context) -> None: +def _launch_tool(tool_name: str, ctx: typer.Context, provider: str | None = None) -> None: try: tool = normalize_tool(tool_name) existing = load_state() @@ -648,17 +772,45 @@ def _launch_tool(tool_name: str, ctx: typer.Context) -> None: if needs_auto_configure: _auto_configure_tool(tool) state = ensure_provider_state(tool) + # An explicit --provider overrides the persisted choice; otherwise fall + # back to whatever `ucode configure` saved for this tool. + explicit_provider = provider is not None + provider = provider or get_provider_service(state, tool) + if provider and explicit_provider: + # Verify the feature only for an explicit --provider; a persisted + # choice was already validated at `ucode configure` time, so trust it + # and keep the launch fast. Surfaces a clear error up front instead of + # a cryptic gateway error mid-session. + token = get_databricks_token(state["workspace"], state.get("profile")) + _, reason = list_model_provider_services(state["workspace"], token) + if is_model_provider_feature_unavailable(reason): + raise RuntimeError( + "Model Provider Service feature is not available yet for this workspace." + ) # Re-fetch model lists on every launch so newly-added Databricks # endpoints show up without a manual `ucode configure` (and so that # tools like pi which read multiple model bundles never run on - # stale state from before a tool added a new bundle). + # stale state from before a tool added a new bundle). Under a provider + # this heavy discovery is skipped (only a web-search model is fetched). state = configure_shared_state( - state["workspace"], profile=state.get("profile"), tools=[tool] + state["workspace"], + profile=state.get("profile"), + tools=[tool], + skip_model_discovery=bool(provider), ) - state, resolved_model = resolve_launch_model(tool, state, None) - state = configure_tool(tool, state, resolved_model) + if provider: + # Routing through a Model Provider Service pins no Databricks model; + # the agent uses its own canonical model names (header selects the + # provider). Skip model resolution, which would otherwise fail when + # the workspace has no matching Databricks models. + resolved_model = None + else: + state, resolved_model = resolve_launch_model(tool, state, None) + state = configure_tool(tool, state, resolved_model, provider=provider) print_section(f"ucode with {TOOL_SPECS[tool]['display']}") - if resolved_model: + if provider: + print_kv("Provider", provider) + elif resolved_model: print_kv("Model", resolved_model) if tool in ("gemini", "opencode", "copilot", "pi"): print_note( @@ -676,15 +828,37 @@ def _launch_tool(tool_name: str, ctx: typer.Context) -> None: @app.command("codex", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) -def codex_cmd(ctx: typer.Context) -> None: +def codex_cmd( + ctx: typer.Context, + provider: Annotated[ + str | None, + typer.Option( + "--provider", + help="Route through a Unity Catalog Model Provider Service " + "(..). Skips Databricks model pinning; pass " + "before any `--` separator.", + ), + ] = None, +) -> None: """Launch Codex via Databricks.""" - _launch_tool("codex", ctx) + _launch_tool("codex", ctx, provider=provider) @app.command("claude", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) -def claude_cmd(ctx: typer.Context) -> None: +def claude_cmd( + ctx: typer.Context, + provider: Annotated[ + str | None, + typer.Option( + "--provider", + help="Route through a Unity Catalog Model Provider Service " + "(..). Skips Databricks model pinning; pass " + "before any `--` separator.", + ), + ] = None, +) -> None: """Launch Claude Code via Databricks.""" - _launch_tool("claude", ctx) + _launch_tool("claude", ctx, provider=provider) @app.command("gemini", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @@ -776,6 +950,15 @@ def configure( help="Also enable MLflow tracing for the configured workspace(s).", ), ] = False, + model_provider: Annotated[ + bool, + typer.Option( + "--model-provider", + help="Offer to route claude/codex through a Unity Catalog Model Provider " + "Service (external Anthropic/OpenAI provider) instead of Databricks models. " + "Without this flag, configure stays on the Databricks path.", + ), + ] = False, skip_upgrade: Annotated[ bool, typer.Option( @@ -824,6 +1007,8 @@ def configure( skip_kwargs["use_pat"] = True if skip_validate: skip_kwargs["skip_validate"] = True + if model_provider: + skip_kwargs["use_model_provider"] = True if agent is not None: tool = normalize_tool(agent) install_tool_binary( diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 580de49..cb279e5 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -1261,6 +1261,81 @@ def build_mcp_service_url(workspace: str, full_name: str) -> str: return f"{workspace}/ai-gateway/mcp-services/{full_name}" +# Maps the gateway routing dialect a coding tool speaks to the Model Provider +# Service `provider_type` it can be backed by. claude speaks Anthropic's API; +# codex speaks OpenAI's. +_TOOL_PROVIDER_TYPES: dict[str, str] = { + "claude": "anthropic", + "codex": "openai", +} + + +def _provider_type_tag(provider_type: str | None) -> str: + """Shorten `EXTERNAL_MODEL_PROVIDER_TYPE_ANTHROPIC` to `anthropic`.""" + if not isinstance(provider_type, str): + return "" + prefix = "EXTERNAL_MODEL_PROVIDER_TYPE_" + tag = provider_type[len(prefix) :] if provider_type.startswith(prefix) else provider_type + return tag.lower() + + +def list_model_provider_services(workspace: str, token: str) -> tuple[list[dict], str | None]: + """List Unity Catalog Model Provider Services on the workspace. + + Returns ``(services, reason)`` where each service is + ``{"name": "..", "provider_type": "anthropic"|...}``. + A non-None ``reason`` means the listing call itself failed. + """ + hostname = workspace_hostname(workspace) + url = f"https://{hostname}/api/2.1/unity-catalog/model-provider-services" + payload, reason = _http_get_json(url, token, timeout=30) + if payload is None: + return [], reason + data = cast(dict, payload) if isinstance(payload, dict) else {} + services: list[dict] = [] + for service in data.get("model_provider_services") or []: + if not isinstance(service, dict): + continue + raw_name = service.get("name") + if not isinstance(raw_name, str) or not raw_name: + continue + # The API returns `model-provider-services/..`. + full_name = raw_name.split("/", 1)[1] if "/" in raw_name else raw_name + config = service.get("config") if isinstance(service.get("config"), dict) else {} + services.append( + { + "name": full_name, + "provider_type": _provider_type_tag(config.get("provider_type")), + } + ) + services.sort(key=lambda s: s["name"]) + return services, None + + +def is_model_provider_feature_unavailable(reason: str | None) -> bool: + """True when a model-provider-services API failure means the workspace + simply hasn't enabled the feature (HTTP 400 "feature is not available"), + as opposed to a transient or auth error. Callers use this to fall back to + Databricks models silently rather than surfacing a scary error. + """ + return bool(reason) and "feature is not available" in reason.lower() + + +def list_tool_provider_services( + tool: str, workspace: str, token: str +) -> tuple[list[str], str | None]: + """Provider-service names whose provider type matches ``tool``'s API dialect. + + Returns ``(names, reason)``; ``reason`` is non-None when the listing failed. + """ + wanted = _TOOL_PROVIDER_TYPES.get(tool) + services, reason = list_model_provider_services(workspace, token) + if reason is not None: + return [], reason + names = [s["name"] for s in services if not wanted or s["provider_type"] == wanted] + return names, None + + # `list_vector_search_catalog_schemas` walks Vector Search endpoints+indexes. # `list_uc_functions_catalog_schemas` walks UC catalogs+schemas in parallel and # keeps only schemas with at least one user function. diff --git a/src/ucode/state.py b/src/ucode/state.py index 471eae0..c2c8a84 100644 --- a/src/ucode/state.py +++ b/src/ucode/state.py @@ -194,3 +194,30 @@ def mark_tool_managed(state: dict, tool: str, managed_keys: list) -> dict: state["managed_configs"] = managed_configs state["last_tool"] = tool return state + + +def get_provider_service(state: dict, tool: str) -> str | None: + """Return the persisted Model Provider Service for ``tool``, if any. + + Launches route through this provider (skipping Databricks model pinning) + unless overridden by an explicit ``--provider`` flag. + """ + providers = state.get("provider_services") + if not isinstance(providers, dict): + return None + name = providers.get(tool) + return name if isinstance(name, str) and name else None + + +def set_provider_service(state: dict, tool: str, full_name: str | None) -> dict: + """Persist (or clear, when ``full_name`` is None) ``tool``'s provider service.""" + providers = dict(state.get("provider_services") or {}) + if full_name: + providers[tool] = full_name + else: + providers.pop(tool, None) + if providers: + state["provider_services"] = providers + else: + state.pop("provider_services", None) + return state diff --git a/src/ucode/ui.py b/src/ucode/ui.py index a687d9c..cb55797 100644 --- a/src/ucode/ui.py +++ b/src/ucode/ui.py @@ -266,6 +266,32 @@ def prompt_for_tools(available: list[tuple[str, str]]) -> list[str]: return list(answer) if answer else [] +def prompt_for_selection(prompt: str, options: list[tuple[str, str]]) -> str | None: + """Single-select arrow-key picker. `options` is [(value, label), ...]. + + The prompt renders above the choices (questionary convention). Returns the + chosen value, or None if the user cancels (Ctrl-C / empty). + """ + style = questionary.Style( + [ + ("pointer", "fg:cyan bold"), + ("highlighted", "noinherit"), + ("selected", "noinherit"), + ("answer", "fg:cyan"), + ] + ) + choices = [questionary.Choice(title=label, value=value) for value, label in options] + answer = questionary.select( + prompt, + choices=choices, + style=style, + pointer="›", + qmark="", + instruction="(use arrow keys)", + ).ask() + return answer + + def prompt_yes_no(prompt: str) -> bool: while True: response = console.input(f"{label(prompt)} {muted('(y/n)')} {muted('›')} ").strip().lower() diff --git a/tests/test_agent_claude.py b/tests/test_agent_claude.py index 1c801b9..86975b4 100644 --- a/tests/test_agent_claude.py +++ b/tests/test_agent_claude.py @@ -110,6 +110,31 @@ def test_model_overrides_not_set_when_no_models(self): env = overlay["env"] assert "ANTHROPIC_DEFAULT_SONNET_MODEL" not in env + def test_provider_adds_routing_header(self): + overlay, _ = claude.render_overlay(WS, "s4", provider="main.aarushi.aarushi-claude") + assert ( + "Databricks-Model-Provider-Service: main.aarushi.aarushi-claude" + in overlay["env"]["ANTHROPIC_CUSTOM_HEADERS"] + ) + + def test_provider_skips_model_pinning(self): + models = { + "opus": "databricks-claude-opus-4-7", + "sonnet": "databricks-claude-sonnet-4-6", + "haiku": "databricks-claude-haiku-4-6", + } + overlay, _ = claude.render_overlay( + WS, "s4", claude_models=models, provider="main.aarushi.aarushi-claude" + ) + env = overlay["env"] + assert "ANTHROPIC_DEFAULT_OPUS_MODEL" not in env + assert "ANTHROPIC_DEFAULT_SONNET_MODEL" not in env + assert "ANTHROPIC_DEFAULT_HAIKU_MODEL" not in env + + def test_no_provider_header_without_flag(self): + overlay, _ = claude.render_overlay(WS, "s4") + assert "Databricks-Model-Provider-Service" not in overlay["env"]["ANTHROPIC_CUSTOM_HEADERS"] + def test_picker_labels_show_raw_routable_id(self): # We deliberately don't set the `_NAME` companion env vars. Showing the # raw `system.ai.…` / `databricks-…` id in the picker label tells users diff --git a/tests/test_agent_codex.py b/tests/test_agent_codex.py index f8d6baf..db65ae7 100644 --- a/tests/test_agent_codex.py +++ b/tests/test_agent_codex.py @@ -61,6 +61,20 @@ def test_auth_refresh_interval(self): auth = overlay["model_providers"]["ucode-databricks"]["auth"] assert auth["refresh_interval_ms"] == 900_000 + def test_provider_adds_routing_header(self): + overlay = codex.render_overlay(WS, provider="main.aarushi.aarushi-openai") + headers = overlay["model_providers"]["ucode-databricks"]["http_headers"] + assert headers["Databricks-Model-Provider-Service"] == "main.aarushi.aarushi-openai" + + def test_provider_omits_model(self): + overlay = codex.render_overlay(WS, model=None, provider="main.aarushi.aarushi-openai") + assert "model" not in overlay + + def test_no_provider_header_without_flag(self): + overlay = codex.render_overlay(WS) + headers = overlay["model_providers"]["ucode-databricks"]["http_headers"] + assert "Databricks-Model-Provider-Service" not in headers + class TestRenderOverlayUserAgent: def test_user_agent_set_on_provider(self, monkeypatch): @@ -124,6 +138,29 @@ def test_preserves_databricks_model_id_when_openai_id_is_incompatible( doc = read_toml_safe(config_path) assert doc["model"] == "databricks-gpt-5-2-codex" + def test_provider_writes_header_and_drops_stale_model(self, tmp_path, monkeypatch): + config_path = tmp_path / ".codex" / "ucode.config.toml" + backup_path = tmp_path / "codex-ucode-config.backup.toml" + monkeypatch.setattr(codex, "CODEX_CONFIG_PATH", config_path) + monkeypatch.setattr(codex, "CODEX_BACKUP_PATH", backup_path) + monkeypatch.setattr(codex, "agent_version", lambda binary: "0.134.0") + monkeypatch.setattr(codex, "save_state", lambda state: None) + + # An earlier non-provider run pinned a model. + codex.write_tool_config({"workspace": WS, "codex_models": ["gpt-5"]}) + assert read_toml_safe(config_path)["model"] == "gpt-5" + + # A provider run must clear it and add the routing header. + codex.write_tool_config( + {"workspace": WS, "codex_models": ["gpt-5"]}, + provider="main.aarushi.aarushi-openai", + ) + + doc = read_toml_safe(config_path) + assert "model" not in doc + headers = doc["model_providers"]["ucode-databricks"]["http_headers"] + assert headers["Databricks-Model-Provider-Service"] == "main.aarushi.aarushi-openai" + def test_removes_legacy_ucode_profile_from_shared_config(self, tmp_path, monkeypatch): config_dir = tmp_path / ".codex" config_dir.mkdir() diff --git a/tests/test_agents_init.py b/tests/test_agents_init.py index e4560f4..3cecadb 100644 --- a/tests/test_agents_init.py +++ b/tests/test_agents_init.py @@ -16,10 +16,32 @@ ensure_tool_binary_available, install_tool_binary, normalize_tool, + provider_permission_error, resolve_launch_model, ) +class TestProviderPermissionError: + _CONN_ERR = ( + "User does not have USE CONNECTION on SCHEMA_CONNECTION " + "'299433db-cb91-4b08-9761-edab72a27836'." + ) + + def test_rewrites_when_provider_configured(self): + state = {"provider_services": {"codex": "main.aarushi.aarushi-test-openai"}} + out = provider_permission_error("codex", state, self._CONN_ERR) + assert "main.aarushi.aarushi-test-openai" in out + assert "EXECUTE" in out + assert "SCHEMA_CONNECTION" not in out + + def test_passthrough_without_provider(self): + assert provider_permission_error("codex", {}, self._CONN_ERR) == self._CONN_ERR + + def test_passthrough_for_unrelated_error(self): + state = {"provider_services": {"codex": "main.a.b"}} + assert provider_permission_error("codex", state, "boom") == "boom" + + class TestToolSpecs: def test_all_tools_present(self): assert set(TOOL_SPECS) == {"codex", "claude", "gemini", "opencode", "copilot", "pi"} diff --git a/tests/test_cli.py b/tests/test_cli.py index cf156bc..ff6f9c1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -634,6 +634,38 @@ def test_selected_tools_skip_picker(self, monkeypatch): assert install_calls == ["claude", "codex"] assert configured == [["claude", "codex"]] + def test_provider_picker_gated_by_flag(self, monkeypatch): + import ucode.cli as cli_mod + + state = {**MINIMAL_STATE, "available_tools": []} + monkeypatch.setattr(cli_mod, "configure_shared_state", lambda *a, **k: state) + monkeypatch.setattr(cli_mod, "check_gateway_endpoint", lambda s, t: t == "claude") + monkeypatch.setattr(cli_mod, "install_tool_binary", lambda *a, **k: True) + monkeypatch.setattr( + cli_mod, "configure_selected_tools", lambda s, tools: {**s, "available_tools": tools} + ) + monkeypatch.setattr(cli_mod, "validate_all_tools", lambda s: None) + picked_for: list[str] = [] + monkeypatch.setattr( + cli_mod, + "_maybe_select_provider_service", + lambda tool, s: picked_for.append(tool) or s, + ) + + # Default: no provider picker. + cli_mod.configure_workspace_command( + selected_tools=["claude"], workspaces=[("https://w.com", None)] + ) + assert picked_for == [] + + # Opt-in: picker offered for the chosen tool. + cli_mod.configure_workspace_command( + selected_tools=["claude"], + workspaces=[("https://w.com", None)], + use_model_provider=True, + ) + assert picked_for == ["claude"] + def test_unavailable_selected_tool_errors_before_configure(self, monkeypatch): import ucode.cli as cli_mod @@ -1113,3 +1145,52 @@ def test_skips_purge_when_workspace_unchanged(self, monkeypatch): cli_mod.configure_shared_state("https://same.databricks.com") assert purge_calls == [] + + +class TestConfigureSharedStateSkipDiscovery: + """With skip_model_discovery (provider mode), the heavy family discovery is + skipped; only a single web-search model is fetched, and existing model lists + are preserved rather than clobbered.""" + + @staticmethod + def _stub(monkeypatch): + import ucode.cli as cli_mod + + monkeypatch.setattr(cli_mod, "normalize_workspace_url", lambda w: w) + monkeypatch.setattr(cli_mod, "ensure_databricks_auth", lambda w, p=None: None) + monkeypatch.setattr(cli_mod, "run_databricks_login", lambda w, p: None) + monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None) + monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token") + monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None) + monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {}) + monkeypatch.setattr(cli_mod, "save_state", lambda s: None) + + def test_skips_family_discovery_and_fetches_web_search_model(self, monkeypatch): + import ucode.cli as cli_mod + + ws = "https://prov.databricks.com" + self._stub(monkeypatch) + # Pretend a prior Databricks configure left models behind. + monkeypatch.setattr( + cli_mod, + "load_state", + lambda: {"workspace": ws, "claude_models": {"opus": "databricks-claude-opus-4-8"}}, + ) + + def _boom(*a, **k): + raise AssertionError("discover_model_services must not run in provider mode") + + monkeypatch.setattr(cli_mod, "discover_model_services", _boom) + codex_calls: list = [] + monkeypatch.setattr( + cli_mod, + "discover_codex_models", + lambda w, t: codex_calls.append((w, t)) or (["databricks-gpt-5"], None), + ) + + state = cli_mod.configure_shared_state(ws, tools=["claude"], skip_model_discovery=True) + + assert codex_calls == [(ws, "token")] + assert state["web_search_model"] == "databricks-gpt-5" + # Existing model list preserved, not overwritten to {}. + assert state["claude_models"] == {"opus": "databricks-claude-opus-4-8"} diff --git a/tests/test_databricks.py b/tests/test_databricks.py index bac25d4..86a64ac 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -272,6 +272,70 @@ def flaky_get(url, token, timeout=10): assert calls["n"] == 3 # two failures, third succeeds +class TestListModelProviderServices: + _PAYLOAD = { + "model_provider_services": [ + { + "name": "model-provider-services/main.aarushi.anthropic-svc", + "config": {"provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_ANTHROPIC"}, + }, + { + "name": "model-provider-services/main.aarushi.openai-svc", + "config": {"provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_OPENAI"}, + }, + { + "name": "model-provider-services/main.bob.bedrock-svc", + "config": {"provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_BEDROCK"}, + }, + ] + } + + def test_strips_prefix_and_tags_provider_type(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) + ) + services, reason = db_mod.list_model_provider_services(WS, "token") + assert reason is None + assert services[0] == {"name": "main.aarushi.anthropic-svc", "provider_type": "anthropic"} + assert {s["provider_type"] for s in services} == {"anthropic", "openai", "bedrock"} + + def test_returns_reason_on_failure(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (None, "HTTP 500 Server Error") + ) + services, reason = db_mod.list_model_provider_services(WS, "token") + assert services == [] + assert reason == "HTTP 500 Server Error" + + def test_claude_filters_to_anthropic(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) + ) + names, reason = db_mod.list_tool_provider_services("claude", WS, "token") + assert reason is None + assert names == ["main.aarushi.anthropic-svc"] + + def test_codex_filters_to_openai(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) + ) + names, _ = db_mod.list_tool_provider_services("codex", WS, "token") + assert names == ["main.aarushi.openai-svc"] + + +class TestModelProviderFeatureUnavailable: + def test_detects_feature_not_available(self): + reason = ( + 'HTTP 400 Bad Request: {"error_code":"BAD_REQUEST",' + '"message":"ModelProviderService feature is not available"}' + ) + assert db_mod.is_model_provider_feature_unavailable(reason) is True + + def test_false_for_other_errors(self): + assert db_mod.is_model_provider_feature_unavailable("HTTP 500 Server Error") is False + assert db_mod.is_model_provider_feature_unavailable(None) is False + + class TestListMcpServices: def test_accepts_entries_without_connection_status(self, monkeypatch): payload = { diff --git a/tests/test_e2e.py b/tests/test_e2e.py index f0c64e0..79fb91e 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -28,6 +28,9 @@ fetch_codex_models, fetch_gemini_models, has_valid_databricks_auth, + is_model_provider_feature_unavailable, + list_model_provider_services, + list_tool_provider_services, workspace_hostname, ) from ucode.ui import normalize_workspace_url @@ -146,6 +149,36 @@ def test_fetch_codex_models_returns_list(self, e2e_workspace, e2e_token): assert isinstance(models, list) +# --------------------------------------------------------------------------- +# Model Provider Services discovery +# --------------------------------------------------------------------------- + + +class TestModelProviderServicesDiscovery: + def test_list_returns_services_or_skips_when_feature_off(self, e2e_workspace, e2e_token): + services, reason = list_model_provider_services(e2e_workspace, e2e_token) + if is_model_provider_feature_unavailable(reason): + pytest.skip("Model Provider Service feature not enabled on this workspace") + assert reason is None, f"listing failed: {reason}" + assert isinstance(services, list) + for svc in services: + assert set(svc) >= {"name", "provider_type"} + # Names are stripped of the `model-provider-services/` API prefix. + assert svc["name"] and "/" not in svc["name"] + + def test_tool_filter_matches_provider_type(self, e2e_workspace, e2e_token): + services, reason = list_model_provider_services(e2e_workspace, e2e_token) + if is_model_provider_feature_unavailable(reason): + pytest.skip("Model Provider Service feature not enabled on this workspace") + assert reason is None + claude_names, _ = list_tool_provider_services("claude", e2e_workspace, e2e_token) + codex_names, _ = list_tool_provider_services("codex", e2e_workspace, e2e_token) + assert set(claude_names) == { + s["name"] for s in services if s["provider_type"] == "anthropic" + } + assert set(codex_names) == {s["name"] for s in services if s["provider_type"] == "openai"} + + # --------------------------------------------------------------------------- # URL builders # --------------------------------------------------------------------------- @@ -475,6 +508,107 @@ def test_launch_claude_per_model( assert not failures, "Claude launch failures:\n" + "\n".join(failures) +class TestModelProviderLaunch: + """Launch claude/codex routed through a real Model Provider Service. + + Picks the first matching service on the workspace, writes a provider config + (no Databricks model pinned), and runs the agent so a real request flows + through the MPS gateway. Skips when the feature is off, no service exists, or + the caller lacks permission on the backing connection. + """ + + @staticmethod + def _first_service(tool: str, workspace: str, token: str) -> str: + names, reason = list_tool_provider_services(tool, workspace, token) + if is_model_provider_feature_unavailable(reason): + pytest.skip("Model Provider Service feature not enabled on this workspace") + if reason is not None: + pytest.skip(f"could not list provider services: {reason}") + if not names: + pytest.skip(f"no {tool} model provider services available on this workspace") + return names[0] + + @staticmethod + def _skip_if_no_permission(combined: str, provider: str) -> None: + if "USE CONNECTION" in combined or "EXECUTE" in combined: + pytest.skip(f"no permission on provider {provider}: {combined[:200]}") + + def test_launch_claude_through_provider( + self, tmp_path, monkeypatch, e2e_state, e2e_workspace, e2e_token + ): + import ucode.config_io as config_io_mod + from ucode.agents import claude + + _require_binary("claude") + provider = self._first_service("claude", e2e_workspace, e2e_token) + + config_dir = tmp_path / "claude_config" + config_dir.mkdir() + monkeypatch.setattr(config_io_mod, "APP_DIR", tmp_path) + monkeypatch.setattr(claude, "CLAUDE_SETTINGS_PATH", config_dir / "settings.json") + monkeypatch.setattr(claude, "CLAUDE_BACKUP_PATH", tmp_path / "claude-settings.backup.json") + + with pytest.MonkeyPatch().context() as mp: + mp.setattr("ucode.state.save_state", lambda s: None) + # No model pinned — the provider header (written into the settings + # env block) routes the agent's own canonical model name. + claude.write_tool_config( + {**e2e_state, "workspace": e2e_workspace}, None, provider=provider + ) + + env = { + **os.environ, + "CLAUDE_CONFIG_DIR": str(config_dir), + "ANTHROPIC_BASE_URL": build_tool_base_url("claude", e2e_workspace), + "ANTHROPIC_API_KEY": e2e_token, + "CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS": "1", + } + result = _run_agent(claude.validate_cmd("claude"), env=env, timeout=90) + combined = (result.stdout + result.stderr).strip() + self._skip_if_no_permission(combined, provider) + assert result.returncode == 0 and combined, ( + f"provider={provider} rc={result.returncode} " + f"stdout={result.stdout[:300]!r} stderr={result.stderr[:300]!r}" + ) + + def test_launch_codex_through_provider( + self, tmp_path, monkeypatch, e2e_state, e2e_workspace, e2e_token + ): + import ucode.config_io as config_io_mod + from ucode.agents import codex + + _require_binary("codex") + provider = self._first_service("codex", e2e_workspace, e2e_token) + + monkeypatch.setattr(config_io_mod, "APP_DIR", tmp_path) + config_dir = tmp_path / "codex_home" / ".codex" + config_dir.mkdir(parents=True) + monkeypatch.setattr(codex, "CODEX_CONFIG_PATH", config_dir / "ucode.config.toml") + monkeypatch.setattr(codex, "CODEX_BACKUP_PATH", tmp_path / "codex-config.backup.toml") + + with pytest.MonkeyPatch().context() as mp: + mp.setattr("ucode.state.save_state", lambda s: None) + codex.write_tool_config( + {**e2e_state, "workspace": e2e_workspace}, None, provider=provider + ) + + timeout_seconds = int(os.environ.get("UCODE_E2E_AGENT_TIMEOUT", "60")) + try: + result = _run_agent( + codex.validate_cmd("codex"), + env={**os.environ, "CODEX_HOME": str(config_dir)}, + timeout=timeout_seconds, + ) + except subprocess.TimeoutExpired: + pytest.fail(f"provider={provider} timed out after {timeout_seconds}s") + combined = (result.stdout + result.stderr).strip() + self._skip_if_no_permission(combined, provider) + assert result.returncode == 0 and combined, ( + f"provider={provider} rc={result.returncode} " + f"stdout={result.stdout[:300]!r} stderr={result.stderr[:300]!r}" + ) + + class TestGeminiLaunch: """Run gemini against every available gemini model.""" diff --git a/tests/test_state.py b/tests/test_state.py index 95d1440..f0a0967 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -12,11 +12,13 @@ STATE_VERSION, build_agent_state, clear_state, + get_provider_service, hydrate_state, load_full_state, load_state, mark_tool_managed, save_state, + set_provider_service, ) FAKE_WS = "https://example.databricks.com" @@ -139,6 +141,32 @@ def test_clear_when_no_state_is_noop(self): clear_state() # should not raise +class TestProviderService: + def test_get_returns_none_when_unset(self): + assert get_provider_service({}, "claude") is None + assert get_provider_service({"provider_services": {}}, "claude") is None + + def test_set_and_get_roundtrip(self): + state = set_provider_service({}, "claude", "main.a.anthropic") + assert state["provider_services"]["claude"] == "main.a.anthropic" + assert get_provider_service(state, "claude") == "main.a.anthropic" + assert get_provider_service(state, "codex") is None + + def test_set_none_clears_entry_and_key(self): + state = set_provider_service({}, "claude", "main.a.anthropic") + state = set_provider_service(state, "claude", None) + assert get_provider_service(state, "claude") is None + # Drop the empty container entirely rather than leaving {}. + assert "provider_services" not in state + + def test_clearing_one_tool_keeps_the_other(self): + state = set_provider_service({}, "claude", "main.a.anthropic") + state = set_provider_service(state, "codex", "main.a.openai") + state = set_provider_service(state, "claude", None) + assert get_provider_service(state, "claude") is None + assert get_provider_service(state, "codex") == "main.a.openai" + + # --------------------------------------------------------------------------- # hydrate_state # --------------------------------------------------------------------------- From 520083e61be97530073cf33056e147210f7dc032 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Thu, 25 Jun 2026 23:10:20 +0000 Subject: [PATCH 02/11] Trigger CI re-run after granting trace-table write access Co-authored-by: Isaac From 2ed9d082efbb9fcd5f61691b6d9d9706852b8192 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Thu, 25 Jun 2026 23:40:38 +0000 Subject: [PATCH 03/11] Re-trigger CI after granting github-actions-sp UC schema access Co-authored-by: Isaac From 809f0c904bc69837f66c727f78793a431cc4445a Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 06:49:13 +0000 Subject: [PATCH 04/11] Drop --model-provider flag; gate picker on interactive configure Show the Model Provider Service picker only on the fully interactive `ucode configure` path (no --agent/--agents). Naming agents signals the non-interactive flow and stays on Databricks. Also fall back to Databricks silently when the MPS feature isn't enabled on the workspace, instead of printing a per-tool note. --- src/ucode/cli.py | 38 +++++++++++--------------------------- tests/test_cli.py | 15 ++++++++------- 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 41504b0..31aef31 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -403,18 +403,15 @@ def _use_databricks() -> dict: return new_state # Probe first so we only offer the picker when it's actually usable. The - # caller already opted in via `--model-provider`, so explain any fallback - # rather than silently dropping back to Databricks. + # interactive path always reaches here, so explain any fallback rather than + # silently dropping back to Databricks. token = get_databricks_token(state["workspace"], state.get("profile")) with spinner("Checking for model provider services..."): names, reason = list_tool_provider_services(tool, state["workspace"], token) if reason is not None: - if is_model_provider_feature_unavailable(reason): - print_note( - "Model Provider Service feature is not available for this workspace; " - f"configuring {display} with Databricks models." - ) - else: + # Most workspaces don't have the feature enabled — that's the common case, + # so fall back to Databricks silently. Only surface unexpected failures. + if not is_model_provider_feature_unavailable(reason): print_warning(f"Could not list model provider services: {reason}") print_note("Falling back to Databricks models.") return _use_databricks() @@ -454,14 +451,14 @@ def configure_workspace_command( prompt_optional_updates: bool = True, use_pat: bool = False, skip_validate: bool = False, - use_model_provider: bool = False, ) -> int: if tool is not None and selected_tools is not None: raise RuntimeError("Use either --agent or --agents, not both.") - # The Databricks-vs-Model-Provider-Service picker is opt-in via - # `--model-provider`; without it, configure stays on the plain Databricks path. - offer_provider = use_model_provider + # The Databricks-vs-Model-Provider-Service picker is shown only on the fully + # interactive path (`ucode configure` with no --agent/--agents). Naming agents + # explicitly signals the non-interactive flow, which stays on Databricks. + offer_provider = tool is None and selected_tools is None workspace_entries = workspaces or [_prompt_for_configuration(tool)] @@ -473,8 +470,6 @@ def configure_workspace_command( use_pat=use_pat, ) state = states[0] - if offer_provider: - state = _maybe_select_provider_service(tool, state) state = configure_single_tool(tool, state) spec = TOOL_SPECS[tool] console.print( @@ -551,8 +546,8 @@ def configure_workspace_command( prompt_optional_updates=prompt_optional_updates, ) - # Offer the provider picker for the chosen claude/codex tools only when - # `--model-provider` was passed; otherwise stay on the Databricks path. + # Offer the provider picker for the chosen claude/codex tools only on the + # interactive path (no --agents); otherwise stay on the Databricks path. if offer_provider: for tool_name in picked: state = _maybe_select_provider_service(tool_name, state) @@ -950,15 +945,6 @@ def configure( help="Also enable MLflow tracing for the configured workspace(s).", ), ] = False, - model_provider: Annotated[ - bool, - typer.Option( - "--model-provider", - help="Offer to route claude/codex through a Unity Catalog Model Provider " - "Service (external Anthropic/OpenAI provider) instead of Databricks models. " - "Without this flag, configure stays on the Databricks path.", - ), - ] = False, skip_upgrade: Annotated[ bool, typer.Option( @@ -1007,8 +993,6 @@ def configure( skip_kwargs["use_pat"] = True if skip_validate: skip_kwargs["skip_validate"] = True - if model_provider: - skip_kwargs["use_model_provider"] = True if agent is not None: tool = normalize_tool(agent) install_tool_binary( diff --git a/tests/test_cli.py b/tests/test_cli.py index ff6f9c1..e2e01a1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -634,7 +634,7 @@ def test_selected_tools_skip_picker(self, monkeypatch): assert install_calls == ["claude", "codex"] assert configured == [["claude", "codex"]] - def test_provider_picker_gated_by_flag(self, monkeypatch): + def test_provider_picker_gated_by_interactive_path(self, monkeypatch): import ucode.cli as cli_mod state = {**MINIMAL_STATE, "available_tools": []} @@ -652,18 +652,18 @@ def test_provider_picker_gated_by_flag(self, monkeypatch): lambda tool, s: picked_for.append(tool) or s, ) - # Default: no provider picker. + # Non-interactive (--agents passed): no provider picker. cli_mod.configure_workspace_command( selected_tools=["claude"], workspaces=[("https://w.com", None)] ) assert picked_for == [] - # Opt-in: picker offered for the chosen tool. - cli_mod.configure_workspace_command( - selected_tools=["claude"], - workspaces=[("https://w.com", None)], - use_model_provider=True, + # Interactive (`ucode configure`): picker offered for each picked tool. + monkeypatch.setattr( + cli_mod, "_prompt_for_configuration", lambda tool=None: ("https://w.com", None) ) + monkeypatch.setattr(cli_mod, "prompt_for_tools", lambda options: ["claude"]) + cli_mod.configure_workspace_command() assert picked_for == ["claude"] def test_unavailable_selected_tool_errors_before_configure(self, monkeypatch): @@ -714,6 +714,7 @@ def fake_configure_shared_state( monkeypatch.setattr(cli_mod, "save_state", lambda state: saved.append(state["workspace"])) monkeypatch.setattr(cli_mod, "check_gateway_endpoint", lambda state, tool: True) monkeypatch.setattr(cli_mod, "prompt_for_tools", lambda available: ["codex"]) + monkeypatch.setattr(cli_mod, "_maybe_select_provider_service", lambda tool, state: state) monkeypatch.setattr(cli_mod, "install_tool_binary", lambda *args, **kwargs: True) monkeypatch.setattr( cli_mod, From 3cc23e777052a39b086bc6ac227ddb616ade87eb Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 16:27:44 +0000 Subject: [PATCH 05/11] Rename provider picker options to Databricks Hosted / External Models --- src/ucode/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 31aef31..3d9bdb7 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -423,8 +423,8 @@ def _use_databricks() -> dict: choice = prompt_for_selection( f"How should {display} be configured?", [ - ("databricks", "Databricks models"), - ("mps", "Model Provider Service (external provider)"), + ("databricks", "Databricks Hosted"), + ("mps", "External Models"), ], ) if choice is None: From 8dd3ac5e7e53d07f1500d2adb0a80422bcea5ef1 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 17:07:11 +0000 Subject: [PATCH 06/11] CI: dump Claude tracing hook log on e2e tracing failure --- .github/workflows/ci.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9fb8ca..819096e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,3 +60,15 @@ jobs: # independent and one shouldn't mask the other. - if: ${{ !cancelled() }} run: uv run --extra tracing pytest tests/test_e2e_tracing.py -v < /dev/null + # Diagnostic: the Claude Stop hook writes its trace-creation log to + # $cwd/.claude/mlflow/claude_tracing.log. If the tracing test failed + # because no root `claude_code_conversation` span landed, this shows + # whether the hook fired on the runner and what its MLflow write did. + - if: ${{ failure() }} + name: Dump Claude tracing hook log + run: | + echo "=== claude_tracing.log ===" + cat .claude/mlflow/claude_tracing.log 2>/dev/null || echo "(no hook log — Stop hook never ran)" + echo "=== installed claude-code + mlflow CLI ===" + claude --version || true + "$(uv tool dir --bin)/mlflow" --version || true From debca5b1473ed2fead84ab8ce95692c59f1f78d1 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 17:21:26 +0000 Subject: [PATCH 07/11] Force synchronous MLflow trace export in tracing hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Claude Stop hook (mlflow autolog claude stop-hook) is a short-lived one-shot process. With MLflow's default async trace logging, the root claude_code_conversation span is queued and the hook's best-effort flush can lose it before the process exits — observed on CI runners, leaving an orphaned llm span and no queryable trace (e2e tracing test failure). Set MLFLOW_ENABLE_ASYNC_TRACE_LOGGING=false so export is synchronous. --- src/ucode/tracing.py | 8 ++++++++ tests/test_tracing.py | 1 + 2 files changed, 9 insertions(+) diff --git a/src/ucode/tracing.py b/src/ucode/tracing.py index 84e81bb..fa5fce3 100644 --- a/src/ucode/tracing.py +++ b/src/ucode/tracing.py @@ -133,6 +133,14 @@ def tracing_env(state: dict, tool: str) -> dict[str, str]: env = { "MLFLOW_TRACKING_URI": str(cfg["tracking_uri"]), "MLFLOW_EXPERIMENT_ID": str(entry["experiment_id"]), + # The trace is emitted by the `mlflow autolog claude stop-hook` Stop + # hook, which is always a short-lived one-shot process. With async + # trace logging (MLflow's default) the root `claude_code_conversation` + # span is queued and the hook's best-effort flush can lose it if the + # export hasn't drained before the process exits — observed on CI + # runners, where it leaves an orphaned `llm` span and no queryable + # trace. Force synchronous export so the span is written before exit. + "MLFLOW_ENABLE_ASYNC_TRACE_LOGGING": "false", } warehouse_id = cfg.get("sql_warehouse_id") if warehouse_id: diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 088b2e1..27e5891 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -82,6 +82,7 @@ def test_uri_and_experiment(self): "MLFLOW_TRACKING_URI": "databricks://p", "MLFLOW_EXPERIMENT_ID": "111", "MLFLOW_TRACING_SQL_WAREHOUSE_ID": "wh123", + "MLFLOW_ENABLE_ASYNC_TRACE_LOGGING": "false", } def test_empty_for_non_claude_agents(self): From 99c68bbce7ec59173d0779908ffcf4473c40134a Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 17:33:27 +0000 Subject: [PATCH 08/11] CI: probe MLflow span export on tracing failure --- .github/workflows/ci.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 819096e..0bfbc21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,3 +72,27 @@ jobs: echo "=== installed claude-code + mlflow CLI ===" claude --version || true "$(uv tool dir --bin)/mlflow" --version || true + # Diagnostic: reproduce the client-side span export from the runner with + # verbose logging. The hook's `log_spans` failure is logged at WARNING to + # mlflow's own logger (not the hook file log), so surface it directly to + # see why the root span never reaches the trace server from CI. + - if: ${{ failure() }} + name: Probe MLflow span export + env: + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_BEARER }} + run: | + uv run --extra tracing python - <<'PY' 2>&1 || true + import logging + logging.basicConfig(level=logging.DEBUG) + import os, mlflow + mlflow.set_tracking_uri("databricks") + mlflow.set_experiment(experiment_id="2190569664060193") + print("=== starting probe span ===") + with mlflow.start_span(name="ci_export_probe") as span: + span.set_inputs({"probe": True}) + try: + mlflow.flush_trace_async_logging() + except Exception as e: + print("flush error:", repr(e)) + print("=== probe done ===") + PY From 7e4654e17e22b52bf79961abdcfe54860787310d Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Fri, 26 Jun 2026 17:44:39 +0000 Subject: [PATCH 09/11] CI: probe span export under mlflow 3.11.1 vs 3.12.0 --- .github/workflows/ci.yml | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bfbc21..f57a5e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,22 +77,21 @@ jobs: # mlflow's own logger (not the hook file log), so surface it directly to # see why the root span never reaches the trace server from CI. - if: ${{ failure() }} - name: Probe MLflow span export + name: Probe MLflow span export (3.11.1 vs 3.12.0) env: DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_BEARER }} run: | - uv run --extra tracing python - <<'PY' 2>&1 || true - import logging - logging.basicConfig(level=logging.DEBUG) - import os, mlflow + for V in 3.11.1 3.12.0; do + echo "=== probing mlflow==$V ===" + uv run --with "mlflow[databricks]==$V" python - "$V" <<'PY' 2>&1 | grep -iE "probe|trace_id|version|error|exception|Failed to log" || true + import sys, mlflow + v = sys.argv[1] mlflow.set_tracking_uri("databricks") mlflow.set_experiment(experiment_id="2190569664060193") - print("=== starting probe span ===") - with mlflow.start_span(name="ci_export_probe") as span: - span.set_inputs({"probe": True}) - try: - mlflow.flush_trace_async_logging() - except Exception as e: - print("flush error:", repr(e)) - print("=== probe done ===") + print("mlflow version", mlflow.__version__) + with mlflow.start_span(name=f"ci_probe_{v.replace('.','_')}") as span: + span.set_inputs({"v": v}) + print("trace_id", span.trace_id) + mlflow.flush_trace_async_logging() PY + done From ee31cf99090748cd26050fc76f0bd7929f16b4d1 Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Wed, 1 Jul 2026 01:30:26 +0000 Subject: [PATCH 10/11] Add Amazon Bedrock model provider routing for claude Bedrock-backed Model Provider Services expose Claude under provider-side model ids (e.g. us.anthropic.claude-sonnet-4-6) rather than Claude Code's canonical names, so ucode pins them explicitly via the ANTHROPIC_DEFAULT_* env vars. Maps service targets to opus/sonnet/haiku families, preferring the highest version and broadest-routing region profile, and validates that a Bedrock service exposes at least one Claude model before routing to it. Co-authored-by: Isaac --- src/ucode/agents/__init__.py | 42 +++++++++- src/ucode/agents/claude.py | 36 +++++++-- src/ucode/cli.py | 28 +++---- src/ucode/databricks.py | 145 ++++++++++++++++++++++++++++++++--- tests/test_agent_claude.py | 23 ++++++ tests/test_agents_init.py | 39 ++++++++++ tests/test_databricks.py | 140 +++++++++++++++++++++++++++++++-- tests/test_e2e.py | 11 ++- 8 files changed, 422 insertions(+), 42 deletions(-) diff --git a/src/ucode/agents/__init__.py b/src/ucode/agents/__init__.py index 414c354..2eedf98 100644 --- a/src/ucode/agents/__init__.py +++ b/src/ucode/agents/__init__.py @@ -18,7 +18,11 @@ from ucode.config_io import ToolSpec from ucode.databricks import ( + BEDROCK_PROVIDER_TYPES, + get_databricks_token, install_databricks_cli, + map_bedrock_claude_models, + resolve_provider_service, ) from ucode.state import get_provider_service, load_state, save_state from ucode.telemetry import agent_version @@ -254,8 +258,35 @@ def resolve_launch_model( return state, model +def resolve_provider_models( + tool: str, state: dict, provider: str | None +) -> tuple[dict | None, str | None]: + """Validate ``provider`` for ``tool`` and return the model ids to pin. + + Returns ``(provider_models, error)``. ``provider_models`` is a + ``{family: model_id}`` dict for a Bedrock-backed claude service (whose + provider-side ids must be pinned explicitly), or None for an Anthropic/ + canonical service or when ``provider`` is None. A non-None ``error`` means + the provider is invalid for the tool (wrong type, missing, feature off, or a + Bedrock service with no Claude models) and the caller should not launch. + """ + if not provider: + return None, None + token = get_databricks_token(state["workspace"], state.get("profile")) + service, error = resolve_provider_service(tool, provider, state["workspace"], token) + if error or service is None: + return None, error + if service["provider_type"] in BEDROCK_PROVIDER_TYPES: + return map_bedrock_claude_models(service.get("targets") or []), None + return None, None + + def configure_tool( - tool: str, state: dict, model: str | None = None, provider: str | None = None + tool: str, + state: dict, + model: str | None = None, + provider: str | None = None, + provider_models: dict[str, str] | None = None, ) -> dict: result: dict | tuple[dict, str] if tool == "codex": @@ -265,7 +296,9 @@ def configure_tool( # model, so the usual "model required" guard doesn't apply to claude. if not model and not provider: raise RuntimeError(f"A {tool} model must be selected before configuration.") - result = claude.write_tool_config(state, model, provider=provider) + result = claude.write_tool_config( + state, model, provider=provider, provider_models=provider_models + ) else: # provider routing is claude/codex-only; every other tool needs a model. if not model: @@ -353,7 +386,10 @@ def configure_single_tool(tool: str, state: dict) -> dict: def _configure_one(tool: str, state: dict, provider: str | None) -> dict: """Write one tool's config, routing through ``provider`` when set.""" if provider: - return configure_tool(tool, state, None, provider=provider) + provider_models, error = resolve_provider_models(tool, state, provider) + if error: + raise RuntimeError(error) + return configure_tool(tool, state, None, provider=provider, provider_models=provider_models) if tool == "codex": return configure_tool("codex", state) state, model = resolve_launch_model(tool, state, None) diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index c733178..7fc4bc1 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -130,6 +130,7 @@ def render_overlay( profile: str | None = None, use_pat: bool = False, provider: str | None = None, + provider_models: dict[str, str] | None = None, ) -> tuple[dict, list[list[str]]]: """Return (overlay, managed_key_paths) for Claude settings.json. @@ -139,9 +140,11 @@ def render_overlay( When `provider` is set (a `..` Model Provider Service), the request is routed to that external provider via the - `Databricks-Model-Provider-Service` header and no Databricks model id is - pinned — Claude Code uses its own canonical model names, which the provider - understands.""" + `Databricks-Model-Provider-Service` header. An Anthropic-backed provider + understands Claude Code's own canonical model names, so no model id is + pinned. A Bedrock-backed provider exposes different model ids (e.g. + `us.anthropic.claude-sonnet-4-6`), passed in `provider_models` by family — + those get pinned via the `ANTHROPIC_DEFAULT_*_MODEL` env vars.""" base_url = build_tool_base_url("claude", workspace) # ANTHROPIC_CUSTOM_HEADERS is parsed as `key: value` pairs separated by # newlines (Anthropic SDK convention). Setting User-Agent here overrides @@ -167,10 +170,21 @@ def render_overlay( # only one row per model. `ucode claude -- --model X` still overrides for a # single session via Claude Code's own --model flag. _ = model # API stability; no longer pinned via env. - # With a Model Provider Service, the header routes to the external provider - # and Claude Code's own canonical model names are sent verbatim — pinning a - # Databricks model id here would mislabel the picker and isn't routable. - if claude_models and not provider: + # A Bedrock-backed provider needs its provider-side ids pinned verbatim + # (Claude Code's canonical names aren't routable there). These come from the + # service's targets, already de-duped to one id per family upstream. + if provider and provider_models: + if provider_models.get("opus"): + env["ANTHROPIC_DEFAULT_OPUS_MODEL"] = provider_models["opus"] + if provider_models.get("sonnet"): + env["ANTHROPIC_DEFAULT_SONNET_MODEL"] = provider_models["sonnet"] + if provider_models.get("haiku"): + env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = provider_models["haiku"] + # With an Anthropic Model Provider Service, the header routes to the external + # provider and Claude Code's own canonical model names are sent verbatim — + # pinning a Databricks model id here would mislabel the picker and isn't + # routable. + elif claude_models and not provider: # Picker rows show the raw routable id (e.g. "system.ai.claude-opus-4-8[1m]") # so users can see which gateway-routable model is behind each shortcut. # We deliberately don't set the `_NAME` companion env vars — the raw id @@ -254,7 +268,12 @@ def _unregister_web_search_mcp() -> None: pass -def write_tool_config(state: dict, model: str | None, provider: str | None = None) -> dict: +def write_tool_config( + state: dict, + model: str | None, + provider: str | None = None, + provider_models: dict[str, str] | None = None, +) -> dict: backup_existing_file(CLAUDE_SETTINGS_PATH, CLAUDE_BACKUP_PATH) web_search_model = _resolve_web_search_model(state) overlay, managed_keys = render_overlay( @@ -265,6 +284,7 @@ def write_tool_config(state: dict, model: str | None, provider: str | None = Non profile=state.get("profile"), use_pat=bool(state.get("use_pat")), provider=provider, + provider_models=provider_models, ) tracing_env_vars = tracing_env(state, "claude") stop_hook_command = claude_tracing_stop_hook_command() if tracing_env_vars else None diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 3d9bdb7..5ae5116 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -21,6 +21,7 @@ normalize_tool, provider_permission_error, resolve_launch_model, + resolve_provider_models, validate_all_tools, validate_tool, ) @@ -44,7 +45,6 @@ get_databricks_token, install_databricks_cli, is_model_provider_feature_unavailable, - list_model_provider_services, list_profile_entries, list_tool_provider_services, normalize_workspace_url, @@ -769,19 +769,17 @@ def _launch_tool(tool_name: str, ctx: typer.Context, provider: str | None = None state = ensure_provider_state(tool) # An explicit --provider overrides the persisted choice; otherwise fall # back to whatever `ucode configure` saved for this tool. - explicit_provider = provider is not None provider = provider or get_provider_service(state, tool) - if provider and explicit_provider: - # Verify the feature only for an explicit --provider; a persisted - # choice was already validated at `ucode configure` time, so trust it - # and keep the launch fast. Surfaces a clear error up front instead of - # a cryptic gateway error mid-session. - token = get_databricks_token(state["workspace"], state.get("profile")) - _, reason = list_model_provider_services(state["workspace"], token) - if is_model_provider_feature_unavailable(reason): - raise RuntimeError( - "Model Provider Service feature is not available yet for this workspace." - ) + # Validate the provider service before launching — it must exist, be a + # provider type this tool can route to (e.g. claude can't use an OpenAI + # or Foundry service), and, for Bedrock, expose Claude models to pin. + # Surfaces a clear error up front instead of a cryptic gateway failure + # mid-session. For a Bedrock service this also returns the model ids. + provider_models = None + if provider: + provider_models, error = resolve_provider_models(tool, state, provider) + if error: + raise RuntimeError(error) # Re-fetch model lists on every launch so newly-added Databricks # endpoints show up without a manual `ucode configure` (and so that # tools like pi which read multiple model bundles never run on @@ -801,7 +799,9 @@ def _launch_tool(tool_name: str, ctx: typer.Context, provider: str | None = None resolved_model = None else: state, resolved_model = resolve_launch_model(tool, state, None) - state = configure_tool(tool, state, resolved_model, provider=provider) + state = configure_tool( + tool, state, resolved_model, provider=provider, provider_models=provider_models + ) print_section(f"ucode with {TOOL_SPECS[tool]['display']}") if provider: print_kv("Provider", provider) diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index cb279e5..ca7a803 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -1262,13 +1262,25 @@ def build_mcp_service_url(workspace: str, full_name: str) -> str: # Maps the gateway routing dialect a coding tool speaks to the Model Provider -# Service `provider_type` it can be backed by. claude speaks Anthropic's API; -# codex speaks OpenAI's. -_TOOL_PROVIDER_TYPES: dict[str, str] = { - "claude": "anthropic", - "codex": "openai", +# Service `provider_type`s it can be backed by. claude speaks Anthropic's API, +# which both the `anthropic` and `amazon_bedrock` provider types serve (Bedrock +# just exposes different model ids); codex speaks OpenAI's. Tags are the short +# form produced by `_provider_type_tag` (e.g. `amazon_bedrock`). +_TOOL_PROVIDER_TYPES: dict[str, tuple[str, ...]] = { + "claude": ("anthropic", "amazon_bedrock"), + "codex": ("openai",), } +# Provider types that expose Bedrock-style model ids (e.g. +# `us.anthropic.claude-sonnet-4-6`) instead of the agent's canonical model +# names, so ucode must pin them explicitly. +BEDROCK_PROVIDER_TYPES: tuple[str, ...] = ("amazon_bedrock",) + + +def tool_supports_provider_type(tool: str, provider_type: str) -> bool: + """True when ``tool``'s API dialect can be backed by ``provider_type``.""" + return provider_type in _TOOL_PROVIDER_TYPES.get(tool, ()) + def _provider_type_tag(provider_type: str | None) -> str: """Shorten `EXTERNAL_MODEL_PROVIDER_TYPE_ANTHROPIC` to `anthropic`.""" @@ -1283,8 +1295,10 @@ def list_model_provider_services(workspace: str, token: str) -> tuple[list[dict] """List Unity Catalog Model Provider Services on the workspace. Returns ``(services, reason)`` where each service is - ``{"name": "..", "provider_type": "anthropic"|...}``. - A non-None ``reason`` means the listing call itself failed. + ``{"name": "..", "provider_type": "anthropic"|..., + "targets": [model_id, ...], "allow_all_targets": bool}``. ``targets`` is the + provider-side model ids the service exposes (used to pin Bedrock model + names). A non-None ``reason`` means the listing call itself failed. """ hostname = workspace_hostname(workspace) url = f"https://{hostname}/api/2.1/unity-catalog/model-provider-services" @@ -1302,10 +1316,17 @@ def list_model_provider_services(workspace: str, token: str) -> tuple[list[dict] # The API returns `model-provider-services/..`. full_name = raw_name.split("/", 1)[1] if "/" in raw_name else raw_name config = service.get("config") if isinstance(service.get("config"), dict) else {} + targets = [] + for target in config.get("targets") or []: + model_id = target.get("model") if isinstance(target, dict) else None + if isinstance(model_id, str) and model_id: + targets.append(model_id) services.append( { "name": full_name, "provider_type": _provider_type_tag(config.get("provider_type")), + "targets": targets, + "allow_all_targets": bool(config.get("allow_all_targets")), } ) services.sort(key=lambda s: s["name"]) @@ -1328,14 +1349,120 @@ def list_tool_provider_services( Returns ``(names, reason)``; ``reason`` is non-None when the listing failed. """ - wanted = _TOOL_PROVIDER_TYPES.get(tool) services, reason = list_model_provider_services(workspace, token) if reason is not None: return [], reason - names = [s["name"] for s in services if not wanted or s["provider_type"] == wanted] + names = [s["name"] for s in services if service_usable_for_tool(tool, s)] return names, None +def service_usable_for_tool(tool: str, service: dict) -> bool: + """True when ``tool`` can actually route through ``service``. + + Beyond the provider-type match, a Bedrock service is only usable for claude + if it exposes at least one Claude model in its targets — otherwise there's no + routable model id to pin. (Anthropic services use canonical names, so any + match is usable.) + """ + provider_type = service.get("provider_type", "") + if not tool_supports_provider_type(tool, provider_type): + return False + if provider_type in BEDROCK_PROVIDER_TYPES: + return bool(map_bedrock_claude_models(service.get("targets") or [])) + return True + + +def resolve_provider_service( + tool: str, service_name: str, workspace: str, token: str +) -> tuple[dict | None, str | None]: + """Validate that ``service_name`` exists and is usable by ``tool``. + + Returns ``(service, error)``. On success ``service`` is the full service + dict (``name``/``provider_type``/``targets``/``allow_all_targets``) and + ``error`` is None. On failure ``service`` is None and ``error`` is an + actionable message: the feature is off, the listing failed, the service + doesn't exist, or its provider type isn't one ``tool`` can route to (e.g. + pointing claude at an OpenAI service). + """ + services, reason = list_model_provider_services(workspace, token) + if is_model_provider_feature_unavailable(reason): + return None, "Model Provider Service feature is not available yet for this workspace." + if reason is not None: + return None, f"Could not list model provider services: {reason}" + match = next((s for s in services if s["name"] == service_name), None) + if match is None: + usable = [ + s["name"] for s in services if tool_supports_provider_type(tool, s["provider_type"]) + ] + suffix = f" Available for {tool}: {', '.join(usable)}." if usable else "" + return None, f"Model provider service '{service_name}' was not found.{suffix}" + provider_type = match["provider_type"] + if not tool_supports_provider_type(tool, provider_type): + supported = ", ".join(_TOOL_PROVIDER_TYPES.get(tool, ())) or "none" + return None, ( + f"Model provider service '{service_name}' is a '{provider_type}' provider, " + f"which {tool} can't route to (supported: {supported})." + ) + if provider_type in BEDROCK_PROVIDER_TYPES and not map_bedrock_claude_models( + match.get("targets") or [] + ): + return None, ( + f"Model provider service '{service_name}' exposes no Claude models — " + f"add Claude targets to it or pick a different service." + ) + return match, None + + +# Bedrock exposes Claude under provider-side ids like +# `us.anthropic.claude-sonnet-4-6`, `global.anthropic.claude-opus-4-8`, or the +# region-less `anthropic.claude-opus-4-8`. We map each service target to a +# Claude family and keep the best id per family. Claude Code only takes one +# default per family; users switch to any other listed region profile at runtime +# with `/model ` or `--model`. +_BEDROCK_CLAUDE_FAMILIES = ("opus", "sonnet", "haiku") +# When the same model/version is offered under several cross-region inference +# profiles, prefer the broadest-routing one as the pinned default. +_BEDROCK_REGION_RANK = {"global": 5, "us": 4, "eu": 3, "apac": 2, "": 1} + + +def _bedrock_target_family(model_id: str) -> str | None: + lowered = model_id.lower() + if "claude" not in lowered: + return None + return next((fam for fam in _BEDROCK_CLAUDE_FAMILIES if fam in lowered), None) + + +def _bedrock_region_rank(model_id: str) -> int: + """Rank a target's cross-region inference profile (`us.`/`eu.`/`global.`/ + region-less) so ties on model version resolve deterministically.""" + head = model_id.lower().split("anthropic.", 1)[0].rstrip(".") + return _BEDROCK_REGION_RANK.get(head, 0) + + +def _bedrock_sort_key(model_id: str) -> tuple: + """Order targets best-first: highest model version, then preferred region.""" + version = tuple(int(n) for n in re.findall(r"\d+", model_id)) + return (version, _bedrock_region_rank(model_id)) + + +def map_bedrock_claude_models(targets: list[str]) -> dict[str, str]: + """Map Bedrock service targets to ``{family: model_id}`` for opus/sonnet/ + haiku, choosing the highest-versioned id per family and, on a version tie, + the broadest-routing region profile. Targets that don't name a Claude family + are ignored.""" + best_key: dict[str, tuple] = {} + result: dict[str, str] = {} + for model_id in targets: + family = _bedrock_target_family(model_id) + if not family: + continue + key = _bedrock_sort_key(model_id) + if family not in best_key or key > best_key[family]: + best_key[family] = key + result[family] = model_id + return result + + # `list_vector_search_catalog_schemas` walks Vector Search endpoints+indexes. # `list_uc_functions_catalog_schemas` walks UC catalogs+schemas in parallel and # keeps only schemas with at least one user function. diff --git a/tests/test_agent_claude.py b/tests/test_agent_claude.py index 86975b4..838cf36 100644 --- a/tests/test_agent_claude.py +++ b/tests/test_agent_claude.py @@ -135,6 +135,29 @@ def test_no_provider_header_without_flag(self): overlay, _ = claude.render_overlay(WS, "s4") assert "Databricks-Model-Provider-Service" not in overlay["env"]["ANTHROPIC_CUSTOM_HEADERS"] + def test_bedrock_provider_pins_model_ids(self): + provider_models = { + "opus": "global.anthropic.claude-opus-4-8", + "sonnet": "us.anthropic.claude-sonnet-4-6", + "haiku": "anthropic.claude-haiku-4-5", + } + overlay, _ = claude.render_overlay( + WS, + None, + provider="main.bob.bedrock-svc", + provider_models=provider_models, + ) + env = overlay["env"] + assert env["ANTHROPIC_DEFAULT_OPUS_MODEL"] == "global.anthropic.claude-opus-4-8" + assert env["ANTHROPIC_DEFAULT_SONNET_MODEL"] == "us.anthropic.claude-sonnet-4-6" + assert env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] == "anthropic.claude-haiku-4-5" + # Bedrock ids are pinned verbatim — no `[1m]` suffix mangling. + assert "[1m]" not in env["ANTHROPIC_DEFAULT_OPUS_MODEL"] + assert ( + "Databricks-Model-Provider-Service: main.bob.bedrock-svc" + in env["ANTHROPIC_CUSTOM_HEADERS"] + ) + def test_picker_labels_show_raw_routable_id(self): # We deliberately don't set the `_NAME` companion env vars. Showing the # raw `system.ai.…` / `databricks-…` id in the picker label tells users diff --git a/tests/test_agents_init.py b/tests/test_agents_init.py index 3cecadb..576fc22 100644 --- a/tests/test_agents_init.py +++ b/tests/test_agents_init.py @@ -202,6 +202,45 @@ def test_raises_when_no_models_available(self): resolve_launch_model("claude", {}, None) +class TestResolveProviderModels: + _STATE = {"workspace": "https://ws.databricks.com", "profile": None} + + def _patch(self, monkeypatch, service, error): + monkeypatch.setattr(agents_mod, "get_databricks_token", lambda w, p: "token") + monkeypatch.setattr( + agents_mod, "resolve_provider_service", lambda t, n, w, tok: (service, error) + ) + + def test_none_provider_returns_none(self): + models, error = agents_mod.resolve_provider_models("claude", self._STATE, None) + assert (models, error) == (None, None) + + def test_anthropic_returns_no_models(self, monkeypatch): + self._patch(monkeypatch, {"provider_type": "anthropic", "targets": []}, None) + models, error = agents_mod.resolve_provider_models("claude", self._STATE, "main.a.svc") + assert error is None + assert models is None + + def test_bedrock_returns_pinned_models(self, monkeypatch): + service = { + "provider_type": "amazon_bedrock", + "targets": ["us.anthropic.claude-sonnet-4-6", "global.anthropic.claude-opus-4-8"], + } + self._patch(monkeypatch, service, None) + models, error = agents_mod.resolve_provider_models("claude", self._STATE, "main.b.svc") + assert error is None + assert models == { + "sonnet": "us.anthropic.claude-sonnet-4-6", + "opus": "global.anthropic.claude-opus-4-8", + } + + def test_invalid_provider_returns_error(self, monkeypatch): + self._patch(monkeypatch, None, "boom") + models, error = agents_mod.resolve_provider_models("claude", self._STATE, "main.x.svc") + assert models is None + assert error == "boom" + + class TestInstallToolBinary: def test_non_strict_returns_false_when_npm_missing(self, monkeypatch): monkeypatch.setattr("ucode.agents.shutil.which", lambda _: None) diff --git a/tests/test_databricks.py b/tests/test_databricks.py index 86a64ac..38c35d2 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -285,7 +285,24 @@ class TestListModelProviderServices: }, { "name": "model-provider-services/main.bob.bedrock-svc", - "config": {"provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_BEDROCK"}, + "config": { + "provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_AMAZON_BEDROCK", + "allow_all_targets": False, + "targets": [ + { + "model": "us.anthropic.claude-sonnet-4-6", + "native_api_types": ["anthropic/v1/messages"], + }, + {"model": "global.anthropic.claude-opus-4-8"}, + ], + }, + }, + { + "name": "model-provider-services/main.bob.bedrock-titan-svc", + "config": { + "provider_type": "EXTERNAL_MODEL_PROVIDER_TYPE_AMAZON_BEDROCK", + "targets": [{"model": "amazon.titan-text-express-v1"}], + }, }, ] } @@ -296,8 +313,28 @@ def test_strips_prefix_and_tags_provider_type(self, monkeypatch): ) services, reason = db_mod.list_model_provider_services(WS, "token") assert reason is None - assert services[0] == {"name": "main.aarushi.anthropic-svc", "provider_type": "anthropic"} - assert {s["provider_type"] for s in services} == {"anthropic", "openai", "bedrock"} + assert services[0] == { + "name": "main.aarushi.anthropic-svc", + "provider_type": "anthropic", + "targets": [], + "allow_all_targets": False, + } + assert {s["provider_type"] for s in services} == { + "anthropic", + "openai", + "amazon_bedrock", + } + + def test_extracts_targets(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) + ) + services, _ = db_mod.list_model_provider_services(WS, "token") + bedrock = next(s for s in services if s["name"] == "main.bob.bedrock-svc") + assert bedrock["targets"] == [ + "us.anthropic.claude-sonnet-4-6", + "global.anthropic.claude-opus-4-8", + ] def test_returns_reason_on_failure(self, monkeypatch): monkeypatch.setattr( @@ -307,13 +344,15 @@ def test_returns_reason_on_failure(self, monkeypatch): assert services == [] assert reason == "HTTP 500 Server Error" - def test_claude_filters_to_anthropic(self, monkeypatch): + def test_claude_includes_anthropic_and_usable_bedrock(self, monkeypatch): monkeypatch.setattr( db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) ) names, reason = db_mod.list_tool_provider_services("claude", WS, "token") assert reason is None - assert names == ["main.aarushi.anthropic-svc"] + # Anthropic + the Bedrock service with Claude targets; the Bedrock service + # exposing only Titan is hidden (no Claude models to pin). + assert names == ["main.aarushi.anthropic-svc", "main.bob.bedrock-svc"] def test_codex_filters_to_openai(self, monkeypatch): monkeypatch.setattr( @@ -323,6 +362,97 @@ def test_codex_filters_to_openai(self, monkeypatch): assert names == ["main.aarushi.openai-svc"] +class TestMapBedrockClaudeModels: + def test_maps_families(self): + models = db_mod.map_bedrock_claude_models( + [ + "us.anthropic.claude-sonnet-4-6", + "global.anthropic.claude-opus-4-8", + "anthropic.claude-haiku-4-5", + "amazon.titan-text-express-v1", + ] + ) + assert models == { + "sonnet": "us.anthropic.claude-sonnet-4-6", + "opus": "global.anthropic.claude-opus-4-8", + "haiku": "anthropic.claude-haiku-4-5", + } + + def test_prefers_highest_version(self): + models = db_mod.map_bedrock_claude_models( + ["us.anthropic.claude-sonnet-4-5", "us.anthropic.claude-sonnet-4-6"] + ) + assert models["sonnet"] == "us.anthropic.claude-sonnet-4-6" + + def test_region_tie_break_prefers_global(self): + models = db_mod.map_bedrock_claude_models( + [ + "us.anthropic.claude-opus-4-8", + "global.anthropic.claude-opus-4-8", + "eu.anthropic.claude-opus-4-8", + ] + ) + assert models["opus"] == "global.anthropic.claude-opus-4-8" + + def test_empty_when_no_claude(self): + assert db_mod.map_bedrock_claude_models(["amazon.titan-text-express-v1"]) == {} + + +class TestResolveProviderService: + _PAYLOAD = TestListModelProviderServices._PAYLOAD + + def _patch(self, monkeypatch): + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=30: (self._PAYLOAD, None) + ) + + def test_anthropic_ok(self, monkeypatch): + self._patch(monkeypatch) + service, error = db_mod.resolve_provider_service( + "claude", "main.aarushi.anthropic-svc", WS, "token" + ) + assert error is None + assert service["provider_type"] == "anthropic" + + def test_bedrock_with_claude_ok(self, monkeypatch): + self._patch(monkeypatch) + service, error = db_mod.resolve_provider_service( + "claude", "main.bob.bedrock-svc", WS, "token" + ) + assert error is None + assert service["provider_type"] == "amazon_bedrock" + + def test_wrong_type_rejected(self, monkeypatch): + self._patch(monkeypatch) + service, error = db_mod.resolve_provider_service( + "claude", "main.aarushi.openai-svc", WS, "token" + ) + assert service is None + assert "can't route to" in error + + def test_bedrock_without_claude_rejected(self, monkeypatch): + self._patch(monkeypatch) + service, error = db_mod.resolve_provider_service( + "claude", "main.bob.bedrock-titan-svc", WS, "token" + ) + assert service is None + assert "no Claude models" in error + + def test_not_found_lists_usable(self, monkeypatch): + self._patch(monkeypatch) + service, error = db_mod.resolve_provider_service("claude", "main.x.missing", WS, "token") + assert service is None + assert "was not found" in error + assert "main.aarushi.anthropic-svc" in error + + def test_feature_unavailable(self, monkeypatch): + reason = "HTTP 400 Bad Request: ModelProviderService feature is not available" + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token, timeout=30: (None, reason)) + service, error = db_mod.resolve_provider_service("claude", "main.x.y", WS, "token") + assert service is None + assert "not available" in error + + class TestModelProviderFeatureUnavailable: def test_detects_feature_not_available(self): reason = ( diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 79fb91e..5d115f6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -31,6 +31,7 @@ is_model_provider_feature_unavailable, list_model_provider_services, list_tool_provider_services, + service_usable_for_tool, workspace_hostname, ) from ucode.ui import normalize_workspace_url @@ -162,7 +163,7 @@ def test_list_returns_services_or_skips_when_feature_off(self, e2e_workspace, e2 assert reason is None, f"listing failed: {reason}" assert isinstance(services, list) for svc in services: - assert set(svc) >= {"name", "provider_type"} + assert set(svc) >= {"name", "provider_type", "targets", "allow_all_targets"} # Names are stripped of the `model-provider-services/` API prefix. assert svc["name"] and "/" not in svc["name"] @@ -173,10 +174,14 @@ def test_tool_filter_matches_provider_type(self, e2e_workspace, e2e_token): assert reason is None claude_names, _ = list_tool_provider_services("claude", e2e_workspace, e2e_token) codex_names, _ = list_tool_provider_services("codex", e2e_workspace, e2e_token) + # claude routes through Anthropic and Bedrock services (Bedrock only when + # it exposes Claude models); codex through OpenAI. assert set(claude_names) == { - s["name"] for s in services if s["provider_type"] == "anthropic" + s["name"] for s in services if service_usable_for_tool("claude", s) + } + assert set(codex_names) == { + s["name"] for s in services if service_usable_for_tool("codex", s) } - assert set(codex_names) == {s["name"] for s in services if s["provider_type"] == "openai"} # --------------------------------------------------------------------------- From 0d19adbe8543472bcbfc4adfaee0ecd49d434b9d Mon Sep 17 00:00:00 2001 From: AarushiShah-db Date: Wed, 1 Jul 2026 01:33:33 +0000 Subject: [PATCH 11/11] Revert diagnostic CI probes in ci.yml Co-authored-by: Isaac --- .github/workflows/ci.yml | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f57a5e0..f9fb8ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,38 +60,3 @@ jobs: # independent and one shouldn't mask the other. - if: ${{ !cancelled() }} run: uv run --extra tracing pytest tests/test_e2e_tracing.py -v < /dev/null - # Diagnostic: the Claude Stop hook writes its trace-creation log to - # $cwd/.claude/mlflow/claude_tracing.log. If the tracing test failed - # because no root `claude_code_conversation` span landed, this shows - # whether the hook fired on the runner and what its MLflow write did. - - if: ${{ failure() }} - name: Dump Claude tracing hook log - run: | - echo "=== claude_tracing.log ===" - cat .claude/mlflow/claude_tracing.log 2>/dev/null || echo "(no hook log — Stop hook never ran)" - echo "=== installed claude-code + mlflow CLI ===" - claude --version || true - "$(uv tool dir --bin)/mlflow" --version || true - # Diagnostic: reproduce the client-side span export from the runner with - # verbose logging. The hook's `log_spans` failure is logged at WARNING to - # mlflow's own logger (not the hook file log), so surface it directly to - # see why the root span never reaches the trace server from CI. - - if: ${{ failure() }} - name: Probe MLflow span export (3.11.1 vs 3.12.0) - env: - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_BEARER }} - run: | - for V in 3.11.1 3.12.0; do - echo "=== probing mlflow==$V ===" - uv run --with "mlflow[databricks]==$V" python - "$V" <<'PY' 2>&1 | grep -iE "probe|trace_id|version|error|exception|Failed to log" || true - import sys, mlflow - v = sys.argv[1] - mlflow.set_tracking_uri("databricks") - mlflow.set_experiment(experiment_id="2190569664060193") - print("mlflow version", mlflow.__version__) - with mlflow.start_span(name=f"ci_probe_{v.replace('.','_')}") as span: - span.set_inputs({"v": v}) - print("trace_id", span.trace_id) - mlflow.flush_trace_async_logging() - PY - done