Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/migrations/agent-trace/015_create_session_models.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS session_models (
id INTEGER PRIMARY KEY,
tool_name TEXT NOT NULL,
session_id TEXT NOT NULL,
model_id TEXT NOT NULL,
model_id TEXT,
tool_version TEXT,
session_start_time_ms INTEGER NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
Expand Down
67 changes: 65 additions & 2 deletions cli/src/services/agent_trace_db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub struct DiffTraceInsert<'a> {
pub struct SessionModelUpsert<'a> {
pub tool_name: &'a str,
pub session_id: &'a str,
pub model_id: &'a str,
pub model_id: Option<&'a str>,
pub tool_version: Option<&'a str>,
pub session_start_time_ms: i64,
}
Expand All @@ -159,7 +159,7 @@ pub struct SessionModelUpsert<'a> {
pub struct SessionModelAttribution {
pub tool_name: String,
pub session_id: String,
pub model_id: String,
pub model_id: Option<String>,
pub tool_version: Option<String>,
pub session_start_time_ms: i64,
}
Expand Down Expand Up @@ -785,6 +785,69 @@ mod tests {
.expect("migration metadata query should succeed")
}

fn session_models_model_id_notnull<M: DbSpec>(db: &TursoDb<M>) -> i64 {
db.query_map("PRAGMA table_info(session_models)", (), |row| {
let name = row.get::<String>(1)?;
let not_null = row.get::<i64>(3)?;
Ok((name, not_null))
})
.expect("session_models table info should load")
.into_iter()
.find_map(|(name, not_null)| (name == "model_id").then_some(not_null))
.expect("session_models.model_id column should exist")
}

#[test]
fn session_model_upsert_and_lookup_round_trip_nullable_and_present_model_ids() {
let db_path = unique_test_db_path();
let db = AgentTraceDb::open_at(&db_path).expect("test DB should open");

assert_eq!(session_models_model_id_notnull(&db), 0);

db.upsert_session_model(SessionModelUpsert {
tool_name: "claude",
session_id: "missing-model-session",
model_id: None,
tool_version: None,
session_start_time_ms: 1_800_000_000_000_i64,
})
.expect("nullable model session upsert should succeed");
db.upsert_session_model(SessionModelUpsert {
tool_name: "claude",
session_id: "model-present-session",
model_id: Some("claude/sonnet-4"),
tool_version: Some("Claude Code 1.2.3"),
session_start_time_ms: 1_800_000_001_000_i64,
})
.expect("model-present session upsert should succeed");

let missing_model = db
.session_model_by_tool_and_session("claude", "missing-model-session")
.expect("nullable model session lookup should succeed")
.expect("nullable model session row should exist");
assert_eq!(missing_model.model_id, None);
assert_eq!(missing_model.tool_version, None);
assert_eq!(missing_model.session_start_time_ms, 1_800_000_000_000_i64);

let model_present = db
.session_model_by_tool_and_session("claude", "model-present-session")
.expect("model-present session lookup should succeed")
.expect("model-present session row should exist");
assert_eq!(
model_present.model_id,
Some(String::from("claude/sonnet-4"))
);
assert_eq!(
model_present.tool_version,
Some(String::from("Claude Code 1.2.3"))
);

drop(db);
if let Some(parent) = db_path.parent() {
fs::remove_dir_all(parent).expect("test DB directory should be removed");
}
}

#[test]
fn recent_diff_trace_patches_applies_bounded_window_ordering_and_parse_accounting() {
let db_path = unique_test_db_path();
Expand Down
37 changes: 16 additions & 21 deletions cli/src/services/hooks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct SessionModelPayload {
#[serde(rename = "sessionID")]
session_id: String,
time: u64,
model_id: String,
model_id: Option<String>,
tool_name: String,
tool_version: Option<String>,
}
Expand Down Expand Up @@ -788,7 +788,7 @@ where
model_id: payload.model_id.clone().or_else(|| {
session_attribution
.as_ref()
.map(|attribution| attribution.model_id.clone())
.and_then(|attribution| attribution.model_id.clone())
}),
tool_version: payload.tool_version.clone().or_else(|| {
session_attribution
Expand Down Expand Up @@ -818,7 +818,6 @@ fn run_session_model_subcommand_from_payload(
logger: Option<&dyn Logger>,
) -> Result<String> {
let payload = parse_session_model_payload(stdin_payload)?;

// Convert the u64 time to i64 for DB storage.
let session_start_time_ms = i64::try_from(payload.time).map_err(|_| {
anyhow!(StdinPayloadKind::SessionModel.validation_error(
Expand All @@ -829,7 +828,7 @@ fn run_session_model_subcommand_from_payload(
let upsert_payload = SessionModelUpsert {
tool_name: &payload.tool_name,
session_id: &payload.session_id,
model_id: &payload.model_id,
model_id: payload.model_id.as_deref(),
tool_version: payload.tool_version.as_deref(),
session_start_time_ms,
};
Expand Down Expand Up @@ -1001,8 +1000,9 @@ where
payload_kind.validation_error(d)
})?;
let time = required_u64_millisecond_field(payload, "time", payload_kind)?;
let model_id =
required_non_empty_string_field(payload, "model_id", |d| payload_kind.validation_error(d))?;
let model_id = Some(required_non_empty_string_field(payload, "model_id", |d| {
payload_kind.validation_error(d)
})?);
let tool_name = required_non_empty_string_field(payload, "tool_name", |d| {
payload_kind.validation_error(d)
})?;
Expand Down Expand Up @@ -1041,7 +1041,7 @@ fn parse_claude_session_model_payload(
}

let session_id = required_claude_session_id(payload, payload_kind)?;
let model_id = required_claude_model_id(payload, payload_kind)?;
let model_id = optional_claude_model_id(payload);
let time = extract_claude_event_time(payload);
let tool_name = "claude".to_string();
let tool_version = extract_claude_tool_version_from_payload(payload).or_else(|| {
Expand Down Expand Up @@ -1076,17 +1076,14 @@ fn required_claude_session_id(
))
}

fn required_claude_model_id(
payload: &serde_json::Map<String, Value>,
payload_kind: StdinPayloadKind,
) -> Result<String> {
fn optional_claude_model_id(payload: &serde_json::Map<String, Value>) -> Option<String> {
// Try direct string fields first.
for key in ["model", "model_id", "modelId"] {
if let Some(value) = payload.get(key) {
if let Some(s) = value.as_str() {
let trimmed = s.trim();
if !trimmed.is_empty() {
return Ok(normalize_claude_model_id(trimmed));
return Some(normalize_claude_model_id(trimmed));
}
}
// If model is an object, try nested identifier fields.
Expand All @@ -1096,7 +1093,7 @@ fn required_claude_model_id(
if let Some(s) = nested_value.as_str() {
let trimmed = s.trim();
if !trimmed.is_empty() {
return Ok(normalize_claude_model_id(trimmed));
return Some(normalize_claude_model_id(trimmed));
}
}
}
Expand All @@ -1105,9 +1102,7 @@ fn required_claude_model_id(
}
}

bail!(payload_kind.validation_error(
"missing non-empty model identifier (model, model_id, or model.id) for Claude SessionStart"
))
None
}

fn normalize_claude_model_id(model: &str) -> String {
Expand Down Expand Up @@ -2821,13 +2816,13 @@ mod tests {
}

fn session_model_attribution(
model_id: &str,
model_id: Option<&str>,
tool_version: Option<&str>,
) -> SessionModelAttribution {
SessionModelAttribution {
tool_name: String::from("claude"),
session_id: String::from("session-123"),
model_id: model_id.to_string(),
model_id: model_id.map(String::from),
tool_version: tool_version.map(String::from),
session_start_time_ms: 1_800_000_000_000_i64,
}
Expand All @@ -2845,7 +2840,7 @@ mod tests {
.expect("Claude SessionStart payload should parse");

assert_eq!(output.session_id, "session-123");
assert_eq!(output.model_id, "claude/sonnet-4");
assert_eq!(output.model_id, Some(String::from("claude/sonnet-4")));
assert_eq!(output.tool_name, "claude");
assert_eq!(output.tool_version, Some(String::from("1.2.3")));
}
Expand Down Expand Up @@ -2897,7 +2892,7 @@ mod tests {
assert_eq!(tool_name, "claude");
assert_eq!(session_id, "session-123");
Ok(Some(session_model_attribution(
"session-model",
Some("session-model"),
Some("Claude Code 1.2.3"),
)))
})
Expand All @@ -2916,7 +2911,7 @@ mod tests {

let resolved = resolve_diff_trace_attribution(&payload, |_tool_name, _session_id| {
Ok(Some(session_model_attribution(
"session-model",
Some("session-model"),
Some("stored-version"),
)))
})
Expand Down
Loading
Loading