Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions src/ucode/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@

from ucode.config_io import ToolSpec
from ucode.databricks import (
BEDROCK_PROVIDER_TYPES,
get_databricks_token,
install_databricks_cli,
map_bedrock_claude_models,
resolve_provider_service,
)
from ucode.state import load_state, save_state
from ucode.state import get_provider_service, load_state, save_state
from ucode.telemetry import agent_version
from ucode.ui import (
console,
Expand Down Expand Up @@ -254,16 +258,52 @@ def resolve_launch_model(
return state, model


def configure_tool(tool: str, state: dict, model: str | None = None) -> dict:
def resolve_provider_models(
tool: str, state: dict, provider: str | None
) -> tuple[dict | None, str | None]:
"""Validate ``provider`` for ``tool`` and return the model ids to pin.

Returns ``(provider_models, error)``. ``provider_models`` is a
``{family: model_id}`` dict for a Bedrock-backed claude service (whose
provider-side ids must be pinned explicitly), or None for an Anthropic/
canonical service or when ``provider`` is None. A non-None ``error`` means
the provider is invalid for the tool (wrong type, missing, feature off, or a
Bedrock service with no Claude models) and the caller should not launch.
"""
if not provider:
return None, None
token = get_databricks_token(state["workspace"], state.get("profile"))
service, error = resolve_provider_service(tool, provider, state["workspace"], token)
if error or service is None:
return None, error
if service["provider_type"] in BEDROCK_PROVIDER_TYPES:
return map_bedrock_claude_models(service.get("targets") or []), None
return None, None


def configure_tool(
tool: str,
state: dict,
model: str | None = None,
provider: str | None = None,
provider_models: dict[str, str] | None = None,
) -> dict:
result: dict | tuple[dict, str]
if tool == "codex":
result = codex.write_tool_config(state, model)
result = codex.write_tool_config(state, model, provider=provider)
elif tool == "claude":
# A Model Provider Service routes by header and pins no Databricks
# model, so the usual "model required" guard doesn't apply to claude.
if not model and not provider:
raise RuntimeError(f"A {tool} model must be selected before configuration.")
result = claude.write_tool_config(
state, model, provider=provider, provider_models=provider_models
)
else:
# provider routing is claude/codex-only; every other tool needs a model.
if not model:
raise RuntimeError(f"A {tool} model must be selected before configuration.")
if tool == "claude":
result = claude.write_tool_config(state, model)
elif tool == "gemini":
if tool == "gemini":
result = gemini.write_tool_config(state, model)
elif tool == "copilot":
result = copilot.write_tool_config(state, model)
Expand Down Expand Up @@ -325,24 +365,37 @@ def _availability_failure_detail(tool: str, state: dict) -> str:

def configure_single_tool(tool: str, state: dict) -> dict:
"""Check availability, configure, and persist state for one tool only."""
with spinner(f"Checking {TOOL_SPECS[tool]['display']} availability..."):
ok = check_gateway_endpoint(state, tool)
if not ok:
detail = _availability_failure_detail(tool, state)
raise RuntimeError(
f"{TOOL_SPECS[tool]['display']} is not available on this workspace.{detail}"
)
if tool == "codex":
state = configure_tool("codex", state)
else:
state, model = resolve_launch_model(tool, state, None)
state = configure_tool(tool, state, model)
provider = get_provider_service(state, tool)
# A Model Provider Service routes through the same gateway and pins no
# Databricks model, so the per-tool model availability check doesn't apply.
if not provider:
with spinner(f"Checking {TOOL_SPECS[tool]['display']} availability..."):
ok = check_gateway_endpoint(state, tool)
if not ok:
detail = _availability_failure_detail(tool, state)
raise RuntimeError(
f"{TOOL_SPECS[tool]['display']} is not available on this workspace.{detail}"
)
state = _configure_one(tool, state, provider)
available_tools = list(set((state.get("available_tools") or []) + [tool]))
state["available_tools"] = available_tools
save_state(state)
return state


def _configure_one(tool: str, state: dict, provider: str | None) -> dict:
"""Write one tool's config, routing through ``provider`` when set."""
if provider:
provider_models, error = resolve_provider_models(tool, state, provider)
if error:
raise RuntimeError(error)
return configure_tool(tool, state, None, provider=provider, provider_models=provider_models)
if tool == "codex":
return configure_tool("codex", state)
state, model = resolve_launch_model(tool, state, None)
return configure_tool(tool, state, model)


def configure_selected_tools(state: dict, tools: list[str]) -> dict:
"""Configure the given tools. Caller is responsible for ensuring each tool
is available on the workspace.
Expand All @@ -352,11 +405,7 @@ def configure_selected_tools(state: dict, tools: list[str]) -> dict:
run is preserved.
"""
for tool in tools:
if tool == "codex":
state = configure_tool("codex", state)
else:
state, model = resolve_launch_model(tool, state, None)
state = configure_tool(tool, state, model)
state = _configure_one(tool, state, get_provider_service(state, tool))

existing = state.get("available_tools") or []
state["available_tools"] = sorted(set(existing) | set(tools))
Expand Down Expand Up @@ -440,6 +489,21 @@ def validate_tool(tool: str) -> tuple[bool, str]:
return False, "timed out"


def provider_permission_error(tool: str, state: dict, err: str) -> str:
"""Rewrite the opaque gateway connection-permission failure into an
actionable message naming the Model Provider Service the user must be
granted access to. Returns ``err`` unchanged when it doesn't apply.
"""
provider = get_provider_service(state, tool)
if provider and "USE CONNECTION on SCHEMA_CONNECTION" in err:
return (
f"You don't have EXECUTE permission on the model provider service "
f"'{provider}'. Ask its owner to grant you access, then re-run "
f"`ucode configure`."
)
return err


def validate_all_tools(state: dict) -> None:
from rich.panel import Panel # local to avoid bumping module-level deps

Expand Down Expand Up @@ -470,7 +534,7 @@ def validate_all_tools(state: dict) -> None:
if ok:
print_success(f"{spec['display']} is working")
else:
print_err(f"{spec['display']}: {err}")
print_err(f"{spec['display']}: {provider_permission_error(tool, state, err)}")
managed = bool(state.get("managed_configs", {}).get(tool))
restore_file(spec["config_path"], spec["backup_path"], managed)
# Rollback settings.json for Pi
Expand Down
52 changes: 42 additions & 10 deletions src/ucode/agents/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,39 @@ def _web_search_mcp_entry(workspace: str, search_model: str, profile: str | None

def render_overlay(
workspace: str,
model: str,
model: str | None,
claude_models: dict[str, str] | None = None,
disable_web_search: bool = False,
profile: str | None = None,
use_pat: bool = False,
provider: str | None = None,
provider_models: dict[str, str] | None = None,
) -> tuple[dict, list[list[str]]]:
"""Return (overlay, managed_key_paths) for Claude settings.json.

NOTE: MCP servers are NOT written here. Claude Code reads `mcpServers`
from `~/.claude.json`, not `~/.claude/settings.json` — registration goes
through `claude mcp add-json` (see `_register_web_search_mcp`)."""
through `claude mcp add-json` (see `_register_web_search_mcp`).

When `provider` is set (a `<catalog>.<schema>.<name>` Model Provider
Service), the request is routed to that external provider via the
`Databricks-Model-Provider-Service` header. An Anthropic-backed provider
understands Claude Code's own canonical model names, so no model id is
pinned. A Bedrock-backed provider exposes different model ids (e.g.
`us.anthropic.claude-sonnet-4-6`), passed in `provider_models` by family —
those get pinned via the `ANTHROPIC_DEFAULT_*_MODEL` env vars."""
base_url = build_tool_base_url("claude", workspace)
# ANTHROPIC_CUSTOM_HEADERS is parsed as `key: value` pairs separated by
# newlines (Anthropic SDK convention). Setting User-Agent here overrides
# the SDK's default UA on outbound requests so the gateway can attribute
# traffic to ucode.
custom_headers = "\n".join(
[
"x-databricks-use-coding-agent-mode: true",
f"User-Agent: ucode/{ucode_version()} claude/{agent_version('claude')}",
]
)
header_lines = [
"x-databricks-use-coding-agent-mode: true",
f"User-Agent: ucode/{ucode_version()} claude/{agent_version('claude')}",
]
if provider:
header_lines.append(f"Databricks-Model-Provider-Service: {provider}")
custom_headers = "\n".join(header_lines)
env: dict[str, str] = {
"ANTHROPIC_BASE_URL": base_url,
"ANTHROPIC_CUSTOM_HEADERS": custom_headers,
Expand All @@ -160,7 +171,21 @@ def render_overlay(
# only one row per model. `ucode claude -- --model X` still overrides for a
# single session via Claude Code's own --model flag.
_ = model # API stability; no longer pinned via env.
if claude_models:
# A Bedrock-backed provider needs its provider-side ids pinned verbatim
# (Claude Code's canonical names aren't routable there). These come from the
# service's targets, already de-duped to one id per family upstream.
if provider and provider_models:
if provider_models.get("opus"):
env["ANTHROPIC_DEFAULT_OPUS_MODEL"] = provider_models["opus"]
if provider_models.get("sonnet"):
env["ANTHROPIC_DEFAULT_SONNET_MODEL"] = provider_models["sonnet"]
if provider_models.get("haiku"):
env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = provider_models["haiku"]
# With an Anthropic Model Provider Service, the header routes to the external
# provider and Claude Code's own canonical model names are sent verbatim —
# pinning a Databricks model id here would mislabel the picker and isn't
# routable.
elif claude_models and not provider:
# Picker rows show the raw routable id (e.g. "system.ai.claude-opus-4-8[1m]")
# so users can see which gateway-routable model is behind each shortcut.
# We deliberately don't set the `_NAME` companion env vars — the raw id
Expand Down Expand Up @@ -244,7 +269,12 @@ def _unregister_web_search_mcp() -> None:
pass


def write_tool_config(state: dict, model: str) -> dict:
def write_tool_config(
state: dict,
model: str | None,
provider: str | None = None,
provider_models: dict[str, str] | None = None,
) -> dict:
backup_existing_file(CLAUDE_SETTINGS_PATH, CLAUDE_BACKUP_PATH)
web_search_model = _resolve_web_search_model(state)
overlay, managed_keys = render_overlay(
Expand All @@ -254,6 +284,8 @@ def write_tool_config(state: dict, model: str) -> dict:
disable_web_search=web_search_model is not None,
profile=state.get("profile"),
use_pat=bool(state.get("use_pat")),
provider=provider,
provider_models=provider_models,
)
tracing_env_vars = tracing_env(state, "claude")
stop_hook_command = claude_tracing_stop_hook_command() if tracing_env_vars else None
Expand Down
57 changes: 47 additions & 10 deletions src/ucode/agents/codex.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,26 @@ def _use_legacy_layout() -> bool:
return parsed < MINIMUM_CODEX_VERSION


def _provider_block(workspace: str, databricks_profile: str | None, use_pat: bool = False) -> dict:
def _provider_block(
workspace: str,
databricks_profile: str | None,
use_pat: bool = False,
provider: str | None = None,
) -> dict:
auth_argv = build_auth_token_argv(workspace, databricks_profile, use_pat=use_pat)
base_url = build_tool_base_url("codex", workspace)
http_headers = {
"User-Agent": f"ucode/{ucode_version()} codex/{agent_version('codex')}",
}
# Route to an external Model Provider Service; the gateway selects the
# provider from this header on every request.
if provider:
http_headers["Databricks-Model-Provider-Service"] = provider
return {
"name": "Databricks AI Gateway",
"base_url": base_url,
"wire_api": "responses",
"http_headers": {
"User-Agent": f"ucode/{ucode_version()} codex/{agent_version('codex')}",
},
"http_headers": http_headers,
# Run the `ucode auth-token` executable directly (not via `sh -c`) so the
# helper works on Windows, where there is no POSIX shell (issue #116).
"auth": {
Expand All @@ -128,12 +138,15 @@ def render_overlay(
model: str | None = None,
databricks_profile: str | None = None,
use_pat: bool = False,
provider: str | None = None,
) -> dict:
overlay: dict = {"model_provider": CODEX_MODEL_PROVIDER_NAME}
if model:
overlay["model"] = model
overlay["model_providers"] = {
CODEX_MODEL_PROVIDER_NAME: _provider_block(workspace, databricks_profile, use_pat),
CODEX_MODEL_PROVIDER_NAME: _provider_block(
workspace, databricks_profile, use_pat, provider
),
}
return overlay

Expand All @@ -143,6 +156,7 @@ def render_legacy_overlay(
model: str | None = None,
databricks_profile: str | None = None,
use_pat: bool = False,
provider: str | None = None,
) -> dict:
"""Overlay for Codex CLI < 0.134.0, which only reads `~/.codex/config.toml`.

Expand All @@ -156,7 +170,9 @@ def render_legacy_overlay(
"profile": CODEX_PROFILE_NAME,
"profiles": {CODEX_PROFILE_NAME: profile_block},
"model_providers": {
CODEX_MODEL_PROVIDER_NAME: _provider_block(workspace, databricks_profile, use_pat),
CODEX_MODEL_PROVIDER_NAME: _provider_block(
workspace, databricks_profile, use_pat, provider
),
},
}

Expand Down Expand Up @@ -293,9 +309,12 @@ def _parse_gpt(model: str | None) -> tuple[int, int | None, int | None, str] | N
)


def write_tool_config(state: dict, model: str | None = None) -> dict:
def write_tool_config(state: dict, model: str | None = None, provider: str | None = None) -> dict:
workspace = state["workspace"]
chosen_model = _codex_model_id(model or default_model(state))
# With a Model Provider Service the gateway routes by header and Codex sends
# its own canonical model name (e.g. `gpt-5`) — leave `model` unset so no
# Databricks endpoint id is pinned.
chosen_model = None if provider else _codex_model_id(model or default_model(state))
databricks_profile = state.get("profile")

if _use_legacy_layout():
Expand All @@ -305,10 +324,20 @@ def write_tool_config(state: dict, model: str | None = None) -> dict:
# ucode's entry from the shared file.
backup_existing_file(LEGACY_CODEX_CONFIG_PATH, LEGACY_CODEX_BACKUP_PATH)
overlay = render_legacy_overlay(
workspace, chosen_model, databricks_profile, use_pat=bool(state.get("use_pat"))
workspace,
chosen_model,
databricks_profile,
use_pat=bool(state.get("use_pat")),
provider=provider,
)
doc = read_toml_safe(LEGACY_CODEX_CONFIG_PATH)
deep_merge_dict(doc, overlay)
if provider:
# deep_merge can't drop keys, so clear a `model` pinned by an
# earlier non-provider run that the provider overlay omits.
profiles = doc.get("profiles")
if isinstance(profiles, dict) and isinstance(profiles.get(CODEX_PROFILE_NAME), dict):
profiles[CODEX_PROFILE_NAME].pop("model", None)
write_toml_file(LEGACY_CODEX_CONFIG_PATH, doc)
state = mark_tool_managed(state, "codex", LEGACY_MANAGED_KEYS)
save_state(state)
Expand All @@ -317,10 +346,18 @@ def write_tool_config(state: dict, model: str | None = None) -> dict:
_remove_legacy_ucode_profile()
backup_existing_file(CODEX_CONFIG_PATH, CODEX_BACKUP_PATH)
overlay = render_overlay(
workspace, chosen_model, databricks_profile, use_pat=bool(state.get("use_pat"))
workspace,
chosen_model,
databricks_profile,
use_pat=bool(state.get("use_pat")),
provider=provider,
)
doc = read_toml_safe(CODEX_CONFIG_PATH)
deep_merge_dict(doc, overlay)
if provider:
# deep_merge can't drop keys, so clear a `model` pinned by an earlier
# non-provider run that the provider overlay omits.
doc.pop("model", None)
write_toml_file(CODEX_CONFIG_PATH, doc)
state = mark_tool_managed(state, "codex", MANAGED_KEYS)
save_state(state)
Expand Down
Loading
Loading