Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,17 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
_ => null,
};
var connectResponse = await InvokeRpcAsync<ConnectResult>(
connection.Rpc, "connect", [new ConnectRequest { Token = token }], connection.StderrBuffer, cancellationToken);
connection.Rpc,
"connect",
[new ConnectHandshakeRequest(
token,
// Opt in to GitHub telemetry forwarding at the connection level when a
// handler is registered (mirrors the runtime, which reads this flag on the
// `connect` handshake so the first session's un-replayable `session.start`
// event is forwarded). Also sent on session.create/resume for older CLIs.
_options.OnGitHubTelemetry != null ? true : null)],
connection.StderrBuffer,
cancellationToken);
serverVersion = (int)connectResponse.ProtocolVersion;
}
catch (IOException ex) when (ex.InnerException is RemoteRpcException remoteEx && IsUnsupportedConnectMethod(remoteEx))
Expand Down Expand Up @@ -2639,6 +2649,10 @@ internal record GetSessionMetadataRequest(
internal record GetSessionMetadataResponse(
SessionMetadata? Session);

internal record ConnectHandshakeRequest(
string? Token,
[property: JsonPropertyName("enableGitHubTelemetryForwarding")] bool? EnableGitHubTelemetryForwarding = null);

internal record SetForegroundSessionRequest(
string SessionId);

Expand Down Expand Up @@ -2673,6 +2687,7 @@ internal record HooksInvokeResponse(
[JsonSerializable(typeof(ListSessionsResponse))]
[JsonSerializable(typeof(GetSessionMetadataRequest))]
[JsonSerializable(typeof(GetSessionMetadataResponse))]
[JsonSerializable(typeof(ConnectHandshakeRequest))]
[JsonSerializable(typeof(McpOAuthTokenStorageMode))]
[JsonSerializable(typeof(EmbeddingCacheStorageMode))]
[JsonSerializable(typeof(ModelCapabilitiesOverride))]
Expand Down
53 changes: 47 additions & 6 deletions dotnet/test/Unit/GitHubTelemetryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,39 @@ public async Task ResumeSession_Opts_Into_Forwarding_When_Handler_Provided()
Assert.True(flag.GetBoolean());
}

[Fact]
public async Task Connect_Opts_Into_Forwarding_When_Handler_Provided()
{
await using var server = await FakeTelemetryServer.StartAsync();
await using var client = new CopilotClient(new CopilotClientOptions
{
Connection = RuntimeConnection.ForUri(server.Url),
OnGitHubTelemetry = _ => Task.CompletedTask,
});
await client.StartAsync();

var connectParams = server.LastConnectParams ?? throw new InvalidOperationException("connect was not captured.");
Assert.True(connectParams.TryGetProperty("enableGitHubTelemetryForwarding", out var flag));
Assert.True(flag.GetBoolean());
}

[Fact]
public async Task Connect_Does_Not_Opt_In_Without_Handler()
{
await using var server = await FakeTelemetryServer.StartAsync();
await using var client = new CopilotClient(new CopilotClientOptions
{
Connection = RuntimeConnection.ForUri(server.Url),
});
await client.StartAsync();

var connectParams = server.LastConnectParams ?? throw new InvalidOperationException("connect was not captured.");
var present = connectParams.TryGetProperty("enableGitHubTelemetryForwarding", out var flag);
Assert.True(
!present || flag.ValueKind == JsonValueKind.Null,
"connect request should omit enableGitHubTelemetryForwarding (or send null) when no handler is registered");
}

[Fact]
public async Task CreateSession_Does_Not_Opt_In_Without_Handler()
{
Expand Down Expand Up @@ -187,6 +220,8 @@ public string Url

public JsonElement? LastResumeParams { get; private set; }

public JsonElement? LastConnectParams { get; private set; }

public static Task<FakeTelemetryServer> StartAsync()
{
var listener = new TcpListener(IPAddress.Loopback, 0);
Expand Down Expand Up @@ -267,12 +302,7 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel

object? result = method switch
{
"connect" => new Dictionary<string, object?>
{
["ok"] = true,
["protocolVersion"] = 3,
["version"] = "test",
},
"connect" => CaptureConnect(request),
"session.create" => CaptureCreate(request),
"session.resume" => CaptureResume(request),
"session.send" => new Dictionary<string, object?> { ["messageId"] = "message-1" },
Expand All @@ -289,6 +319,17 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel
}, cancellationToken);
}

private Dictionary<string, object?> CaptureConnect(JsonElement request)
{
LastConnectParams = request.TryGetProperty("params", out var p) ? p.Clone() : null;
return new Dictionary<string, object?>
{
["ok"] = true,
["protocolVersion"] = 3,
["version"] = "test",
};
}

private Dictionary<string, object?> CaptureCreate(JsonElement request)
{
LastCreateParams = request.TryGetProperty("params", out var p) ? p.Clone() : null;
Expand Down
19 changes: 18 additions & 1 deletion go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,15 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error {
t := c.effectiveConnectionToken
tokenPtr = &t
}
connectResult, err := c.internalRPC.Connect(ctx, &rpc.ConnectRequest{Token: tokenPtr})
connectReq := &connectHandshakeRequest{Token: tokenPtr}
// Opt in to GitHub telemetry forwarding at the connection level when a handler is
// registered (mirrors the runtime, which reads this flag on the `connect` handshake
// so the first session's un-replayable `session.start` event is forwarded). Also
// sent on session.create/resume for older CLIs.
if c.options.OnGitHubTelemetry != nil {
connectReq.EnableGitHubTelemetryForwarding = Bool(true)
}
rawConnectResult, err := c.client.Request(ctx, "connect", connectReq)
if err != nil {
var rpcErr *jsonrpc2.Error
if errors.As(err, &rpcErr) && (rpcErr.Code == jsonrpc2.ErrMethodNotFound.Code || rpcErr.Message == "Unhandled method connect") {
Expand All @@ -1700,6 +1708,10 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error {
return err
}
} else {
var connectResult rpc.ConnectResult
if err := json.Unmarshal(rawConnectResult, &connectResult); err != nil {
return err
}
v := int(connectResult.ProtocolVersion)
serverVersion = &v
}
Expand All @@ -1716,6 +1728,11 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error {
return nil
}

type connectHandshakeRequest struct {
Token *string `json:"token,omitempty"`
EnableGitHubTelemetryForwarding *bool `json:"enableGitHubTelemetryForwarding,omitempty"`
}

// stderrBufferSize is the maximum number of bytes kept from the CLI process's
// stderr. Only the tail is retained so that memory stays bounded even when the
// process produces a large amount of diagnostic output.
Expand Down
46 changes: 46 additions & 0 deletions go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,52 @@ func assertForwardingFlagAbsent(t *testing.T, params json.RawMessage) {
}
}

func TestClient_ForwardsGitHubTelemetryForwardingOnConnect(t *testing.T) {
rpcClient, server, _ := newRuntimeShutdownRpcPair(t)
t.Cleanup(server.Stop)
client := &Client{
client: rpcClient,
RPC: rpc.NewServerRPC(rpcClient),
internalRPC: rpc.NewInternalServerRPC(rpcClient),
sessions: make(map[string]*Session),
options: ClientOptions{OnGitHubTelemetry: func(*rpc.GitHubTelemetryNotification) {}},
}

connectParams := make(chan json.RawMessage, 1)
server.SetRequestHandler("connect", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) {
connectParams <- append(json.RawMessage(nil), params...)
return []byte(`{"ok":true,"protocolVersion":3,"version":"test"}`), nil
})

if err := client.verifyProtocolVersion(t.Context()); err != nil {
t.Fatalf("verifyProtocolVersion failed: %v", err)
}
assertForwardingFlagTrue(t, <-connectParams)
}

func TestClient_OmitsGitHubTelemetryForwardingOnConnectWhenNoHandler(t *testing.T) {
rpcClient, server, _ := newRuntimeShutdownRpcPair(t)
t.Cleanup(server.Stop)
client := &Client{
client: rpcClient,
RPC: rpc.NewServerRPC(rpcClient),
internalRPC: rpc.NewInternalServerRPC(rpcClient),
sessions: make(map[string]*Session),
options: ClientOptions{},
}

connectParams := make(chan json.RawMessage, 1)
server.SetRequestHandler("connect", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) {
connectParams <- append(json.RawMessage(nil), params...)
return []byte(`{"ok":true,"protocolVersion":3,"version":"test"}`), nil
})

if err := client.verifyProtocolVersion(t.Context()); err != nil {
t.Fatalf("verifyProtocolVersion failed: %v", err)
}
assertForwardingFlagAbsent(t, <-connectParams)
}

func TestGitHubTelemetryNotificationRoutesToCallback(t *testing.T) {
// The runtime forwards telemetry via a JSON-RPC *notification* (no id).
// Drive a real Content-Length-framed notification through the transport and
Expand Down
21 changes: 15 additions & 6 deletions java/src/main/java/com/github/copilot/CopilotClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import com.github.copilot.rpc.CreateSessionResponse;
import com.github.copilot.generated.rpc.SessionOptionsUpdateParams;
import com.github.copilot.generated.rpc.SessionInstalledPlugin;
import com.github.copilot.generated.rpc.ConnectParams;
import com.github.copilot.generated.rpc.ConnectResult;
import com.github.copilot.generated.rpc.GitHubTelemetryNotification;
import com.github.copilot.generated.rpc.ServerRpc;
import com.github.copilot.generated.rpc.SessionEventLogRegisterInterestParams;
Expand Down Expand Up @@ -306,11 +306,20 @@ private void verifyProtocolVersion(Connection connection) throws Exception {
Integer serverVersion;

try {
// Try the new 'connect' RPC which supports connection tokens
var connectParams = new ConnectParams(effectiveConnectionToken);
var connectResponse = connection.rpc
.invoke("connect", connectParams, com.github.copilot.generated.rpc.ConnectResult.class)
.get(30, TimeUnit.SECONDS);
// Try the new 'connect' RPC which supports connection tokens.
var connectParams = new HashMap<String, Object>();
if (effectiveConnectionToken != null) {
connectParams.put("token", effectiveConnectionToken);
}
// Opt into GitHub telemetry forwarding at the connection level when a handler
// is registered, so the runtime can forward the first session's un-replayable
// start event. Also sent on session create/resume for backward compatibility
// with servers that read the flag there instead.
if (this.options.getOnGitHubTelemetry() != null) {
connectParams.put("enableGitHubTelemetryForwarding", true);
}
var connectResponse = connection.rpc.invoke("connect", connectParams, ConnectResult.class).get(30,
TimeUnit.SECONDS);
serverVersion = connectResponse.protocolVersion() != null
? connectResponse.protocolVersion().intValue()
: null;
Expand Down
27 changes: 23 additions & 4 deletions java/src/test/java/com/github/copilot/GitHubTelemetryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
/**
* Exercises the hand-written GitHub telemetry forwarding surface: the
* {@code gitHubTelemetry.event} notification adapter, the
* {@code enableGitHubTelemetryForwarding} capability flag on the create/resume
* requests, and the {@code onGitHubTelemetry} client option.
* {@code enableGitHubTelemetryForwarding} capability flag on the connect
* handshake and the create/resume requests, and the {@code onGitHubTelemetry}
* client option.
*/
@AllowCopilotExperimental
class GitHubTelemetryTest {
Expand Down Expand Up @@ -146,6 +147,12 @@ void clientOptsSessionsIntoForwardingAndReceivesEvents() throws Exception {

client.start().get(15, TimeUnit.SECONDS);

// Connecting must opt into telemetry forwarding at the connection level so
// the runtime can forward the first session's un-replayable start event.
JsonNode connectParams = server.awaitConnect();
assertTrue(connectParams.path("enableGitHubTelemetryForwarding").asBoolean(),
"connect request should carry enableGitHubTelemetryForwarding=true");

// Creating a session must opt it into telemetry forwarding.
client.createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(15,
TimeUnit.SECONDS);
Expand Down Expand Up @@ -178,6 +185,10 @@ void clientOmitsForwardingWhenNoHandler() throws Exception {

client.start().get(15, TimeUnit.SECONDS);

JsonNode connectParams = server.awaitConnect();
assertFalse(connectParams.has("enableGitHubTelemetryForwarding"),
"connect request should omit the flag when no handler is registered");

client.createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(15,
TimeUnit.SECONDS);
JsonNode createParams = server.awaitCreate();
Expand Down Expand Up @@ -214,6 +225,7 @@ private static final class FakeRuntimeServer implements AutoCloseable {
private final ServerSocket serverSocket;
private final Thread acceptThread;
private final CompletableFuture<JsonRpcClient> ready = new CompletableFuture<>();
private final CompletableFuture<JsonNode> connectParams = new CompletableFuture<>();
private final CompletableFuture<JsonNode> createParams = new CompletableFuture<>();
private final CompletableFuture<JsonNode> resumeParams = new CompletableFuture<>();

Expand All @@ -228,6 +240,10 @@ String url() {
return "127.0.0.1:" + serverSocket.getLocalPort();
}

JsonNode awaitConnect() throws Exception {
return connectParams.get(15, TimeUnit.SECONDS);
}

JsonNode awaitCreate() throws Exception {
return createParams.get(15, TimeUnit.SECONDS);
}
Expand All @@ -244,8 +260,10 @@ private void acceptLoop() {
try {
Socket socket = serverSocket.accept();
JsonRpcClient server = JsonRpcClient.fromSocket(socket);
server.registerMethodHandler("connect",
(id, params) -> respond(server, id, Map.of("protocolVersion", 2)));
server.registerMethodHandler("connect", (id, params) -> {
connectParams.complete(params);
respond(server, id, Map.of("protocolVersion", 2));
});
server.registerMethodHandler("session.create", (id, params) -> {
createParams.complete(params);
respond(server, id, Map.of("sessionId", params.path("sessionId").asText("created"), "workspacePath",
Expand All @@ -261,6 +279,7 @@ private void acceptLoop() {
ready.complete(server);
} catch (IOException e) {
ready.completeExceptionally(e);
connectParams.completeExceptionally(e);
createParams.completeExceptionally(e);
resumeParams.completeExceptionally(e);
}
Expand Down
15 changes: 12 additions & 3 deletions nodejs/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1846,9 +1846,18 @@ export class CopilotClient {

let serverVersion: number | undefined;
try {
const result = await raceAgainstExit(
this.internalRpc.connect({ token: this.effectiveConnectionToken })
);
const connectParams: {
token?: string;
enableGitHubTelemetryForwarding?: boolean;
} = { token: this.effectiveConnectionToken };
// Opt in to GitHub telemetry forwarding at the connection level when a
// handler is registered (mirrors the runtime, which reads this flag on the
// `connect` handshake so the first session's un-replayable `session.start`
// event is forwarded). Also sent on session.create/resume for older CLIs.
if (this.onGitHubTelemetry != null) {
connectParams.enableGitHubTelemetryForwarding = true;
}
const result = await raceAgainstExit(this.internalRpc.connect(connectParams));
serverVersion = result.protocolVersion;
} catch (err) {
if (
Expand Down
34 changes: 34 additions & 0 deletions nodejs/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,40 @@ describe("CopilotClient", () => {
expect(resumePayload.enableGitHubTelemetryForwarding).toBe(true);
});

it("opts into GitHub telemetry forwarding on the connect handshake when a handler is provided", async () => {
const client = new CopilotClient({ onGitHubTelemetry: () => {} });
onTestFinished(() => client.forceStop());

const sendRequest = vi.fn(async (method: string) => {
if (method === "connect") return { ok: true, protocolVersion: 3, version: "test" };
throw new Error(`Unexpected method: ${method}`);
});
(client as any).connection = { sendRequest };

await (client as any).verifyProtocolVersion();

const connectCall = sendRequest.mock.calls.find(([method]) => method === "connect");
expect(connectCall).toBeDefined();
expect((connectCall![1] as any).enableGitHubTelemetryForwarding).toBe(true);
});

it("does not opt into GitHub telemetry forwarding on the connect handshake without a handler", async () => {
const client = new CopilotClient();
onTestFinished(() => client.forceStop());

const sendRequest = vi.fn(async (method: string) => {
if (method === "connect") return { ok: true, protocolVersion: 3, version: "test" };
throw new Error(`Unexpected method: ${method}`);
});
(client as any).connection = { sendRequest };

await (client as any).verifyProtocolVersion();

const connectCall = sendRequest.mock.calls.find(([method]) => method === "connect");
expect(connectCall).toBeDefined();
expect((connectCall![1] as any).enableGitHubTelemetryForwarding).toBeUndefined();
});

it("does not opt into GitHub telemetry forwarding without a handler", async () => {
const client = new CopilotClient();
await client.start();
Expand Down
Loading
Loading