diff --git a/java/README.md b/java/README.md index 6a0690343..334db459f 100644 --- a/java/README.md +++ b/java/README.md @@ -165,6 +165,73 @@ public String onlyContext(ToolInvocation invocation) { ... } public String report(@CopilotToolParam("Phase") String phase, ToolInvocation invocation, @CopilotToolParam("Limit") int limit) { ... } ``` +## Inline lambda tool definitions (experimental) + +For inline tool authoring at the session construction site, use `ToolDefinition.from(...)` with explicit parameter metadata: + +```java +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.rpc.ToolDefer; +import com.github.copilot.tool.Param; + +ToolDefinition search = ToolDefinition + .from( + "search_items", + "Searches indexed items by keyword", + Param.of(String.class, "keyword", "Search keyword"), + keyword -> "Searching for: " + keyword) + .skipPermission(true) + .defer(ToolDefer.AUTO); +``` + +### Parameter metadata with `Param.of(...)` + +`Param.of(type, name, description)` creates a required parameter. For optional parameters with defaults: + +```java +Param limit = Param.of(Integer.class, "limit", "Max results", false, "10"); +``` + +### Async handlers + +Use `fromAsync` for asynchronous tool handlers: + +```java +import java.util.concurrent.CompletableFuture; + +ToolDefinition fetchData = ToolDefinition.fromAsync( + "fetch_data", + "Fetches data from remote source", + Param.of(String.class, "url", "Data source URL"), + url -> CompletableFuture.supplyAsync(() -> fetchRemote(url)) +); +``` + +### ToolInvocation context injection + +Inline tools can access `ToolInvocation` runtime context using `fromWithToolInvocation`: + +```java +ToolDefinition reportPhase = ToolDefinition.fromWithToolInvocation( + "report_phase", + "Reports the current phase with invocation context", + Param.of(String.class, "phase", "The current phase"), + (phase, invocation) -> "phase=" + phase + ", toolCallId=" + invocation.getToolCallId() +); +``` + +For async with `ToolInvocation`, use `fromAsyncWithToolInvocation`. + +### Fluent option modifiers + +Chain fluent modifiers to set tool options: + +- `.skipPermission(boolean)` — bypass permission prompts +- `.defer(ToolDefer)` — control deferred execution (`AUTO`, `NEVER`) +- `.overridesBuiltInTool(boolean)` — shadow built-in tools + +For design context and decision rationale, see [ADR-006](docs/adr/adr-006-tool-definition-inline.md). + ## Memory Sessions can opt into persistent memory, allowing the agent to read and write memory across turns. Memory is configured per session and applies to both `createSession` and `resumeSession`. diff --git a/java/docs/adr/adr-006-tool-definition-inline.md b/java/docs/adr/adr-006-tool-definition-inline.md new file mode 100644 index 000000000..ad48527c1 --- /dev/null +++ b/java/docs/adr/adr-006-tool-definition-inline.md @@ -0,0 +1,118 @@ +# ADR-006: Inline tool definition with lambdas + +## Context and problem statement + +[ADR-005](adr-005-tool-definition.md) introduced an ergonomic Java tools API based on `@CopilotTool` method annotations, `@CopilotToolParam` parameter annotations, and `ToolDefinition.fromObject(...)` for reflection-based tool registration. That model works well when teams define tools as methods on a class. + +The next ergonomics goal is an inline style comparable to C# `CopilotTool.DefineTool(...)`, where developers can define a tool at the call site without creating a separate tool container class. + +For this decision, we evaluated two alternatives: + +* Method-reference registration (`ToolDefinition.from(tools::setCurrentPhase)`) +* Inline lambda registration (`ToolDefinition.from(..., phase -> ...)`) + +The key factor is metadata quality: tool name, description, parameter names, parameter descriptions, required/default semantics, and schema stability. + +## Considered options + +### Option 1: Method-reference API + +Example: + +```java +ToolDefinition setPhase = ToolDefinition.from(tools::setCurrentPhase); +``` + +In this model, metadata is sourced from existing method-level annotations (`@CopilotTool`, `@Param`) on the referenced method. + +Advantages: + +* Closest Java analog to C# method-group ergonomics +* High-quality metadata with minimal additional API surface +* Reuses ADR-005 metadata and invocation behavior directly + +Drawbacks: + +* Not truly inline: still requires a declared method (and usually annotations) elsewhere +* Does not solve the "define the whole tool at the call site" use case +* Method-reference resolution adds runtime/reflection complexity + +### Option 2: Inline lambda API with explicit metadata + +Example: + +```java +ToolDefinition setPhase = ToolDefinition.from( + "set_current_phase", + "Sets the current phase of the agent", + Param.of(String.class, "phase", "The phase to transition to"), + (String phase) -> { + currentPhase = phase; + return "Phase set to " + phase; + }); +``` + +In this model, handler logic is inline, and metadata is provided explicitly through `Param.of(...)` parameter definitions. + +Advantages: + +* True inline authoring at the session construction site +* No dependence on lambda parameter-name reflection or `-parameters` +* Deterministic metadata and schema generation +* Independent from annotation processing and generated companion classes + +Drawbacks: + +* Slightly more verbose than method-reference style because metadata is explicit +* Introduces new public API types for parameter definitions and typed lambda overloads +* Requires careful API design to stay concise for common one-parameter tools + +## Decision outcome + +Chosen: **Option 2 for ADR-006 scope** — inline lambda API with explicit metadata. + +Rationale: + +1. The primary requirement for this ADR is inline definition. Option 2 satisfies it directly; Option 1 does not. +1. Metadata quality is the critical requirement. Option 2 keeps metadata explicit and stable, instead of relying on fragile lambda introspection. +1. Option 2 can ship independently of method-reference support and without changes to annotation processing. +1. Option 2 preserves behavior parity with existing tool execution by delegating to `ToolDefinition` construction and current invocation semantics. + +Option 1 remains valuable and can be added independently as a separate ergonomic layer. It is not blocked by this decision. + +## Design constraints and non-goals + +Constraints for the inline lambda API: + +* Require explicit tool name and description. +* Require explicit parameter metadata (at minimum name and type, with optional description/required/default). +* Support both sync and async handlers (`R` and `CompletableFuture`). +* Keep result semantics aligned with existing behavior (`String` passthrough, `void` maps to `"Success"`, non-string objects serialized to JSON). +* Keep override/permission/defer flags available through options, consistent with existing `ToolDefinition` fields. + +Non-goals for this ADR: + +* Replacing `@CopilotTool`/`fromObject` APIs. +* Defining method-reference registration behavior in detail. +* Introducing compile-time code generation for lambda metadata. + +## Consequences + +The SDK now provides an explicit inline path for developers who prefer to keep tool declarations at session creation while preserving high-quality schema metadata. Implemented API families include: + +- `ToolDefinition.from(name, description, [params...], handler)` — sync handlers +- `ToolDefinition.fromAsync(name, description, [params...], asyncHandler)` — async handlers returning `CompletableFuture` +- `ToolDefinition.fromWithToolInvocation(...)` — sync with `ToolInvocation` context injection +- `ToolDefinition.fromAsyncWithToolInvocation(...)` — async with `ToolInvocation` context injection + +Parameter metadata is defined using `Param.of(type, name, description)` for required parameters and `Param.of(type, name, description, required, defaultValue)` for optional parameters with defaults. + +Fluent option modifiers (`.skipPermission(boolean)`, `.defer(ToolDefer)`, `.overridesBuiltInTool(boolean)`) allow post-construction customization. + +The annotation-driven API from [ADR-005](adr-005-tool-definition.md) remains the recommended path for larger tool surfaces where co-locating metadata with method implementations improves maintainability. For usage examples and complete API coverage, see the Java SDK README. + +## Related work items + +* #1682 +* #1792 +* #1810 diff --git a/java/src/main/java/com/github/copilot/rpc/ParamCoercion.java b/java/src/main/java/com/github/copilot/rpc/ParamCoercion.java new file mode 100644 index 000000000..fc8274254 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/ParamCoercion.java @@ -0,0 +1,197 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.tool.Param; + +/** + * Internal runtime helper: coerces raw invocation arguments to the typed values + * declared by {@link Param} descriptors. + * + *

+ * Reuses the SDK-configured {@link ObjectMapper} for complex type conversions, + * matching the coercion policy applied by existing ergonomic tooling. No + * bespoke conversion paths are introduced. + * + *

+ * Package-private: not part of the public API. + */ +class ParamCoercion { + + /** Utility class; do not instantiate. */ + private ParamCoercion() { + } + + /** + * Coerces the named argument from an invocation argument map to the Java type + * declared by {@code param}. + * + *

+ * Resolution order: + *

    + *
  1. If the argument is present, convert it to {@code T} via + * {@link ObjectMapper#convertValue}.
  2. + *
  3. If absent and a default value is set, parse the string default via + * {@link #coerceDefault}.
  4. + *
  5. If absent and the parameter is optional ({@code required=false}), return + * an empty Optional variant or {@code null}.
  6. + *
  7. If absent and required, throw {@link IllegalArgumentException} with the + * parameter name.
  8. + *
+ * + * @param + * the target Java type + * @param args + * the invocation argument map; may be {@code null} for zero-argument + * tools + * @param param + * the parameter descriptor + * @param mapper + * the configured {@link ObjectMapper} for complex type conversion + * @return the coerced argument value + * @throws IllegalArgumentException + * if a required parameter is missing or coercion fails + */ + @SuppressWarnings("unchecked") + static T coerce(Map args, Param param, ObjectMapper mapper) { + Object raw = (args != null) ? args.get(param.name()) : null; + + if (raw == null) { + if (param.hasDefaultValue()) { + return coerceDefault(param, mapper); + } else if (!param.required()) { + return (T) emptyOptionalOrNull(param.type()); + } else { + throw new IllegalArgumentException( + "Required parameter '" + param.name() + "' is missing from tool invocation"); + } + } + + Class type = param.type(); + + // Handle Optional* types explicitly before delegating to ObjectMapper + if (type == java.util.OptionalInt.class) { + try { + return (T) java.util.OptionalInt.of(((Number) raw).intValue()); + } catch (ClassCastException ex) { + throw new IllegalArgumentException("Parameter '" + param.name() + + "' expected a numeric value for OptionalInt, got: " + raw.getClass().getSimpleName(), ex); + } + } + if (type == java.util.OptionalLong.class) { + try { + return (T) java.util.OptionalLong.of(((Number) raw).longValue()); + } catch (ClassCastException ex) { + throw new IllegalArgumentException("Parameter '" + param.name() + + "' expected a numeric value for OptionalLong, got: " + raw.getClass().getSimpleName(), ex); + } + } + if (type == java.util.OptionalDouble.class) { + try { + return (T) java.util.OptionalDouble.of(((Number) raw).doubleValue()); + } catch (ClassCastException ex) { + throw new IllegalArgumentException("Parameter '" + param.name() + + "' expected a numeric value for OptionalDouble, got: " + raw.getClass().getSimpleName(), ex); + } + } + + try { + return mapper.convertValue(raw, type); + } catch (IllegalArgumentException ex) { + throw new IllegalArgumentException( + "Failed to coerce parameter '" + param.name() + "' to type " + type.getSimpleName(), ex); + } + } + + /** + * Parses a {@link Param}'s string default value into the declared Java type. + * + *

+ * Handles primitives, boxed types, {@link String}, {@link Boolean}, and enums + * explicitly, mirroring the validation logic in {@link Param}. The + * {@link ObjectMapper#readValue} fallback exists as a safety net but is not + * expected to be reached in practice, since {@link Param} construction rejects + * defaults for non-primitive/boxed/String/Boolean/enum types. + * + * @param + * the target Java type + * @param param + * the parameter descriptor carrying the default value + * @param mapper + * the configured {@link ObjectMapper} used as fallback for complex + * types + * @return the parsed default value + * @throws IllegalArgumentException + * if parsing fails + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + static T coerceDefault(Param param, ObjectMapper mapper) { + String defaultValue = param.defaultValue(); + Class type = param.type(); + try { + if (type == String.class) { + return type.cast(defaultValue); + } + if (type == Integer.class || type == int.class) { + return (T) Integer.valueOf(defaultValue); + } + if (type == Long.class || type == long.class) { + return (T) Long.valueOf(defaultValue); + } + if (type == Double.class || type == double.class) { + return (T) Double.valueOf(defaultValue); + } + if (type == Float.class || type == float.class) { + return (T) Float.valueOf(defaultValue); + } + if (type == Short.class || type == short.class) { + return (T) Short.valueOf(defaultValue); + } + if (type == Byte.class || type == byte.class) { + return (T) Byte.valueOf(defaultValue); + } + if (type == Boolean.class || type == boolean.class) { + return (T) Boolean.valueOf(defaultValue); + } + if (type.isEnum()) { + Class enumType = (Class) type; + return type.cast(Enum.valueOf(enumType, defaultValue)); + } + // Fallback: let ObjectMapper parse the JSON-encoded default string + return mapper.readValue(defaultValue, type); + } catch (IllegalArgumentException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalArgumentException("Failed to apply default value '" + defaultValue + "' for parameter '" + + param.name() + "' of type " + type.getSimpleName(), ex); + } + } + + /** + * Returns an empty Optional variant for Optional primitive types, or + * {@code null} for all other types. + * + * @param type + * the declared parameter type + * @return {@link java.util.OptionalInt#empty()}, + * {@link java.util.OptionalLong#empty()}, + * {@link java.util.OptionalDouble#empty()}, or {@code null} + */ + static Object emptyOptionalOrNull(Class type) { + if (type == java.util.OptionalInt.class) { + return java.util.OptionalInt.empty(); + } + if (type == java.util.OptionalLong.class) { + return java.util.OptionalLong.empty(); + } + if (type == java.util.OptionalDouble.class) { + return java.util.OptionalDouble.empty(); + } + return null; + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/ParamSchema.java b/java/src/main/java/com/github/copilot/rpc/ParamSchema.java new file mode 100644 index 000000000..ee025eb2c --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/ParamSchema.java @@ -0,0 +1,190 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.tool.Param; + +/** + * Internal runtime helper: maps {@link Param} metadata to JSON Schema + * {@code Map} objects. + * + *

+ * This class is a simplified runtime counterpart to the compile-time + * {@code SchemaGenerator}. It operates on {@code java.lang.reflect.Class} + * values instead of {@code javax.lang.model} mirrors, and produces {@link Map} + * instances rather than Java source-code literals. Unlike + * {@code SchemaGenerator}, it does not inspect generics or object members + * (records/POJOs) and therefore produces flat type mappings only (no + * {@code additionalProperties} or nested object {@code properties}). It does + * produce {@code items} for plain Java arrays via component-type recursion. + * + *

+ * Package-private: not part of the public API. + */ +class ParamSchema { + + /** Utility class; do not instantiate. */ + private ParamSchema() { + } + + /** + * Builds a JSON Schema {@code Map} from zero or more {@link Param} descriptors. + * + *

+ * Validation applied: + *

    + *
  • Each {@link Param} must be non-null.
  • + *
  • Parameter names must be unique; duplicates throw + * {@link IllegalArgumentException} with the tool name and duplicate name.
  • + *
+ * + * @param toolName + * the tool name, included in exception messages for clarity + * @param mapper + * the configured {@link ObjectMapper} used to coerce default values + * into their typed form for the schema + * @param params + * zero or more parameter descriptors + * @return a JSON Schema object map with {@code type=object}, + * {@code properties}, and {@code required} keys + * @throws IllegalArgumentException + * if a null param or duplicate parameter names are found + */ + static Map buildSchema(String toolName, ObjectMapper mapper, Param... params) { + if (params == null || params.length == 0) { + return Map.of("type", "object", "properties", Map.of(), "required", List.of()); + } + + // Validate: no null params, no duplicate names + Set seen = new HashSet<>(); + for (Param param : params) { + if (param == null) { + throw new IllegalArgumentException("A Param descriptor is null for tool '" + toolName + "'"); + } + if (!seen.add(param.name())) { + throw new IllegalArgumentException( + "Duplicate parameter name '" + param.name() + "' in tool '" + toolName + "'"); + } + } + + List requiredNames = new ArrayList<>(); + Map properties = new LinkedHashMap<>(); + + for (Param param : params) { + Map typeSchema = forType(param.type()); + Map enriched = new LinkedHashMap<>(typeSchema); + enriched.put("description", param.description()); + if (param.hasDefaultValue()) { + enriched.put("default", ParamCoercion.coerceDefault(param, mapper)); + } + properties.put(param.name(), Collections.unmodifiableMap(enriched)); + if (param.required()) { + requiredNames.add(param.name()); + } + } + + return Map.of("type", "object", "properties", Collections.unmodifiableMap(properties), "required", + Collections.unmodifiableList(requiredNames)); + } + + /** + * Maps a Java {@link Class} to a flat JSON Schema type descriptor. + * + *

+ * Covers primitives, boxed types, strings, UUIDs, date-time types, enums, + * collections, arrays, and maps. Does not resolve generic type parameters (e.g. + * {@code List} item schemas or {@code Map} additionalProperties) — + * those require the compile-time {@code SchemaGenerator} which operates on + * {@code TypeMirror}. + * + * @param type + * the Java type to map + * @return a JSON Schema type map (e.g. {@code Map.of("type", "string")}) + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + static Map forType(Class type) { + // Integer types + if (type == int.class || type == Integer.class || type == long.class || type == Long.class || type == byte.class + || type == Byte.class || type == short.class || type == Short.class) { + return Map.of("type", "integer"); + } + // Floating-point types + if (type == double.class || type == Double.class || type == float.class || type == Float.class) { + return Map.of("type", "number"); + } + // Boolean + if (type == boolean.class || type == Boolean.class) { + return Map.of("type", "boolean"); + } + // Char → string + if (type == char.class || type == Character.class) { + return Map.of("type", "string"); + } + // String + if (type == String.class) { + return Map.of("type", "string"); + } + // UUID + if (type == java.util.UUID.class) { + return Map.of("type", "string", "format", "uuid"); + } + // Optional primitive types + if (type == java.util.OptionalInt.class || type == java.util.OptionalLong.class) { + return Map.of("type", "integer"); + } + if (type == java.util.OptionalDouble.class) { + return Map.of("type", "number"); + } + // Date-time types + if (type == java.time.OffsetDateTime.class || type == java.time.LocalDateTime.class + || type == java.time.Instant.class || type == java.time.ZonedDateTime.class) { + return Map.of("type", "string", "format", "date-time"); + } + if (type == java.time.LocalDate.class) { + return Map.of("type", "string", "format", "date"); + } + if (type == java.time.LocalTime.class) { + return Map.of("type", "string", "format", "time"); + } + // JsonNode / Object → any (no type constraint) + if (type == com.fasterxml.jackson.databind.JsonNode.class || type == Object.class) { + return Map.of(); + } + // Enum types + if (type.isEnum()) { + Class enumType = (Class) type; + List constants = Arrays.stream(enumType.getEnumConstants()).map(Enum::name) + .collect(Collectors.toList()); + return Map.of("type", "string", "enum", Collections.unmodifiableList(constants)); + } + // List / Collection / Set → array (raw element type) + if (java.util.List.class.isAssignableFrom(type) || java.util.Collection.class.isAssignableFrom(type) + || java.util.Set.class.isAssignableFrom(type)) { + return Map.of("type", "array"); + } + // Plain array → array with items schema derived from component type + if (type.isArray()) { + Map itemsSchema = forType(type.getComponentType()); + return Map.of("type", "array", "items", itemsSchema); + } + // Map → object + if (java.util.Map.class.isAssignableFrom(type)) { + return Map.of("type", "object"); + } + // POJO / record → object + return Map.of("type", "object"); + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java b/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java index b3fa2bc53..8a336c749 100644 --- a/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java +++ b/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java @@ -9,6 +9,10 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -19,6 +23,7 @@ import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.github.copilot.CopilotExperimental; +import com.github.copilot.tool.Param; /** * Defines a tool that can be invoked by the AI assistant. @@ -186,7 +191,7 @@ public static ToolDefinition createWithDefer(String name, String description, Ma * @throws IllegalStateException * if the generated {@code $$CopilotToolMeta} class is not found * (annotation processor did not run) - * @since 1.0.2 + * @since 1.0.6 */ @CopilotExperimental public static List fromObject(Object instance) { @@ -209,7 +214,7 @@ public static List fromObject(Object instance) { * @throws IllegalStateException * if the generated {@code $$CopilotToolMeta} class is not found * (annotation processor did not run) - * @since 1.0.2 + * @since 1.0.6 */ @CopilotExperimental public static List fromClass(Class clazz) { @@ -227,6 +232,570 @@ public static List fromClass(Class clazz) { return loadDefinitions(clazz, null); } + // ------------------------------------------------------------------ + // Fluent copy-style modifier methods for lambda-defined tools + // ------------------------------------------------------------------ + + /** + * Returns a copy with the {@code overridesBuiltInTool} flag set. + * + * @param value + * {@code true} to indicate this tool intentionally overrides a + * built-in CLI tool with the same name + * @return a new {@code ToolDefinition} with the flag applied + * @since 1.0.6 + */ + @CopilotExperimental + public ToolDefinition overridesBuiltInTool(boolean value) { + return new ToolDefinition(name, description, parameters, handler, value, skipPermission, defer); + } + + /** + * Returns a copy with the {@code skipPermission} flag set. + * + * @param value + * {@code true} to skip the permission request for this tool + * invocation + * @return a new {@code ToolDefinition} with the flag applied + * @since 1.0.6 + */ + @CopilotExperimental + public ToolDefinition skipPermission(boolean value) { + return new ToolDefinition(name, description, parameters, handler, overridesBuiltInTool, value, defer); + } + + /** + * Returns a copy with the {@code defer} mode set. + * + * @param value + * the deferral mode; use {@link ToolDefer#AUTO} to allow deferral or + * {@link ToolDefer#NEVER} to force the tool to always be pre-loaded + * @return a new {@code ToolDefinition} with the defer mode applied + * @since 1.0.6 + */ + @CopilotExperimental + public ToolDefinition defer(ToolDefer value) { + return new ToolDefinition(name, description, parameters, handler, overridesBuiltInTool, skipPermission, value); + } + + // ------------------------------------------------------------------ + // from(...) — sync, no ToolInvocation + // ------------------------------------------------------------------ + + /** + * Creates a tool definition with a zero-argument synchronous handler. + * + *

+ * The handler is a {@link Supplier} that returns the tool result. + * + *

Example

+ * + *
{@code
+     * ToolDefinition ping = ToolDefinition.from("ping", "Returns a simple pong response", () -> "pong");
+     * }
+ * + * @param + * the return type of the handler + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param handler + * the zero-argument sync handler + * @return a new tool definition + * @throws IllegalArgumentException + * if {@code name} or {@code description} is blank, or if + * {@code handler} is null + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition from(String name, String description, Supplier handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper); + ToolHandler toolHandler = invocation -> { + R result = handler.get(); + return CompletableFuture.completedFuture(formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a one-argument synchronous handler. + * + *

Example

+ * + *
{@code
+     * ToolDefinition greet = ToolDefinition.from("greet", "Greets a user by name",
+     * 		Param.of(String.class, "name", "The user's name"), name -> "Hello, " + name + "!");
+     * }
+ * + * @param + * the type of the first parameter + * @param + * the return type of the handler + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param handler + * the one-argument sync handler + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition from(String name, String description, Param p1, Function handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + R result = handler.apply(arg1); + return CompletableFuture.completedFuture(formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a two-argument synchronous handler. + * + *

Example

+ * + *
{@code
+     * ToolDefinition add = ToolDefinition.from("add", "Adds two integers", Param.of(Integer.class, "a", "First number"),
+     * 		Param.of(Integer.class, "b", "Second number"), (a, b) -> a + b);
+     * }
+ * + * @param + * the type of the first parameter + * @param + * the type of the second parameter + * @param + * the return type of the handler + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param p2 + * the second parameter descriptor + * @param handler + * the two-argument sync handler + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition from(String name, String description, Param p1, Param p2, + BiFunction handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1, p2); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + T2 arg2 = ParamCoercion.coerce(invocation.getArguments(), p2, mapper); + R result = handler.apply(arg1, arg2); + return CompletableFuture.completedFuture(formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + // ------------------------------------------------------------------ + // fromAsync(...) — async, no ToolInvocation + // ------------------------------------------------------------------ + + /** + * Creates a tool definition with a zero-argument asynchronous handler. + * + *

+ * The handler is a {@link Supplier} returning a {@link CompletableFuture}. + * + *

Example

+ * + *
{@code
+     * ToolDefinition ping = ToolDefinition.fromAsync("ping", "Returns a pong response asynchronously",
+     * 		() -> CompletableFuture.completedFuture("pong"));
+     * }
+ * + * @param + * the return type wrapped in {@link CompletableFuture} + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param handler + * the zero-argument async handler + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromAsync(String name, String description, + Supplier> handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper); + ToolHandler toolHandler = invocation -> { + CompletableFuture future = handler.get(); + if (future == null) { + return CompletableFuture.failedFuture( + new NullPointerException("Async handler for tool '" + name + "' returned a null future")); + } + return future.thenApply(result -> formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a one-argument asynchronous handler. + * + *

Example

+ * + *
{@code
+     * ToolDefinition greet = ToolDefinition.fromAsync("greet_async", "Greets a user by name asynchronously",
+     * 		Param.of(String.class, "name", "The user's name"),
+     * 		name -> CompletableFuture.completedFuture("Hello, " + name + "!"));
+     * }
+ * + * @param + * the type of the first parameter + * @param + * the return type wrapped in {@link CompletableFuture} + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param handler + * the one-argument async handler + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromAsync(String name, String description, Param p1, + Function> handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + CompletableFuture future = handler.apply(arg1); + if (future == null) { + return CompletableFuture.failedFuture( + new NullPointerException("Async handler for tool '" + name + "' returned a null future")); + } + return future.thenApply(result -> formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a two-argument asynchronous handler. + * + * @param + * the type of the first parameter + * @param + * the type of the second parameter + * @param + * the return type wrapped in {@link CompletableFuture} + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param p2 + * the second parameter descriptor + * @param handler + * the two-argument async handler + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromAsync(String name, String description, Param p1, Param p2, + BiFunction> handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1, p2); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + T2 arg2 = ParamCoercion.coerce(invocation.getArguments(), p2, mapper); + CompletableFuture future = handler.apply(arg1, arg2); + if (future == null) { + return CompletableFuture.failedFuture( + new NullPointerException("Async handler for tool '" + name + "' returned a null future")); + } + return future.thenApply(result -> formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + // ------------------------------------------------------------------ + // fromWithToolInvocation(...) — sync, with ToolInvocation context + // ------------------------------------------------------------------ + + /** + * Creates a tool definition with a zero-argument synchronous handler that + * receives the {@link ToolInvocation} context. + * + *

Example

+ * + *
{@code
+     * ToolDefinition sessionInfo = ToolDefinition.fromWithToolInvocation("session_info", "Return the current session id",
+     * 		invocation -> "sessionId=" + invocation.getSessionId());
+     * }
+ * + * @param + * the return type of the handler + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param handler + * a function accepting the {@link ToolInvocation} context + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromWithToolInvocation(String name, String description, + Function handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper); + ToolHandler toolHandler = invocation -> { + R result = handler.apply(invocation); + return CompletableFuture.completedFuture(formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a one-argument synchronous handler that also + * receives the {@link ToolInvocation} context. + * + *

Example

+ * + *
{@code
+     * ToolDefinition reportPhase = ToolDefinition.fromWithToolInvocation("report_phase",
+     * 		"Report the current phase along with invocation context", Param.of(String.class, "phase", "Current phase"),
+     * 		(phase, invocation) -> "phase=" + phase + ", toolCallId=" + invocation.getToolCallId());
+     * }
+ * + * @param + * the type of the first parameter + * @param + * the return type of the handler + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param handler + * a function accepting the typed argument and the + * {@link ToolInvocation} context + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromWithToolInvocation(String name, String description, Param p1, + BiFunction handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + R result = handler.apply(arg1, invocation); + return CompletableFuture.completedFuture(formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + // ------------------------------------------------------------------ + // fromAsyncWithToolInvocation(...) — async, with ToolInvocation context + // ------------------------------------------------------------------ + + /** + * Creates a tool definition with a zero-argument asynchronous handler that + * receives the {@link ToolInvocation} context. + * + *

Example

+ * + *
{@code
+     * ToolDefinition sessionInfo = ToolDefinition.fromAsyncWithToolInvocation("session_info_async",
+     * 		"Return the current session id asynchronously",
+     * 		invocation -> CompletableFuture.completedFuture("sessionId=" + invocation.getSessionId()));
+     * }
+ * + * @param + * the return type wrapped in {@link CompletableFuture} + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param handler + * a function accepting the {@link ToolInvocation} context, returning + * a {@link CompletableFuture} + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromAsyncWithToolInvocation(String name, String description, + Function> handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper); + ToolHandler toolHandler = invocation -> { + CompletableFuture future = handler.apply(invocation); + if (future == null) { + return CompletableFuture.failedFuture( + new NullPointerException("Async handler for tool '" + name + "' returned a null future")); + } + return future.thenApply(result -> formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + /** + * Creates a tool definition with a one-argument asynchronous handler that also + * receives the {@link ToolInvocation} context. + * + *

Example

+ * + *
{@code
+     * ToolDefinition reportPhase = ToolDefinition.fromAsyncWithToolInvocation("report_phase_async",
+     * 		"Report the current phase with invocation context asynchronously",
+     * 		Param.of(String.class, "phase", "The current phase"), (phase, invocation) -> CompletableFuture
+     * 				.completedFuture("phase=" + phase + ", toolCallId=" + invocation.getToolCallId()));
+     * }
+ * + * @param + * the type of the first parameter + * @param + * the return type wrapped in {@link CompletableFuture} + * @param name + * the unique name of the tool (must not be blank) + * @param description + * a description of what the tool does (must not be blank) + * @param p1 + * the first parameter descriptor + * @param handler + * a function accepting the typed argument and the + * {@link ToolInvocation} context, returning a + * {@link CompletableFuture} + * @return a new tool definition + * @throws IllegalArgumentException + * if validation fails + * @since 1.0.6 + */ + @CopilotExperimental + public static ToolDefinition fromAsyncWithToolInvocation(String name, String description, Param p1, + BiFunction> handler) { + requireNonBlankToolName(name); + requireNonBlankDescription(description); + requireNonNullHandler(handler, name); + final ObjectMapper mapper = getConfiguredMapper(); + Map schema = ParamSchema.buildSchema(name, mapper, p1); + ToolHandler toolHandler = invocation -> { + T1 arg1 = ParamCoercion.coerce(invocation.getArguments(), p1, mapper); + CompletableFuture future = handler.apply(arg1, invocation); + if (future == null) { + return CompletableFuture.failedFuture( + new NullPointerException("Async handler for tool '" + name + "' returned a null future")); + } + return future.thenApply(result -> formatResult(result, mapper)); + }; + return new ToolDefinition(name, description, schema, toolHandler, null, null, null); + } + + // ------------------------------------------------------------------ + // Internal helpers: result formatting, validation + // ------------------------------------------------------------------ + + /** + * Formats a handler return value according to the tool result contract: + *
    + *
  • {@link String} — returned as-is
  • + *
  • {@code null} — mapped to {@code "Success"} (covers handlers that return + * null to indicate a successful no-value result)
  • + *
  • any other value — JSON-serialized via {@link ObjectMapper}
  • + *
+ */ + private static Object formatResult(Object result, ObjectMapper mapper) { + if (result == null) { + return "Success"; + } + if (result instanceof String) { + return result; + } + if (result instanceof ToolResultObject) { + return result; + } + try { + return mapper.writeValueAsString(result); + } catch (com.fasterxml.jackson.core.JsonProcessingException ex) { + throw new IllegalStateException("Failed to serialize tool result to JSON", ex); + } + } + + // ------------------------------------------------------------------ + // Validation helpers + // ------------------------------------------------------------------ + + private static void requireNonBlankToolName(String name) { + if (name == null || name.isBlank()) { + throw new IllegalArgumentException("Tool name must not be null or blank"); + } + } + + private static void requireNonBlankDescription(String description) { + if (description == null || description.isBlank()) { + throw new IllegalArgumentException("Tool description must not be null or blank"); + } + } + + private static void requireNonNullHandler(Object handler, String toolName) { + if (handler == null) { + throw new IllegalArgumentException("handler must not be null for tool '" + toolName + "'"); + } + } + @SuppressWarnings("unchecked") private static List loadDefinitions(Class clazz, Object instance) { String metaClassName = clazz.getName() + "$$CopilotToolMeta"; diff --git a/java/src/main/java/com/github/copilot/tool/Param.java b/java/src/main/java/com/github/copilot/tool/Param.java new file mode 100644 index 000000000..bbe188ce0 --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/Param.java @@ -0,0 +1,261 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.util.Objects; + +import com.github.copilot.CopilotExperimental; + +/** + * Runtime parameter metadata for lambda-defined tools. + * + *

+ * Each {@code Param} instance describes a single parameter that a tool accepts, + * including its Java type, wire name, description, whether it is required, and + * an optional default value. Instances are immutable; fluent mutators return + * new copies. + * + *

Example Usage

+ * + *
{@code
+ * Param query = Param.of(String.class, "query", "Search query text");
+ *
+ * Param limit = Param.of(Integer.class, "limit", "Max results", false, "10");
+ * }
+ * + * @param + * the Java type of the parameter value + * @since 1.0.6 + */ +@CopilotExperimental +public final class Param { + + private final Class type; + private final String name; + private final String description; + private final boolean required; + private final String defaultValue; + + private Param(Class type, String name, String description, boolean required, String defaultValue) { + this.type = Objects.requireNonNull(type, "type"); + this.name = requireNonBlank(name, "name"); + this.description = requireNonBlank(description, "description"); + this.defaultValue = defaultValue == null ? "" : defaultValue; + this.required = required; + + if (this.required && !this.defaultValue.isEmpty()) { + throw new IllegalArgumentException("required=true cannot be combined with a non-empty defaultValue"); + } + + validateDefaultValue(type, this.defaultValue); + } + + /** + * Creates a required parameter with no default value. + * + * @param + * the parameter type + * @param type + * the Java class of the parameter + * @param name + * the wire name sent to the model (must not be blank) + * @param description + * a human-readable description (must not be blank) + * @return a new {@code Param} instance + * @throws NullPointerException + * if {@code type} is null + * @throws IllegalArgumentException + * if {@code name} or {@code description} is blank + */ + public static Param of(Class type, String name, String description) { + return new Param<>(type, name, description, true, ""); + } + + /** + * Creates a parameter with explicit required/default settings. + * + * @param + * the parameter type + * @param type + * the Java class of the parameter + * @param name + * the wire name sent to the model (must not be blank) + * @param description + * a human-readable description (must not be blank) + * @param required + * whether the parameter is required + * @param defaultValue + * the default value as a string, or {@code null}/empty for none + * @return a new {@code Param} instance + * @throws NullPointerException + * if {@code type} is null + * @throws IllegalArgumentException + * if validation fails + */ + public static Param of(Class type, String name, String description, boolean required, + String defaultValue) { + return new Param<>(type, name, description, required, defaultValue); + } + + /** + * Returns a copy with a different name. + * + * @param name + * the new parameter name + * @return a new {@code Param} with the updated name + */ + public Param name(String name) { + return new Param<>(this.type, name, this.description, this.required, this.defaultValue); + } + + /** + * Returns a copy with a different description. + * + * @param description + * the new description + * @return a new {@code Param} with the updated description + */ + public Param description(String description) { + return new Param<>(this.type, this.name, description, this.required, this.defaultValue); + } + + /** + * Returns a copy with a different required flag. + * + * @param required + * whether the parameter is required + * @return a new {@code Param} with the updated required flag + */ + public Param required(boolean required) { + return new Param<>(this.type, this.name, this.description, required, this.defaultValue); + } + + /** + * Returns an optional copy with the given default value. Setting a default + * implicitly makes the parameter optional ({@code required=false}). + * + * @param defaultValue + * the default value as a string + * @return a new {@code Param} with the default applied and required set to + * false + */ + public Param defaultValue(String defaultValue) { + return new Param<>(this.type, this.name, this.description, false, defaultValue); + } + + /** Returns the Java type of this parameter. */ + public Class type() { + return type; + } + + /** Returns the wire name of this parameter. */ + public String name() { + return name; + } + + /** Returns the human-readable description. */ + public String description() { + return description; + } + + /** Returns whether this parameter is required. */ + public boolean required() { + return required; + } + + /** Returns the default value string, or empty if none. */ + public String defaultValue() { + return defaultValue; + } + + /** Returns {@code true} if a non-empty default value is set. */ + public boolean hasDefaultValue() { + return !defaultValue.isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Param other)) { + return false; + } + return required == other.required && Objects.equals(type, other.type) && Objects.equals(name, other.name) + && Objects.equals(description, other.description) && Objects.equals(defaultValue, other.defaultValue); + } + + @Override + public int hashCode() { + return Objects.hash(type, name, description, required, defaultValue); + } + + @Override + public String toString() { + return "Param[name=" + name + ", type=" + type.getSimpleName() + ", required=" + required + "]"; + } + + // ------------------------------------------------------------------ + // Internal validation helpers + // ------------------------------------------------------------------ + + private static String requireNonBlank(String value, String fieldName) { + if (value == null || value.isBlank()) { + throw new IllegalArgumentException(fieldName + " must not be null or blank"); + } + return value; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static void validateDefaultValue(Class type, String defaultValue) { + if (defaultValue == null || defaultValue.isEmpty()) { + return; + } + + try { + if (type == String.class) { + return; + } + if (type == Integer.class || type == int.class) { + Integer.parseInt(defaultValue); + return; + } + if (type == Long.class || type == long.class) { + Long.parseLong(defaultValue); + return; + } + if (type == Double.class || type == double.class) { + Double.parseDouble(defaultValue); + return; + } + if (type == Float.class || type == float.class) { + Float.parseFloat(defaultValue); + return; + } + if (type == Short.class || type == short.class) { + Short.parseShort(defaultValue); + return; + } + if (type == Byte.class || type == byte.class) { + Byte.parseByte(defaultValue); + return; + } + if (type == Boolean.class || type == boolean.class) { + if (!"true".equalsIgnoreCase(defaultValue) && !"false".equalsIgnoreCase(defaultValue)) { + throw new IllegalArgumentException("must be 'true' or 'false'"); + } + return; + } + if (type.isEnum()) { + Class enumType = (Class) type; + Enum.valueOf(enumType, defaultValue); + return; + } + } catch (RuntimeException ex) { + throw new IllegalArgumentException( + "defaultValue '" + defaultValue + "' is not valid for type " + type.getSimpleName(), ex); + } + + throw new IllegalArgumentException( + "defaultValue is not supported for type " + type.getName() + " without a custom coercion policy"); + } +} diff --git a/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java b/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java index c74e94544..df031f354 100644 --- a/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java +++ b/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java @@ -23,6 +23,7 @@ import com.github.copilot.rpc.SessionConfig; import com.github.copilot.rpc.ToolDefinition; import com.github.copilot.rpc.ToolSet; +import com.github.copilot.tool.Param; /** * Failsafe integration test for the ergonomic {@code @CopilotTool} + @@ -82,4 +83,49 @@ void ergonomicToolDefinition() throws Exception { } } } + + @Test + void lambdaToolDefinition() throws Exception { + ctx.configureForTest("tools", "ergonomic_tool_definition"); + + class LambdaTools { + String currentPhase; + } + LambdaTools tools = new LambdaTools(); + + ToolDefinition setCurrentPhase = ToolDefinition.from("set_current_phase", "Sets the current phase of the agent", + Param.of(String.class, "phase", "The phase to transition to"), phase -> { + tools.currentPhase = phase; + return "Phase set to " + phase; + }); + + ToolDefinition searchItems = ToolDefinition.from("search_items", "Search for items by keyword", + Param.of(String.class, "keyword", "Search keyword"), + keyword -> "Found: " + keyword + " -> item_alpha, item_beta"); + + try (CopilotClient client = ctx.createClient()) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setAvailableTools(new ToolSet().addCustom("*").addBuiltIn("web_fetch")) + .setTools(List.of(setCurrentPhase, searchItems))) + .get(30, TimeUnit.SECONDS); + + try { + AssistantMessageEvent response = session.sendAndWait(new MessageOptions().setPrompt( + "First, set the current phase to 'analyzing'. Then search for items with keyword 'copilot'. Report the phase and search results."), + 60_000).get(90, TimeUnit.SECONDS); + + assertNotNull(response, "Expected a response from the assistant"); + String content = response.getData().content().toLowerCase(); + assertTrue(content.contains("analyzing"), + "Response should contain the updated phase: " + response.getData().content()); + assertTrue(content.contains("item_alpha") || content.contains("item_beta"), + "Response should contain search results: " + response.getData().content()); + assertTrue("analyzing".equals(tools.currentPhase), + "Expected currentPhase to be 'analyzing' but was: " + tools.currentPhase); + } finally { + session.close(); + } + } + } } diff --git a/java/src/test/java/com/github/copilot/rpc/ParamCoercionTest.java b/java/src/test/java/com/github/copilot/rpc/ParamCoercionTest.java new file mode 100644 index 000000000..8ad4ee830 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/ParamCoercionTest.java @@ -0,0 +1,362 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.tool.Param; + +/** + * Unit tests for {@link ParamCoercion} — runtime argument coercion from raw + * invocation maps to typed Java values declared by {@link Param} descriptors. + */ +class ParamCoercionTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + // ── coerce: present argument, simple types ─────────────────────────────────── + + @Test + void coerce_stringArg_passedThrough() { + Param p = Param.of(String.class, "msg", "A message"); + String result = ParamCoercion.coerce(Map.of("msg", "hello"), p, MAPPER); + assertEquals("hello", result); + } + + @Test + void coerce_integerArgFromNumber() { + Param p = Param.of(Integer.class, "n", "A number"); + Integer result = ParamCoercion.coerce(Map.of("n", 42), p, MAPPER); + assertEquals(42, result); + } + + @Test + void coerce_longArgFromNumber() { + Param p = Param.of(Long.class, "id", "An identifier"); + Long result = ParamCoercion.coerce(Map.of("id", 123456789L), p, MAPPER); + assertEquals(123456789L, result); + } + + @Test + void coerce_doubleArgFromNumber() { + Param p = Param.of(Double.class, "price", "A price"); + Double result = ParamCoercion.coerce(Map.of("price", 19.99), p, MAPPER); + assertEquals(19.99, result, 0.001); + } + + @Test + void coerce_floatArgFromNumber() { + Param p = Param.of(Float.class, "rate", "A rate"); + Float result = ParamCoercion.coerce(Map.of("rate", 3.14), p, MAPPER); + assertEquals(3.14f, result, 0.01f); + } + + @Test + void coerce_booleanArgFromBoolean() { + Param p = Param.of(Boolean.class, "flag", "A flag"); + Boolean result = ParamCoercion.coerce(Map.of("flag", true), p, MAPPER); + assertEquals(true, result); + } + + // Note: enum coercion via mapper.convertValue requires the enum's package to be + // opened to com.fasterxml.jackson.databind. In the SDK module, + // com.github.copilot.tool + // is not opened to Jackson (only com.github.copilot.rpc is). User-defined enums + // will + // be outside the SDK module and fully accessible. Enum default coercion is + // tested via + // coerceDefault_enum which uses Enum.valueOf directly. + + @Test + void coerce_enumFromString_viaCoerceDefault() { + Param p = Param.of(TestMode.class, "mode", "Mode", false, "FAST"); + TestMode result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(TestMode.FAST, result); + } + + // ── coerce: Optional primitive types ───────────────────────────────────────── + + @Test + void coerce_optionalInt_fromNumber() { + Param p = Param.of(OptionalInt.class, "count", "Count", false, ""); + OptionalInt result = ParamCoercion.coerce(Map.of("count", 7), p, MAPPER); + assertEquals(OptionalInt.of(7), result); + } + + @Test + void coerce_optionalLong_fromNumber() { + Param p = Param.of(OptionalLong.class, "ts", "Timestamp", false, ""); + OptionalLong result = ParamCoercion.coerce(Map.of("ts", 999L), p, MAPPER); + assertEquals(OptionalLong.of(999L), result); + } + + @Test + void coerce_optionalDouble_fromNumber() { + Param p = Param.of(OptionalDouble.class, "ratio", "Ratio", false, ""); + OptionalDouble result = ParamCoercion.coerce(Map.of("ratio", 2.5), p, MAPPER); + assertEquals(OptionalDouble.of(2.5), result); + } + + @Test + void coerce_optionalInt_nonNumeric_throwsIllegalArgument() { + Param p = Param.of(OptionalInt.class, "count", "Count", false, ""); + assertThrows(IllegalArgumentException.class, + () -> ParamCoercion.coerce(Map.of("count", "not_a_number"), p, MAPPER)); + } + + @Test + void coerce_optionalLong_nonNumeric_throwsIllegalArgument() { + Param p = Param.of(OptionalLong.class, "ts", "Timestamp", false, ""); + assertThrows(IllegalArgumentException.class, () -> ParamCoercion.coerce(Map.of("ts", "abc"), p, MAPPER)); + } + + @Test + void coerce_optionalDouble_nonNumeric_throwsIllegalArgument() { + Param p = Param.of(OptionalDouble.class, "ratio", "Ratio", false, ""); + assertThrows(IllegalArgumentException.class, () -> ParamCoercion.coerce(Map.of("ratio", "xyz"), p, MAPPER)); + } + + // ── coerce: missing argument — required ────────────────────────────────────── + + @Test + void coerce_requiredMissing_throwsWithParamName() { + Param p = Param.of(String.class, "query", "Search query"); + var ex = assertThrows(IllegalArgumentException.class, () -> ParamCoercion.coerce(Map.of(), p, MAPPER)); + assertTrue(ex.getMessage().contains("query")); + } + + @Test + void coerce_requiredMissing_nullArgs_throws() { + Param p = Param.of(String.class, "name", "A name"); + var ex = assertThrows(IllegalArgumentException.class, () -> ParamCoercion.coerce(null, p, MAPPER)); + assertTrue(ex.getMessage().contains("name")); + } + + // ── coerce: missing argument — optional with default ───────────────────────── + + @Test + void coerce_optionalWithStringDefault_usesDefault() { + Param p = Param.of(String.class, "mode", "Mode", false, "normal"); + String result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals("normal", result); + } + + @Test + void coerce_optionalWithIntegerDefault_usesDefault() { + Param p = Param.of(Integer.class, "limit", "Limit", false, "25"); + Integer result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(25, result); + } + + @Test + void coerce_optionalWithLongDefault_usesDefault() { + Param p = Param.of(Long.class, "offset", "Offset", false, "100"); + Long result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(100L, result); + } + + @Test + void coerce_optionalWithDoubleDefault_usesDefault() { + Param p = Param.of(Double.class, "threshold", "Threshold", false, "0.75"); + Double result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(0.75, result, 0.001); + } + + @Test + void coerce_optionalWithFloatDefault_usesDefault() { + Param p = Param.of(Float.class, "rate", "Rate", false, "1.5"); + Float result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(1.5f, result, 0.01f); + } + + @Test + void coerce_optionalWithShortDefault_usesDefault() { + Param p = Param.of(Short.class, "level", "Level", false, "3"); + Short result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals((short) 3, result); + } + + @Test + void coerce_optionalWithByteDefault_usesDefault() { + Param p = Param.of(Byte.class, "code", "Code", false, "7"); + Byte result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals((byte) 7, result); + } + + @Test + void coerce_optionalWithBooleanDefault_usesDefault() { + Param p = Param.of(Boolean.class, "verbose", "Verbose", false, "true"); + Boolean result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(true, result); + } + + @Test + void coerce_optionalWithEnumDefault_usesDefault() { + Param p = Param.of(TestMode.class, "mode", "Mode", false, "SLOW"); + TestMode result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(TestMode.SLOW, result); + } + + // ── coerce: missing argument — optional without default ────────────────────── + + @Test + void coerce_optionalNoDefault_returnsNull() { + Param p = Param.of(String.class, "title", "Title", false, ""); + String result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertNull(result); + } + + @Test + void coerce_optionalNoDefault_optionalInt_returnsEmpty() { + Param p = Param.of(OptionalInt.class, "n", "Number", false, ""); + OptionalInt result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(OptionalInt.empty(), result); + } + + @Test + void coerce_optionalNoDefault_optionalLong_returnsEmpty() { + Param p = Param.of(OptionalLong.class, "ts", "Timestamp", false, ""); + OptionalLong result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(OptionalLong.empty(), result); + } + + @Test + void coerce_optionalNoDefault_optionalDouble_returnsEmpty() { + Param p = Param.of(OptionalDouble.class, "ratio", "Ratio", false, ""); + OptionalDouble result = ParamCoercion.coerce(Map.of(), p, MAPPER); + assertEquals(OptionalDouble.empty(), result); + } + + // ── coerce: type conversion via ObjectMapper ───────────────────────────────── + + @Test + void coerce_integerFromStringViaMapper() { + // ObjectMapper can convert "42" string to Integer + Param p = Param.of(Integer.class, "n", "A number"); + Integer result = ParamCoercion.coerce(Map.of("n", "42"), p, MAPPER); + assertEquals(42, result); + } + + @Test + void coerce_booleanFromStringViaMapper() { + Param p = Param.of(Boolean.class, "flag", "A flag"); + Boolean result = ParamCoercion.coerce(Map.of("flag", "true"), p, MAPPER); + assertEquals(true, result); + } + + @Test + void coerce_incompatibleType_throwsWithParamName() { + Param p = Param.of(Integer.class, "count", "Count"); + var ex = assertThrows(IllegalArgumentException.class, + () -> ParamCoercion.coerce(Map.of("count", "not_a_number"), p, MAPPER)); + assertTrue(ex.getMessage().contains("count")); + } + + // ── coerceDefault: direct tests ────────────────────────────────────────────── + + @Test + void coerceDefault_string() { + Param p = Param.of(String.class, "s", "A string", false, "hello"); + assertEquals("hello", ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_integer() { + Param p = Param.of(Integer.class, "n", "A num", false, "99"); + assertEquals(99, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_long() { + Param p = Param.of(Long.class, "id", "An id", false, "12345"); + assertEquals(12345L, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_double() { + Param p = Param.of(Double.class, "d", "A double", false, "3.14"); + assertEquals(3.14, ParamCoercion.coerceDefault(p, MAPPER), 0.001); + } + + @Test + void coerceDefault_float() { + Param p = Param.of(Float.class, "f", "A float", false, "2.5"); + assertEquals(2.5f, ParamCoercion.coerceDefault(p, MAPPER), 0.01f); + } + + @Test + void coerceDefault_short() { + Param p = Param.of(Short.class, "s", "A short", false, "10"); + assertEquals((short) 10, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_byte() { + Param p = Param.of(Byte.class, "b", "A byte", false, "5"); + assertEquals((byte) 5, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_booleanTrue() { + Param p = Param.of(Boolean.class, "v", "Verbose", false, "true"); + assertEquals(true, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_booleanFalse() { + Param p = Param.of(Boolean.class, "v", "Verbose", false, "false"); + assertEquals(false, ParamCoercion.coerceDefault(p, MAPPER)); + } + + @Test + void coerceDefault_enum() { + Param p = Param.of(TestMode.class, "m", "Mode", false, "FAST"); + assertEquals(TestMode.FAST, ParamCoercion.coerceDefault(p, MAPPER)); + } + + // ── emptyOptionalOrNull: direct tests ──────────────────────────────────────── + + @Test + void emptyOptionalOrNull_optionalInt_returnsEmpty() { + assertEquals(OptionalInt.empty(), ParamCoercion.emptyOptionalOrNull(OptionalInt.class)); + } + + @Test + void emptyOptionalOrNull_optionalLong_returnsEmpty() { + assertEquals(OptionalLong.empty(), ParamCoercion.emptyOptionalOrNull(OptionalLong.class)); + } + + @Test + void emptyOptionalOrNull_optionalDouble_returnsEmpty() { + assertEquals(OptionalDouble.empty(), ParamCoercion.emptyOptionalOrNull(OptionalDouble.class)); + } + + @Test + void emptyOptionalOrNull_string_returnsNull() { + assertNull(ParamCoercion.emptyOptionalOrNull(String.class)); + } + + @Test + void emptyOptionalOrNull_integer_returnsNull() { + assertNull(ParamCoercion.emptyOptionalOrNull(Integer.class)); + } + + // ── Test helper types ──────────────────────────────────────────────────────── + + enum TestMode { + FAST, SLOW, NORMAL + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/ParamSchemaTest.java b/java/src/test/java/com/github/copilot/rpc/ParamSchemaTest.java new file mode 100644 index 000000000..27d76a91a --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/ParamSchemaTest.java @@ -0,0 +1,436 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZonedDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.UUID; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.tool.Param; + +/** + * Unit tests for {@link ParamSchema} — runtime JSON Schema generation from + * {@link Param} descriptors. + */ +class ParamSchemaTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + // ── buildSchema: empty / zero params ───────────────────────────────────────── + + @Test + void buildSchema_nullParams_returnsEmptySchema() { + Map schema = ParamSchema.buildSchema("tool", MAPPER, (Param[]) null); + assertEquals("object", schema.get("type")); + assertTrue(((Map) schema.get("properties")).isEmpty()); + assertTrue(((List) schema.get("required")).isEmpty()); + } + + @Test + void buildSchema_emptyArray_returnsEmptySchema() { + Map schema = ParamSchema.buildSchema("tool", MAPPER); + assertEquals("object", schema.get("type")); + assertTrue(((Map) schema.get("properties")).isEmpty()); + assertTrue(((List) schema.get("required")).isEmpty()); + } + + // ── buildSchema: validation ────────────────────────────────────────────────── + + @Test + void buildSchema_nullParamElement_throwsWithToolName() { + Param p1 = Param.of(String.class, "a", "First"); + var ex = assertThrows(IllegalArgumentException.class, + () -> ParamSchema.buildSchema("my_tool", MAPPER, p1, null)); + assertTrue(ex.getMessage().contains("my_tool")); + } + + @Test + void buildSchema_duplicateNames_throwsWithToolNameAndParamName() { + Param p1 = Param.of(String.class, "name", "First name"); + Param p2 = Param.of(String.class, "name", "Second name"); + var ex = assertThrows(IllegalArgumentException.class, + () -> ParamSchema.buildSchema("greeting", MAPPER, p1, p2)); + assertTrue(ex.getMessage().contains("name")); + assertTrue(ex.getMessage().contains("greeting")); + } + + // ── buildSchema: required / optional semantics ─────────────────────────────── + + @Test + void buildSchema_requiredParam_appearsInRequiredList() { + Param p = Param.of(String.class, "query", "Search query"); + Map schema = ParamSchema.buildSchema("search", MAPPER, p); + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertTrue(required.contains("query")); + } + + @Test + void buildSchema_optionalParam_notInRequiredList() { + Param p = Param.of(Integer.class, "limit", "Max results", false, "10"); + Map schema = ParamSchema.buildSchema("list", MAPPER, p); + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertTrue(required.isEmpty()); + } + + @Test + void buildSchema_mixedRequiredAndOptional_onlyRequiredInList() { + Param pReq = Param.of(String.class, "query", "Search query"); + Param pOpt = Param.of(Integer.class, "limit", "Max", false, "20"); + Map schema = ParamSchema.buildSchema("search", MAPPER, pReq, pOpt); + @SuppressWarnings("unchecked") + List required = (List) schema.get("required"); + assertEquals(1, required.size()); + assertEquals("query", required.get(0)); + } + + // ── buildSchema: description and default in property ───────────────────────── + + @Test + void buildSchema_paramDescription_appearsInPropertySchema() { + Param p = Param.of(String.class, "msg", "A message to send"); + Map schema = ParamSchema.buildSchema("send", MAPPER, p); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + @SuppressWarnings("unchecked") + Map msgSchema = (Map) props.get("msg"); + assertEquals("A message to send", msgSchema.get("description")); + } + + @Test + void buildSchema_paramDefault_appearsInPropertySchema() { + Param p = Param.of(Integer.class, "count", "Item count", false, "5"); + Map schema = ParamSchema.buildSchema("items", MAPPER, p); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + @SuppressWarnings("unchecked") + Map countSchema = (Map) props.get("count"); + assertEquals(5, countSchema.get("default")); + } + + @Test + void buildSchema_stringDefault_appearsAsString() { + Param p = Param.of(String.class, "mode", "Operating mode", false, "fast"); + Map schema = ParamSchema.buildSchema("run", MAPPER, p); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + @SuppressWarnings("unchecked") + Map modeSchema = (Map) props.get("mode"); + assertEquals("fast", modeSchema.get("default")); + } + + @Test + void buildSchema_booleanDefault_appearsAsBoolean() { + Param p = Param.of(Boolean.class, "verbose", "Verbose mode", false, "true"); + Map schema = ParamSchema.buildSchema("run", MAPPER, p); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + @SuppressWarnings("unchecked") + Map verboseSchema = (Map) props.get("verbose"); + assertEquals(true, verboseSchema.get("default")); + } + + // ── buildSchema: multiple params preserve order ────────────────────────────── + + @Test + void buildSchema_multipleParams_orderPreservedInProperties() { + Param p1 = Param.of(String.class, "alpha", "First"); + Param p2 = Param.of(String.class, "beta", "Second"); + Param p3 = Param.of(String.class, "gamma", "Third"); + Map schema = ParamSchema.buildSchema("ordered", MAPPER, p1, p2, p3); + @SuppressWarnings("unchecked") + Map props = (Map) schema.get("properties"); + List keys = List.copyOf(props.keySet()); + assertEquals(List.of("alpha", "beta", "gamma"), keys); + } + + // ── forType: primitive and boxed integer types ─────────────────────────────── + + @Test + void forType_int_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(int.class)); + } + + @Test + void forType_Integer_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(Integer.class)); + } + + @Test + void forType_long_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(long.class)); + } + + @Test + void forType_Long_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(Long.class)); + } + + @Test + void forType_short_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(short.class)); + } + + @Test + void forType_Short_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(Short.class)); + } + + @Test + void forType_byte_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(byte.class)); + } + + @Test + void forType_Byte_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(Byte.class)); + } + + // ── forType: floating-point types ──────────────────────────────────────────── + + @Test + void forType_double_returnsNumber() { + assertEquals(Map.of("type", "number"), ParamSchema.forType(double.class)); + } + + @Test + void forType_Double_returnsNumber() { + assertEquals(Map.of("type", "number"), ParamSchema.forType(Double.class)); + } + + @Test + void forType_float_returnsNumber() { + assertEquals(Map.of("type", "number"), ParamSchema.forType(float.class)); + } + + @Test + void forType_Float_returnsNumber() { + assertEquals(Map.of("type", "number"), ParamSchema.forType(Float.class)); + } + + // ── forType: boolean ───────────────────────────────────────────────────────── + + @Test + void forType_boolean_returnsBoolean() { + assertEquals(Map.of("type", "boolean"), ParamSchema.forType(boolean.class)); + } + + @Test + void forType_Boolean_returnsBoolean() { + assertEquals(Map.of("type", "boolean"), ParamSchema.forType(Boolean.class)); + } + + // ── forType: char / Character ──────────────────────────────────────────────── + + @Test + void forType_char_returnsString() { + assertEquals(Map.of("type", "string"), ParamSchema.forType(char.class)); + } + + @Test + void forType_Character_returnsString() { + assertEquals(Map.of("type", "string"), ParamSchema.forType(Character.class)); + } + + // ── forType: String ────────────────────────────────────────────────────────── + + @Test + void forType_String_returnsString() { + assertEquals(Map.of("type", "string"), ParamSchema.forType(String.class)); + } + + // ── forType: UUID ──────────────────────────────────────────────────────────── + + @Test + void forType_UUID_returnsStringWithUuidFormat() { + Map schema = ParamSchema.forType(UUID.class); + assertEquals("string", schema.get("type")); + assertEquals("uuid", schema.get("format")); + } + + // ── forType: Optional primitive types ──────────────────────────────────────── + + @Test + void forType_OptionalInt_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(OptionalInt.class)); + } + + @Test + void forType_OptionalLong_returnsInteger() { + assertEquals(Map.of("type", "integer"), ParamSchema.forType(OptionalLong.class)); + } + + @Test + void forType_OptionalDouble_returnsNumber() { + assertEquals(Map.of("type", "number"), ParamSchema.forType(OptionalDouble.class)); + } + + // ── forType: date-time types ───────────────────────────────────────────────── + + @Test + void forType_OffsetDateTime_returnsDateTimeFormat() { + Map schema = ParamSchema.forType(OffsetDateTime.class); + assertEquals("string", schema.get("type")); + assertEquals("date-time", schema.get("format")); + } + + @Test + void forType_LocalDateTime_returnsDateTimeFormat() { + Map schema = ParamSchema.forType(LocalDateTime.class); + assertEquals("string", schema.get("type")); + assertEquals("date-time", schema.get("format")); + } + + @Test + void forType_Instant_returnsDateTimeFormat() { + Map schema = ParamSchema.forType(Instant.class); + assertEquals("string", schema.get("type")); + assertEquals("date-time", schema.get("format")); + } + + @Test + void forType_ZonedDateTime_returnsDateTimeFormat() { + Map schema = ParamSchema.forType(ZonedDateTime.class); + assertEquals("string", schema.get("type")); + assertEquals("date-time", schema.get("format")); + } + + @Test + void forType_LocalDate_returnsDateFormat() { + Map schema = ParamSchema.forType(LocalDate.class); + assertEquals("string", schema.get("type")); + assertEquals("date", schema.get("format")); + } + + @Test + void forType_LocalTime_returnsTimeFormat() { + Map schema = ParamSchema.forType(LocalTime.class); + assertEquals("string", schema.get("type")); + assertEquals("time", schema.get("format")); + } + + // ── forType: JsonNode / Object → any ───────────────────────────────────────── + + @Test + void forType_JsonNode_returnsEmptySchema() { + assertTrue(ParamSchema.forType(JsonNode.class).isEmpty()); + } + + @Test + void forType_Object_returnsEmptySchema() { + assertTrue(ParamSchema.forType(Object.class).isEmpty()); + } + + // ── forType: enums ─────────────────────────────────────────────────────────── + + @Test + void forType_enum_returnsStringWithEnumValues() { + Map schema = ParamSchema.forType(TestColor.class); + assertEquals("string", schema.get("type")); + @SuppressWarnings("unchecked") + List values = (List) schema.get("enum"); + assertNotNull(values); + assertEquals(List.of("RED", "GREEN", "BLUE"), values); + } + + // ── forType: collections ───────────────────────────────────────────────────── + + @Test + void forType_List_returnsArray() { + assertEquals(Map.of("type", "array"), ParamSchema.forType(List.class)); + } + + @Test + void forType_Set_returnsArray() { + assertEquals(Map.of("type", "array"), ParamSchema.forType(Set.class)); + } + + @Test + void forType_Collection_returnsArray() { + assertEquals(Map.of("type", "array"), ParamSchema.forType(Collection.class)); + } + + // ── forType: arrays ────────────────────────────────────────────────────────── + + @Test + void forType_stringArray_returnsArrayWithStringItems() { + Map schema = ParamSchema.forType(String[].class); + assertEquals("array", schema.get("type")); + @SuppressWarnings("unchecked") + Map items = (Map) schema.get("items"); + assertEquals("string", items.get("type")); + } + + @Test + void forType_intArray_returnsArrayWithIntegerItems() { + Map schema = ParamSchema.forType(int[].class); + assertEquals("array", schema.get("type")); + @SuppressWarnings("unchecked") + Map items = (Map) schema.get("items"); + assertEquals("integer", items.get("type")); + } + + @Test + void forType_doubleArray_returnsArrayWithNumberItems() { + Map schema = ParamSchema.forType(double[].class); + assertEquals("array", schema.get("type")); + @SuppressWarnings("unchecked") + Map items = (Map) schema.get("items"); + assertEquals("number", items.get("type")); + } + + // ── forType: Map ───────────────────────────────────────────────────────────── + + @Test + void forType_Map_returnsObject() { + assertEquals(Map.of("type", "object"), ParamSchema.forType(Map.class)); + } + + // ── forType: POJO / record fallback ────────────────────────────────────────── + + @Test + void forType_record_returnsObject() { + assertEquals(Map.of("type", "object"), ParamSchema.forType(TestRecord.class)); + } + + @Test + void forType_pojo_returnsObject() { + assertEquals(Map.of("type", "object"), ParamSchema.forType(TestPojo.class)); + } + + // ── Test helper types ──────────────────────────────────────────────────────── + + enum TestColor { + RED, GREEN, BLUE + } + + record TestRecord(String name, int value) { + } + + static class TestPojo { + String field; + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/ToolDefinitionLambdaTest.java b/java/src/test/java/com/github/copilot/rpc/ToolDefinitionLambdaTest.java new file mode 100644 index 000000000..7f9ccaaba --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/ToolDefinitionLambdaTest.java @@ -0,0 +1,613 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.AllowCopilotExperimental; +import com.github.copilot.tool.Param; + +/** + * Unit tests for {@link ToolDefinition#from}, {@link ToolDefinition#fromAsync}, + * {@link ToolDefinition#fromWithToolInvocation}, and + * {@link ToolDefinition#fromAsyncWithToolInvocation} lambda-tool factories, + * plus the fluent option-modifier methods + * ({@link ToolDefinition#overridesBuiltInTool}, + * {@link ToolDefinition#skipPermission}, {@link ToolDefinition#defer}). + * + *

+ * Tests are grouped by the Phase 4.4 contract: + *

    + *
  1. Successful inline definitions for arities 0–2 (sync and async).
  2. + *
  3. ToolInvocation context injection (sync and async).
  4. + *
  5. Option flag propagation.
  6. + *
  7. Required/default semantics.
  8. + *
  9. Error and validation paths.
  10. + *
  11. Schema structure.
  12. + *
  13. Result formatting (String, null, non-String).
  14. + *
  15. Argument coercion.
  16. + *
+ */ +@AllowCopilotExperimental +class ToolDefinitionLambdaTest { + + // ── Helpers ────────────────────────────────────────────────────────────────── + + private static ToolInvocation invocationOf(Map args) { + ObjectNode argsNode = JsonNodeFactory.instance.objectNode(); + for (Map.Entry e : args.entrySet()) { + Object v = e.getValue(); + if (v instanceof String s) { + argsNode.put(e.getKey(), s); + } else if (v instanceof Integer i) { + argsNode.put(e.getKey(), i); + } else if (v instanceof Long l) { + argsNode.put(e.getKey(), l); + } else if (v instanceof Double d) { + argsNode.put(e.getKey(), d); + } else if (v instanceof Boolean b) { + argsNode.put(e.getKey(), b); + } else if (v != null) { + argsNode.put(e.getKey(), v.toString()); + } + } + return new ToolInvocation().setArguments(argsNode); + } + + private static ToolInvocation invocationWithContext(String sessionId, String toolCallId, Map args) { + return invocationOf(args).setSessionId(sessionId).setToolCallId(toolCallId); + } + + @SuppressWarnings("unchecked") + private static Map schemaOf(ToolDefinition tool) { + return (Map) tool.parameters(); + } + + @SuppressWarnings("unchecked") + private static Map propertiesOf(ToolDefinition tool) { + return (Map) schemaOf(tool).get("properties"); + } + + @SuppressWarnings("unchecked") + private static List requiredOf(ToolDefinition tool) { + return (List) schemaOf(tool).get("required"); + } + + // ── Group 1: Successful inline definitions – arity 0, sync ─────────────────── + + @Test + void from_zeroArg_returnsNameAndDescription() { + ToolDefinition tool = ToolDefinition.from("ping", "Returns pong", () -> "pong"); + assertEquals("ping", tool.name()); + assertEquals("Returns pong", tool.description()); + } + + @Test + void from_zeroArg_invokesHandler() throws Exception { + ToolDefinition tool = ToolDefinition.from("ping", "Returns pong", () -> "pong"); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("pong", result); + } + + @Test + void from_zeroArg_emptySchema() { + ToolDefinition tool = ToolDefinition.from("ping", "Returns pong", () -> "pong"); + assertTrue(propertiesOf(tool).isEmpty()); + assertTrue(requiredOf(tool).isEmpty()); + } + + // ── Group 1: Successful inline definitions – arity 1, sync ─────────────────── + + @Test + void from_oneArg_returnsNameAndDescription() { + Param nameParam = Param.of(String.class, "name", "The user's name"); + ToolDefinition tool = ToolDefinition.from("greet", "Greets a user", nameParam, n -> "Hello, " + n + "!"); + assertEquals("greet", tool.name()); + assertEquals("Greets a user", tool.description()); + } + + @Test + void from_oneArg_invokesHandler() throws Exception { + Param nameParam = Param.of(String.class, "name", "The user's name"); + ToolDefinition tool = ToolDefinition.from("greet", "Greets a user", nameParam, n -> "Hello, " + n + "!"); + Object result = tool.handler().invoke(invocationOf(Map.of("name", "Alice"))).get(); + assertEquals("Hello, Alice!", result); + } + + @Test + void from_oneArg_schemaContainsParam() { + Param nameParam = Param.of(String.class, "name", "The user's name"); + ToolDefinition tool = ToolDefinition.from("greet", "Greets a user", nameParam, n -> "Hello, " + n + "!"); + assertTrue(propertiesOf(tool).containsKey("name")); + assertTrue(requiredOf(tool).contains("name")); + } + + // ── Group 1: Successful inline definitions – arity 2, sync ─────────────────── + + @Test + void from_twoArg_invokesHandler() throws Exception { + Param paramA = Param.of(Integer.class, "a", "First number"); + Param paramB = Param.of(Integer.class, "b", "Second number"); + ToolDefinition tool = ToolDefinition.from("add", "Adds two integers", paramA, paramB, + (a, b) -> String.valueOf(a + b)); + Object result = tool.handler().invoke(invocationOf(Map.of("a", 3, "b", 4))).get(); + assertEquals("7", result); + } + + @Test + void from_twoArg_schemaBothParamsPresent() { + Param paramA = Param.of(Integer.class, "a", "First"); + Param paramB = Param.of(Integer.class, "b", "Second"); + ToolDefinition tool = ToolDefinition.from("add", "Adds two integers", paramA, paramB, (a, b) -> a + b); + assertTrue(propertiesOf(tool).containsKey("a")); + assertTrue(propertiesOf(tool).containsKey("b")); + assertTrue(requiredOf(tool).contains("a")); + assertTrue(requiredOf(tool).contains("b")); + } + + // ── Group 2: Async handlers (fromAsync) ────────────────────────────────────── + + @Test + void fromAsync_zeroArg_invokesHandler() throws Exception { + ToolDefinition tool = ToolDefinition.fromAsync("ping_async", "Async ping", + () -> CompletableFuture.completedFuture("pong")); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("pong", result); + } + + @Test + void fromAsync_oneArg_invokesHandler() throws Exception { + Param nameParam = Param.of(String.class, "name", "Name to greet"); + ToolDefinition tool = ToolDefinition.fromAsync("greet_async", "Async greet", nameParam, + n -> CompletableFuture.completedFuture("Hi, " + n + "!")); + Object result = tool.handler().invoke(invocationOf(Map.of("name", "Bob"))).get(); + assertEquals("Hi, Bob!", result); + } + + @Test + void fromAsync_twoArg_invokesHandler() throws Exception { + Param paramA = Param.of(Integer.class, "a", "Left operand"); + Param paramB = Param.of(Integer.class, "b", "Right operand"); + ToolDefinition tool = ToolDefinition.fromAsync("add_async", "Async add", paramA, paramB, + (a, b) -> CompletableFuture.completedFuture(String.valueOf(a + b))); + Object result = tool.handler().invoke(invocationOf(Map.of("a", 10, "b", 5))).get(); + assertEquals("15", result); + } + + // ── Group 3: ToolInvocation context injection (sync) ───────────────────────── + + @Test + void fromWithToolInvocation_zeroArg_receivesContext() throws Exception { + ToolDefinition tool = ToolDefinition.fromWithToolInvocation("ctx_sync", "Returns session id", + inv -> "session=" + inv.getSessionId()); + Object result = tool.handler().invoke(invocationWithContext("sess-1", "call-1", Map.of())).get(); + assertEquals("session=sess-1", result); + } + + @Test + void fromWithToolInvocation_zeroArg_emptySchema() { + ToolDefinition tool = ToolDefinition.fromWithToolInvocation("ctx_sync", "Returns session id", + inv -> "session=" + inv.getSessionId()); + assertTrue(propertiesOf(tool).isEmpty()); + assertTrue(requiredOf(tool).isEmpty()); + } + + @Test + void fromWithToolInvocation_oneArg_receivesArgAndContext() throws Exception { + Param phaseParam = Param.of(String.class, "phase", "Current phase"); + ToolDefinition tool = ToolDefinition.fromWithToolInvocation("report", "Report phase", phaseParam, + (phase, inv) -> "phase=" + phase + ",callId=" + inv.getToolCallId()); + Object result = tool.handler().invoke(invocationWithContext("sess-2", "call-42", Map.of("phase", "analysis"))) + .get(); + assertEquals("phase=analysis,callId=call-42", result); + } + + @Test + void fromWithToolInvocation_oneArg_schemaExcludesInvocationParam() { + Param phaseParam = Param.of(String.class, "phase", "Current phase"); + ToolDefinition tool = ToolDefinition.fromWithToolInvocation("report", "Report phase", phaseParam, + (phase, inv) -> phase); + assertTrue(propertiesOf(tool).containsKey("phase")); + assertFalse(propertiesOf(tool).containsKey("invocation")); + assertEquals(List.of("phase"), requiredOf(tool)); + } + + // ── Group 4: Async ToolInvocation context injection ────────────────────────── + + @Test + void fromAsyncWithToolInvocation_zeroArg_receivesContext() throws Exception { + ToolDefinition tool = ToolDefinition.fromAsyncWithToolInvocation("ctx_async", "Async ctx", + inv -> CompletableFuture.completedFuture("callId=" + inv.getToolCallId())); + Object result = tool.handler().invoke(invocationWithContext("sess-3", "call-99", Map.of())).get(); + assertEquals("callId=call-99", result); + } + + @Test + void fromAsyncWithToolInvocation_oneArg_receivesArgAndContext() throws Exception { + Param phaseParam = Param.of(String.class, "phase", "Phase name"); + ToolDefinition tool = ToolDefinition.fromAsyncWithToolInvocation("report_async", "Async report", phaseParam, + (phase, inv) -> CompletableFuture.completedFuture("phase=" + phase + ",sess=" + inv.getSessionId())); + Object result = tool.handler().invoke(invocationWithContext("sess-4", "call-7", Map.of("phase", "planning"))) + .get(); + assertEquals("phase=planning,sess=sess-4", result); + } + + // ── Group 5: Option flag propagation ───────────────────────────────────────── + + @Test + void overridesBuiltInTool_setsFlag() { + ToolDefinition base = ToolDefinition.from("grep", "Custom grep", () -> "ok"); + assertNull(base.overridesBuiltInTool()); + ToolDefinition withOverride = base.overridesBuiltInTool(true); + assertEquals(Boolean.TRUE, withOverride.overridesBuiltInTool()); + } + + @Test + void overridesBuiltInTool_doesNotMutateOriginal() { + ToolDefinition base = ToolDefinition.from("grep", "Custom grep", () -> "ok"); + base.overridesBuiltInTool(true); + assertNull(base.overridesBuiltInTool(), "original must remain unchanged"); + } + + @Test + void skipPermission_setsFlag() { + ToolDefinition base = ToolDefinition.from("read_file", "Reads a file", () -> "contents"); + assertNull(base.skipPermission()); + ToolDefinition withSkip = base.skipPermission(true); + assertEquals(Boolean.TRUE, withSkip.skipPermission()); + } + + @Test + void skipPermission_doesNotMutateOriginal() { + ToolDefinition base = ToolDefinition.from("read_file", "Reads a file", () -> "contents"); + base.skipPermission(true); + assertNull(base.skipPermission(), "original must remain unchanged"); + } + + @Test + void defer_setsAutoMode() { + ToolDefinition base = ToolDefinition.from("search", "Searches things", () -> "results"); + assertNull(base.defer()); + ToolDefinition deferred = base.defer(ToolDefer.AUTO); + assertEquals(ToolDefer.AUTO, deferred.defer()); + } + + @Test + void defer_setsNeverMode() { + ToolDefinition base = ToolDefinition.from("must_preload", "Always preloaded", () -> "ok"); + ToolDefinition neverDeferred = base.defer(ToolDefer.NEVER); + assertEquals(ToolDefer.NEVER, neverDeferred.defer()); + } + + @Test + void defer_doesNotMutateOriginal() { + ToolDefinition base = ToolDefinition.from("search", "Searches things", () -> "results"); + base.defer(ToolDefer.AUTO); + assertNull(base.defer(), "original must remain unchanged"); + } + + @Test + void fluentModifiers_canBeChained() { + ToolDefinition tool = ToolDefinition.from("override_tool", "Overrides built-in", () -> "ok") + .overridesBuiltInTool(true).skipPermission(true).defer(ToolDefer.AUTO); + assertEquals(Boolean.TRUE, tool.overridesBuiltInTool()); + assertEquals(Boolean.TRUE, tool.skipPermission()); + assertEquals(ToolDefer.AUTO, tool.defer()); + } + + @Test + void fluentModifiers_preserveHandlerAndSchema() throws Exception { + Param p = Param.of(String.class, "msg", "A message"); + ToolDefinition tool = ToolDefinition.from("echo", "Echoes message", p, msg -> msg).skipPermission(true) + .overridesBuiltInTool(false); + assertNotNull(tool.handler()); + Object result = tool.handler().invoke(invocationOf(Map.of("msg", "hello"))).get(); + assertEquals("hello", result); + } + + // ── Group 6: Required/default semantics ────────────────────────────────────── + + @Test + void requiredParam_passedValue_usesProvidedValue() throws Exception { + Param p = Param.of(String.class, "word", "A word"); + ToolDefinition tool = ToolDefinition.from("echo", "Echoes", p, w -> w); + Object result = tool.handler().invoke(invocationOf(Map.of("word", "hello"))).get(); + assertEquals("hello", result); + } + + @Test + void requiredParam_missingFromInvocation_throwsIllegalArgumentException() { + Param p = Param.of(String.class, "word", "A required word"); + ToolDefinition tool = ToolDefinition.from("echo", "Echoes", p, w -> w); + var ex = assertThrows(IllegalArgumentException.class, () -> tool.handler().invoke(invocationOf(Map.of()))); + assertTrue(ex.getMessage().contains("word"), "Exception message should mention the missing parameter name"); + } + + @Test + void optionalParamWithDefault_absent_usesDefault() throws Exception { + Param p = Param.of(Integer.class, "limit", "Max results", false, "10"); + ToolDefinition tool = ToolDefinition.from("list", "Lists items", p, lim -> "limit=" + lim); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("limit=10", result); + } + + @Test + void optionalParamWithDefault_provided_usesProvidedValue() throws Exception { + Param p = Param.of(Integer.class, "limit", "Max results", false, "10"); + ToolDefinition tool = ToolDefinition.from("list", "Lists items", p, lim -> "limit=" + lim); + Object result = tool.handler().invoke(invocationOf(Map.of("limit", 25))).get(); + assertEquals("limit=25", result); + } + + @Test + void optionalParamWithDefault_schemaNotInRequired() { + Param p = Param.of(Integer.class, "limit", "Max results", false, "10"); + ToolDefinition tool = ToolDefinition.from("list", "Lists items", p, lim -> "limit=" + lim); + assertFalse(requiredOf(tool).contains("limit")); + assertTrue(propertiesOf(tool).containsKey("limit")); + } + + @Test + void optionalParam_absent_noDefaultYieldsNull() throws Exception { + Param p = Param.of(String.class, "title", "Optional title", false, ""); + ToolDefinition tool = ToolDefinition.from("greet", "Greets", p, t -> t == null ? "(no title)" : t); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("(no title)", result); + } + + @Test + void defaultValueAppearsInSchema() { + Param p = Param.of(Integer.class, "limit", "Max results", false, "5"); + ToolDefinition tool = ToolDefinition.from("list", "Lists items", p, lim -> lim.toString()); + @SuppressWarnings("unchecked") + Map limitPropSchema = (Map) propertiesOf(tool).get("limit"); + assertNotNull(limitPropSchema, "Schema must include 'limit' property"); + assertEquals(5, limitPropSchema.get("default"), "Default value must appear in schema"); + } + + // ── Group 7: Error / validation paths ──────────────────────────────────────── + + @Test + void from_nullName_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from(null, "desc", () -> "ok")); + } + + @Test + void from_blankName_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from(" ", "desc", () -> "ok")); + } + + @Test + void from_nullDescription_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from("tool", null, () -> "ok")); + } + + @Test + void from_blankDescription_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from("tool", "", () -> "ok")); + } + + @Test + void from_nullHandler_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, + () -> ToolDefinition.from("tool", "desc", (java.util.function.Supplier) null)); + } + + @Test + void from_oneArg_nullParam_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, + () -> ToolDefinition.from("tool", "desc", (Param) null, s -> s)); + } + + @Test + void from_twoArg_nullFirstParam_throwsIllegalArgumentException() { + Param p2 = Param.of(String.class, "b", "B param"); + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from("tool", "desc", null, p2, (a, b) -> a)); + } + + @Test + void from_twoArg_nullSecondParam_throwsIllegalArgumentException() { + Param p1 = Param.of(String.class, "a", "A param"); + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.from("tool", "desc", p1, null, (a, b) -> a)); + } + + @Test + void from_twoArg_duplicateParamNames_throwsIllegalArgumentException() { + Param p1 = Param.of(String.class, "name", "Name 1"); + Param p2 = Param.of(String.class, "name", "Name 2"); + var ex = assertThrows(IllegalArgumentException.class, + () -> ToolDefinition.from("tool", "desc", p1, p2, (a, b) -> a + b)); + assertTrue(ex.getMessage().contains("name"), "error must mention the duplicate param name"); + assertTrue(ex.getMessage().contains("tool"), "error must mention the tool name"); + } + + @Test + void fromAsync_nullName_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, + () -> ToolDefinition.fromAsync(null, "desc", () -> CompletableFuture.completedFuture("ok"))); + } + + @Test + void fromAsync_nullHandler_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.fromAsync("tool", "desc", + (java.util.function.Supplier>) null)); + } + + @Test + void fromWithToolInvocation_nullName_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, + () -> ToolDefinition.fromWithToolInvocation(null, "desc", inv -> "ok")); + } + + @Test + void fromAsyncWithToolInvocation_nullDescription_throwsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> ToolDefinition.fromAsyncWithToolInvocation("tool", null, + inv -> CompletableFuture.completedFuture("ok"))); + } + + // ── Group 8: Schema structure + // ───────────────────────────────────────────────── + + @Test + void schema_zeroArg_hasTypeObjectAndEmptyMaps() { + ToolDefinition tool = ToolDefinition.from("noop", "No-op", () -> "done"); + Map schema = schemaOf(tool); + assertEquals("object", schema.get("type")); + assertTrue(((Map) schema.get("properties")).isEmpty()); + assertTrue(((List) schema.get("required")).isEmpty()); + } + + @Test + void schema_oneArg_hasCorrectTypeForString() { + Param p = Param.of(String.class, "query", "Search query"); + ToolDefinition tool = ToolDefinition.from("search", "Searches", p, q -> q); + @SuppressWarnings("unchecked") + Map querySchema = (Map) propertiesOf(tool).get("query"); + assertNotNull(querySchema); + assertEquals("string", querySchema.get("type")); + assertEquals("Search query", querySchema.get("description")); + } + + @Test + void schema_oneArg_hasCorrectTypeForInteger() { + Param p = Param.of(Integer.class, "count", "Item count"); + ToolDefinition tool = ToolDefinition.from("count_items", "Counts items", p, c -> c.toString()); + @SuppressWarnings("unchecked") + Map countSchema = (Map) propertiesOf(tool).get("count"); + assertNotNull(countSchema); + assertEquals("integer", countSchema.get("type")); + } + + @Test + void schema_oneArg_hasCorrectTypeForBoolean() { + Param p = Param.of(Boolean.class, "enabled", "Whether enabled"); + ToolDefinition tool = ToolDefinition.from("toggle", "Toggles", p, e -> e.toString()); + @SuppressWarnings("unchecked") + Map enabledSchema = (Map) propertiesOf(tool).get("enabled"); + assertNotNull(enabledSchema); + assertEquals("boolean", enabledSchema.get("type")); + } + + @Test + void schema_oneArg_enumTypeHasStringAndEnumValues() { + Param p = Param.of(Color.class, "color", "A color"); + ToolDefinition tool = ToolDefinition.from("paint", "Paints with a color", p, c -> c.name()); + @SuppressWarnings("unchecked") + Map colorSchema = (Map) propertiesOf(tool).get("color"); + assertNotNull(colorSchema); + assertEquals("string", colorSchema.get("type")); + @SuppressWarnings("unchecked") + List enumValues = (List) colorSchema.get("enum"); + assertNotNull(enumValues); + assertTrue(enumValues.contains("RED")); + assertTrue(enumValues.contains("GREEN")); + assertTrue(enumValues.contains("BLUE")); + } + + // ── Group 9: Result formatting + // ──────────────────────────────────────────────── + + @Test + void resultFormatting_stringReturnedAsIs() throws Exception { + ToolDefinition tool = ToolDefinition.from("echo", "Echoes", () -> "plain text"); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("plain text", result); + } + + @Test + void resultFormatting_nullMappedToSuccess() throws Exception { + ToolDefinition tool = ToolDefinition.from("noop", "No-op", () -> null); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("Success", result); + } + + @Test + void resultFormatting_nonStringSerializedToJson() throws Exception { + Param p = Param.of(String.class, "key", "Key name"); + ToolDefinition tool = ToolDefinition.from("to_map", "Wraps in map", p, k -> Map.of("key", k, "value", 42)); + Object result = tool.handler().invoke(invocationOf(Map.of("key", "x"))).get(); + assertNotNull(result); + assertTrue(result instanceof String, "Non-String should be JSON-serialized to String"); + String json = (String) result; + ObjectMapper mapper = new ObjectMapper(); + JsonNode node = mapper.readTree(json); + assertTrue(node.isObject(), "Result should be a JSON object"); + assertEquals("x", node.get("key").asText(), "JSON must contain key field with value 'x'"); + assertEquals(42, node.get("value").asInt(), "JSON must contain value field with value 42"); + } + + @Test + void resultFormatting_integerSerializedToJson() throws Exception { + ToolDefinition tool = ToolDefinition.from("forty_two", "Returns 42", () -> 42); + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("42", result); + } + + // ── Group 10: Argument coercion + // ─────────────────────────────────────────────── + + @Test + void coercion_stringArgPassedThrough() throws Exception { + Param p = Param.of(String.class, "msg", "A message"); + ToolDefinition tool = ToolDefinition.from("echo", "Echoes message", p, m -> m); + Object result = tool.handler().invoke(invocationOf(Map.of("msg", "hello world"))).get(); + assertEquals("hello world", result); + } + + @Test + void coercion_integerArgFromJsonNumber() throws Exception { + Param p = Param.of(Integer.class, "n", "An integer"); + ToolDefinition tool = ToolDefinition.from("double_it", "Doubles n", p, n -> String.valueOf(n * 2)); + Object result = tool.handler().invoke(invocationOf(Map.of("n", 7))).get(); + assertEquals("14", result); + } + + @Test + void coercion_booleanArg() throws Exception { + Param p = Param.of(Boolean.class, "flag", "A flag"); + ToolDefinition tool = ToolDefinition.from("flagged", "Reports flag", p, f -> f ? "yes" : "no"); + Object result = tool.handler().invoke(invocationOf(Map.of("flag", true))).get(); + assertEquals("yes", result); + } + + @Test + void coercion_enumArgFromString() throws Exception { + Param p = Param.of(Color.class, "color", "A color"); + ToolDefinition tool = ToolDefinition.from("paint", "Paints", p, c -> c.name().toLowerCase()); + Object result = tool.handler().invoke(invocationOf(Map.of("color", "GREEN"))).get(); + assertEquals("green", result); + } + + @Test + void coercion_defaultIntegerParsedCorrectly() throws Exception { + Param p = Param.of(Integer.class, "limit", "Max count", false, "99"); + ToolDefinition tool = ToolDefinition.from("bounded", "Bounded list", p, lim -> "got=" + lim); + // No argument provided — should use default 99 + Object result = tool.handler().invoke(invocationOf(Map.of())).get(); + assertEquals("got=99", result); + } + + // ── Inner types for test helpers + // ────────────────────────────────────────────── + + enum Color { + RED, GREEN, BLUE + } +} diff --git a/java/src/test/java/com/github/copilot/tool/ParamTest.java b/java/src/test/java/com/github/copilot/tool/ParamTest.java new file mode 100644 index 000000000..75f6e4422 --- /dev/null +++ b/java/src/test/java/com/github/copilot/tool/ParamTest.java @@ -0,0 +1,262 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link Param} runtime parameter metadata. + */ +public class ParamTest { + + // ------------------------------------------------------------------ + // Factory method: of(type, name, description) + // ------------------------------------------------------------------ + + @Test + void ofCreatesRequiredParamWithNoDefault() { + Param p = Param.of(String.class, "query", "Search query"); + assertEquals(String.class, p.type()); + assertEquals("query", p.name()); + assertEquals("Search query", p.description()); + assertTrue(p.required()); + assertEquals("", p.defaultValue()); + assertFalse(p.hasDefaultValue()); + } + + @Test + void ofFullFactoryCreatesOptionalParamWithDefault() { + Param p = Param.of(Integer.class, "limit", "Max results", false, "10"); + assertEquals(Integer.class, p.type()); + assertEquals("limit", p.name()); + assertEquals("Max results", p.description()); + assertFalse(p.required()); + assertEquals("10", p.defaultValue()); + assertTrue(p.hasDefaultValue()); + } + + // ------------------------------------------------------------------ + // Validation: blank name/description rejected + // ------------------------------------------------------------------ + + @Test + void rejectsNullName() { + var ex = assertThrows(IllegalArgumentException.class, () -> Param.of(String.class, null, "desc")); + assertTrue(ex.getMessage().contains("name")); + } + + @Test + void rejectsBlankName() { + var ex = assertThrows(IllegalArgumentException.class, () -> Param.of(String.class, " ", "desc")); + assertTrue(ex.getMessage().contains("name")); + } + + @Test + void rejectsNullDescription() { + var ex = assertThrows(IllegalArgumentException.class, () -> Param.of(String.class, "n", null)); + assertTrue(ex.getMessage().contains("description")); + } + + @Test + void rejectsBlankDescription() { + var ex = assertThrows(IllegalArgumentException.class, () -> Param.of(String.class, "n", "")); + assertTrue(ex.getMessage().contains("description")); + } + + // ------------------------------------------------------------------ + // Validation: required=true with non-empty default rejected + // ------------------------------------------------------------------ + + @Test + void rejectsRequiredWithNonEmptyDefault() { + var ex = assertThrows(IllegalArgumentException.class, () -> Param.of(String.class, "x", "desc", true, "val")); + assertTrue(ex.getMessage().contains("required=true")); + } + + @Test + void allowsRequiredWithEmptyDefault() { + Param p = Param.of(String.class, "x", "desc", true, ""); + assertTrue(p.required()); + assertFalse(p.hasDefaultValue()); + } + + @Test + void allowsRequiredWithNullDefault() { + Param p = Param.of(String.class, "x", "desc", true, null); + assertTrue(p.required()); + assertEquals("", p.defaultValue()); + } + + // ------------------------------------------------------------------ + // Validation: default value type checking + // ------------------------------------------------------------------ + + @Test + void validatesIntegerDefault() { + // valid + Param p = Param.of(Integer.class, "n", "num", false, "42"); + assertEquals("42", p.defaultValue()); + + // invalid + assertThrows(IllegalArgumentException.class, () -> Param.of(Integer.class, "n", "num", false, "abc")); + } + + @Test + void validatesLongDefault() { + Param p = Param.of(Long.class, "n", "num", false, "999999999999"); + assertEquals("999999999999", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Long.class, "n", "num", false, "notlong")); + } + + @Test + void validatesDoubleDefault() { + Param p = Param.of(Double.class, "d", "decimal", false, "3.14"); + assertEquals("3.14", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Double.class, "d", "decimal", false, "xyz")); + } + + @Test + void validatesFloatDefault() { + Param p = Param.of(Float.class, "f", "float val", false, "1.5"); + assertEquals("1.5", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Float.class, "f", "float val", false, "notfloat")); + } + + @Test + void validatesShortDefault() { + Param p = Param.of(Short.class, "s", "short val", false, "100"); + assertEquals("100", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Short.class, "s", "short val", false, "99999")); + } + + @Test + void validatesByteDefault() { + Param p = Param.of(Byte.class, "b", "byte val", false, "127"); + assertEquals("127", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Byte.class, "b", "byte val", false, "999")); + } + + @Test + void validatesBooleanDefault() { + Param p1 = Param.of(Boolean.class, "b", "flag", false, "true"); + assertEquals("true", p1.defaultValue()); + + Param p2 = Param.of(Boolean.class, "b", "flag", false, "FALSE"); + assertEquals("FALSE", p2.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(Boolean.class, "b", "flag", false, "yes")); + } + + @Test + void validatesEnumDefault() { + Param p = Param.of(TestEnum.class, "e", "enum val", false, "ALPHA"); + assertEquals("ALPHA", p.defaultValue()); + + assertThrows(IllegalArgumentException.class, () -> Param.of(TestEnum.class, "e", "enum val", false, "INVALID")); + } + + @Test + void rejectsUnsupportedTypeWithDefault() { + assertThrows(IllegalArgumentException.class, () -> Param.of(Object.class, "o", "object", false, "something")); + } + + @Test + void allowsStringDefault() { + Param p = Param.of(String.class, "s", "string", false, "hello"); + assertEquals("hello", p.defaultValue()); + } + + // ------------------------------------------------------------------ + // Fluent mutators return new instances + // ------------------------------------------------------------------ + + @Test + void nameMutatorReturnsNewInstance() { + Param original = Param.of(String.class, "a", "desc"); + Param renamed = original.name("b"); + assertEquals("a", original.name()); + assertEquals("b", renamed.name()); + } + + @Test + void descriptionMutatorReturnsNewInstance() { + Param original = Param.of(String.class, "a", "desc1"); + Param updated = original.description("desc2"); + assertEquals("desc1", original.description()); + assertEquals("desc2", updated.description()); + } + + @Test + void requiredMutatorReturnsNewInstance() { + Param original = Param.of(String.class, "a", "desc"); + Param optional = original.required(false); + assertTrue(original.required()); + assertFalse(optional.required()); + } + + @Test + void defaultValueMutatorSetsOptional() { + Param original = Param.of(String.class, "a", "desc"); + Param withDefault = original.defaultValue("val"); + assertTrue(original.required()); + assertFalse(withDefault.required()); + assertEquals("val", withDefault.defaultValue()); + assertTrue(withDefault.hasDefaultValue()); + } + + // ------------------------------------------------------------------ + // equals / hashCode / toString + // ------------------------------------------------------------------ + + @Test + void equalParamsAreEqual() { + Param a = Param.of(String.class, "x", "desc"); + Param b = Param.of(String.class, "x", "desc"); + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + } + + @Test + void differentParamsAreNotEqual() { + Param a = Param.of(String.class, "x", "desc"); + Param b = Param.of(String.class, "y", "desc"); + assertNotEquals(a, b); + } + + @Test + void toStringContainsName() { + Param p = Param.of(String.class, "query", "Search"); + assertTrue(p.toString().contains("query")); + assertTrue(p.toString().contains("String")); + } + + // ------------------------------------------------------------------ + // Null type rejected + // ------------------------------------------------------------------ + + @Test + void rejectsNullType() { + assertThrows(NullPointerException.class, () -> Param.of(null, "n", "desc")); + } + + // ------------------------------------------------------------------ + // Test enum for validation tests + // ------------------------------------------------------------------ + + enum TestEnum { + ALPHA, BETA + } +}