diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..f0a4f54 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,14 @@ +.git +.github +.env +.env.* +!.env.example + +/gateway +/bin +/dist +/coverage.out +/backups +/models + +.DS_Store diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ccf3586 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + name: Build and test gateway + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.26.4" + + - name: Download dependencies + run: go mod download + + - name: Secret scan + run: go run github.com/zricethezav/gitleaks/v8@v8.30.1 detect --no-git --source . --redact --verbose + + - name: Build gateway + run: go build ./apps/gateway + + - name: Run tests + run: go test ./... + + docker: + name: Build Docker image + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Build gateway image + run: docker build -f apps/gateway/Dockerfile -t dappnode-nexus-gateway:ci . diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfcc008 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +.env +.env.* +!.env.example + +# Build outputs +/gateway +/bin/ +/dist/ +/coverage.out + +# Local data and generated artifacts +/backups/ +/models/ + +# Go/tool caches +/.cache/ + +# OS/editor noise +.DS_Store +.vscode/* +!.vscode/extensions.json diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..64c817c --- /dev/null +++ b/LICENSE @@ -0,0 +1,157 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, +made available under the License, as indicated by a copyright notice that is +included in or attached to the work. + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original +version of the Work and any modifications or additions to that Work or +Derivative Works thereof, that is intentionally submitted to Licensor for +inclusion in the Work by the copyright owner or by an individual or Legal Entity +authorized to submit on behalf of the copyright owner. For the purposes of this +definition, "submitted" means any form of electronic, verbal, or written +communication sent to the Licensor or its representatives, including but not +limited to communication on electronic mailing lists, source code control +systems, and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but excluding +communication that is conspicuously marked or otherwise designated in writing by +the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this +License, each Contributor hereby grants to You a perpetual, worldwide, +non-exclusive, no-charge, royalty-free, irrevocable copyright license to +reproduce, prepare Derivative Works of, publicly display, publicly perform, +sublicense, and distribute the Work and such Derivative Works in Source or +Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, +each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) patent +license to make, have made, use, offer to sell, sell, import, and otherwise +transfer the Work, where such license applies only to those patent claims +licensable by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) with the Work +to which such Contribution(s) was submitted. If You institute patent litigation +against any entity (including a cross-claim or counterclaim in a lawsuit) +alleging that the Work or a Contribution incorporated within the Work +constitutes direct or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate as of the date +such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or +Derivative Works thereof in any medium, with or without modifications, and in +Source or Object form, provided that You meet the following conditions: + +(a) You must give any other recipients of the Work or Derivative Works a copy of +this License; and + +(b) You must cause any modified files to carry prominent notices stating that +You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works that You +distribute, all copyright, patent, trademark, and attribution notices from the +Source form of the Work, excluding those notices that do not pertain to any part +of the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its distribution, then +any Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. + +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any +Contribution intentionally submitted for inclusion in the Work by You to the +Licensor shall be under the terms and conditions of this License, without any +additional terms or conditions. Notwithstanding the above, nothing herein shall +supersede or modify the terms of any separate license agreement you may have +executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, +trademarks, service marks, or product names of the Licensor, except as required +for reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in +writing, Licensor provides the Work (and each Contributor provides its +Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied, including, without limitation, any warranties or +conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any risks +associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in +tort (including negligence), contract, or otherwise, unless required by +applicable law (such as deliberate and grossly negligent acts) or agreed to in +writing, shall any Contributor be liable to You for damages, including any +direct, indirect, special, incidental, or consequential damages of any character +arising as a result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, work stoppage, +computer failure or malfunction, or any and all other commercial damages or +losses), even if such Contributor has been advised of the possibility of such +damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or +Derivative Works thereof, You may choose to offer, and charge a fee for, +acceptance of support, warranty, indemnity, or other liability obligations +and/or rights consistent with this License. However, in accepting such +obligations, You may act only on Your own behalf and on Your sole +responsibility, not on behalf of any other Contributor, and only if You agree to +indemnify, defend, and hold each Contributor harmless for any liability incurred +by, or claims asserted against, such Contributor by reason of your accepting any +such warranty or additional liability. diff --git a/README.md b/README.md index 8641584..90d7930 100644 --- a/README.md +++ b/README.md @@ -1 +1,118 @@ -# dappnode-nexus-gateway \ No newline at end of file +# Dappnode Nexus Gateway + +Dappnode Nexus Gateway is a Go HTTP service that exposes an OpenAI-compatible +API for model inference. + +It authenticates Nexus API keys, resolves public model configuration from +Postgres, forwards requests to configured upstream providers, records usage, and +stores transport proof metadata for supported providers. + +## Features + +- OpenAI-compatible `/v1/models` and `/v1/chat/completions` endpoints. +- API-key authentication backed by Postgres. +- Provider selection and model configuration from database state. +- Support for OpenAI-compatible providers, Anthropic, and Tinfoil. +- Streaming and non-streaming chat completions. +- Usage accounting for successful and failed requests. +- Optional PII masking through Microsoft Presidio. +- Prometheus metrics endpoint. +- Tinfoil transport proof lookup endpoint. + +## Requirements + +- Go 1.26.4 or newer. +- Postgres with the gateway tables, API keys, model catalog, provider + configuration, pricing, routing, usage, and proof tables already present. +- Provider API keys for the upstreams you enable. +- A routing endpoint reachable through `ROUTER_URL` when router-backed models + are enabled. +- Presidio Analyzer if PII filtering is enabled. + +## Quick Start + +```sh +go test ./... +go run ./apps/gateway +``` + +By default the gateway listens on `:8080`, exposes metrics on `:9090`, and +connects to: + +```text +postgres://nexus:nexus@localhost:5432/nexus?sslmode=disable +``` + +Build the container image with: + +```sh +docker build -f apps/gateway/Dockerfile -t dappnode-nexus-gateway:local . +``` + +## Configuration + +| Variable | Default | Description | +| --- | --- | --- | +| `PORT` | `8080` | HTTP API port. | +| `METRICS_PORT` | `9090` | Prometheus metrics port. | +| `DATABASE_URL` | `postgres://nexus:nexus@localhost:5432/nexus?sslmode=disable` | Postgres connection string. | +| `ROUTER_URL` | `http://localhost:8083` | Routing service base URL. | +| `LOG_LEVEL` | `info` | Logger level. | +| `EUR_TO_USD_FALLBACK_RATE` | `1.08` | Fallback exchange rate for model pricing display. | +| `OPENAI_API_KEY` | empty | OpenAI provider credential. | +| `ANTHROPIC_API_KEY` | empty | Anthropic provider credential. | +| `NOVITA_API_KEY` | empty | Novita provider credential. | +| `PHALA_API_KEY` | empty | Phala provider credential. | +| `MINIMAX_API_KEY` | empty | MiniMax provider credential. | +| `DEEPSEEK_API_KEY` | empty | DeepSeek provider credential. | +| `TINFOIL_API_KEY` | empty | Tinfoil provider credential. | +| `TINFOIL_PROXY_BASE_URL` | empty | Optional Tinfoil proxy base URL. | +| `PII_FILTER_ENABLED` | `true` | Enables Presidio-backed PII masking for keys configured to use it. | +| `PRESIDIO_ANALYZER_URL` | `http://presidio-analyzer:3000` | Presidio Analyzer base URL. | +| `PII_FILTER_LANGUAGE` | `en` | Language passed to the PII analyzer. | +| `PII_FILTER_SCORE_THRESHOLD` | `0.4` | Minimum analyzer score for masking. | +| `PII_FILTER_TIMEOUT_MS` | `1500` | PII analyzer timeout in milliseconds. | +| `PII_FILTER_FAIL_OPEN` | `false` | Allows requests to continue if PII masking is unavailable. | + +Provider records in Postgres point to the environment variable that contains +the provider credential. This lets model/provider configuration change without +putting secrets in the database. + +## API + +### `GET /healthz` + +Returns service health. + +### `GET /v1/models` + +Returns the public model catalog in an OpenAI-compatible shape. + +### `POST /v1/chat/completions` + +Runs a chat completion request. Both streaming and non-streaming responses are +supported. + +### `GET /v1/tinfoil/proofs/{response_id}` + +Returns stored Tinfoil proof metadata for a gateway response when available. + +## Development + +```sh +go test ./... +go build ./apps/gateway +``` + +Run a local image build before publishing container changes: + +```sh +docker build -f apps/gateway/Dockerfile -t dappnode-nexus-gateway:local . +``` + +Do not commit real `.env` files, provider API keys, database credentials, or +generated runtime configuration. + +## License + +Apache-2.0. See [LICENSE](LICENSE). diff --git a/apps/gateway/Dockerfile b/apps/gateway/Dockerfile new file mode 100644 index 0000000..c88d2de --- /dev/null +++ b/apps/gateway/Dockerfile @@ -0,0 +1,23 @@ +# ---- Build stage ---- +FROM golang:1.26.4-alpine AS builder + +WORKDIR /app + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /gateway ./apps/gateway + +# ---- Runtime stage ---- +FROM alpine:3.20 + +RUN apk --no-cache add ca-certificates tzdata + +WORKDIR /app + +COPY --from=builder /gateway . + +EXPOSE 8080 9090 + +ENTRYPOINT ["/app/gateway"] diff --git a/apps/gateway/internal/adapters/auth/apikeys/service.go b/apps/gateway/internal/adapters/auth/apikeys/service.go new file mode 100644 index 0000000..3e6fd48 --- /dev/null +++ b/apps/gateway/internal/adapters/auth/apikeys/service.go @@ -0,0 +1,64 @@ +package apikeys + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/storage/postgres" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// Service implements the AuthService port using API key hash lookup. +type Service struct { + repo *postgres.APIKeyRepo + logger ports.Logger +} + +func NewService(repo *postgres.APIKeyRepo, logger ports.Logger) *Service { + return &Service{repo: repo, logger: logger} +} + +func (s *Service) AuthenticateAPIKey(ctx context.Context, rawKey string) (domain.AuthContext, error) { + if rawKey == "" { + return domain.AuthContext{}, domain.ErrInvalidAPIKey("missing API key") + } + + key, account, err := s.repo.FindByHash(ctx, rawKey) + if err != nil { + s.logger.Warn("auth lookup failed", + "error", err.Error(), + "key_prefix", safePrefix(rawKey), + ) + return domain.AuthContext{}, domain.ErrInvalidAPIKey("invalid API key") + } + + if !key.Active { + s.logger.Warn("inactive API key used", + "key_id", key.ID, + "account_id", key.AccountID, + "key_prefix", key.KeyPrefix, + ) + return domain.AuthContext{}, domain.ErrInactiveAPIKey() + } + + if !account.IsActive() { + s.logger.Warn("inactive account attempted access", + "account_id", account.ID, + "key_id", key.ID, + ) + return domain.AuthContext{}, domain.ErrInactiveAccount() + } + + return domain.AuthContext{ + Account: account, + APIKey: key, + }, nil +} + +// safePrefix returns the first 8 chars of a key for log correlation without leaking secrets. +func safePrefix(key string) string { + if len(key) <= 8 { + return key[:len(key)/2] + "..." + } + return key[:8] + "..." +} diff --git a/apps/gateway/internal/adapters/billing/reserver.go b/apps/gateway/internal/adapters/billing/reserver.go new file mode 100644 index 0000000..73b8685 --- /dev/null +++ b/apps/gateway/internal/adapters/billing/reserver.go @@ -0,0 +1,65 @@ +package billing + +import ( + "sync" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/google/uuid" +) + +type reservation struct { + accountID string + amountMicrocents int64 +} + +// InMemoryReserver tracks in-flight cost reservations using an in-memory map. +// It is safe for concurrent use. +type InMemoryReserver struct { + mu sync.Mutex + reservations map[string]reservation + accountTotals map[string]int64 +} + +func NewInMemoryReserver() *InMemoryReserver { + return &InMemoryReserver{ + reservations: make(map[string]reservation), + accountTotals: make(map[string]int64), + } +} + +// TryReserve checks that existing reservations plus maxCostMicrocents do not exceed +// availableBalanceMicrocents. On success it creates a reservation and returns its ID. +func (r *InMemoryReserver) TryReserve(accountID string, availableBalanceMicrocents int64, maxCostMicrocents int64) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + + currentReserved := r.accountTotals[accountID] + if currentReserved+maxCostMicrocents > availableBalanceMicrocents { + return "", domain.ErrInsufficientBalance() + } + + id := uuid.New().String() + r.reservations[id] = reservation{ + accountID: accountID, + amountMicrocents: maxCostMicrocents, + } + r.accountTotals[accountID] += maxCostMicrocents + return id, nil +} + +// Release removes a reservation, freeing the held amount. Safe to call multiple times. +func (r *InMemoryReserver) Release(reservationID string) { + r.mu.Lock() + defer r.mu.Unlock() + + res, ok := r.reservations[reservationID] + if !ok { + return + } + + delete(r.reservations, reservationID) + r.accountTotals[res.accountID] -= res.amountMicrocents + if r.accountTotals[res.accountID] <= 0 { + delete(r.accountTotals, res.accountID) + } +} diff --git a/apps/gateway/internal/adapters/billing/reserver_test.go b/apps/gateway/internal/adapters/billing/reserver_test.go new file mode 100644 index 0000000..457c559 --- /dev/null +++ b/apps/gateway/internal/adapters/billing/reserver_test.go @@ -0,0 +1,165 @@ +package billing + +import ( + "sync" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestInMemoryReserver_BasicReserveAndRelease(t *testing.T) { + r := NewInMemoryReserver() + + id, err := r.TryReserve("acc1", 1000, 400) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if id == "" { + t.Fatal("expected non-empty reservation ID") + } + + r.Release(id) + + // After release, the full balance should be available again. + id2, err := r.TryReserve("acc1", 1000, 1000) + if err != nil { + t.Fatalf("unexpected error after release: %v", err) + } + r.Release(id2) +} + +func TestInMemoryReserver_ConcurrentReservationsExhaustBalance(t *testing.T) { + r := NewInMemoryReserver() + balance := int64(1000) + + // First reservation: 600 out of 1000 + id1, err := r.TryReserve("acc1", balance, 600) + if err != nil { + t.Fatalf("first reserve: %v", err) + } + + // Second reservation: 500 out of 1000 — should fail (600 + 500 > 1000) + _, err = r.TryReserve("acc1", balance, 500) + if err == nil { + t.Fatal("expected insufficient balance error") + } + gwErr, ok := err.(*domain.GatewayError) + if !ok { + t.Fatalf("expected GatewayError, got %T", err) + } + if gwErr.Code != domain.ErrCodeInsufficientBalance { + t.Fatalf("error code = %s, want %s", gwErr.Code, domain.ErrCodeInsufficientBalance) + } + + // Second reservation: 400 out of 1000 — should succeed (600 + 400 = 1000) + id2, err := r.TryReserve("acc1", balance, 400) + if err != nil { + t.Fatalf("second reserve (400): %v", err) + } + + r.Release(id1) + r.Release(id2) +} + +func TestInMemoryReserver_ReleaseFreesCapacity(t *testing.T) { + r := NewInMemoryReserver() + balance := int64(1000) + + id1, _ := r.TryReserve("acc1", balance, 900) + + // Can't fit another 200 + _, err := r.TryReserve("acc1", balance, 200) + if err == nil { + t.Fatal("expected insufficient balance error") + } + + // Release first reservation + r.Release(id1) + + // Now 200 should fit + id2, err := r.TryReserve("acc1", balance, 200) + if err != nil { + t.Fatalf("reserve after release: %v", err) + } + r.Release(id2) +} + +func TestInMemoryReserver_DoubleReleaseIsSafe(t *testing.T) { + r := NewInMemoryReserver() + + id, _ := r.TryReserve("acc1", 1000, 500) + r.Release(id) + r.Release(id) // should not panic or corrupt state + + // Verify state is clean — full balance available + id2, err := r.TryReserve("acc1", 1000, 1000) + if err != nil { + t.Fatalf("reserve after double release: %v", err) + } + r.Release(id2) +} + +func TestInMemoryReserver_IndependentAccounts(t *testing.T) { + r := NewInMemoryReserver() + + // Two accounts with different balances should not interfere + id1, err := r.TryReserve("acc1", 500, 500) + if err != nil { + t.Fatalf("acc1 reserve: %v", err) + } + + id2, err := r.TryReserve("acc2", 300, 300) + if err != nil { + t.Fatalf("acc2 reserve: %v", err) + } + + r.Release(id1) + r.Release(id2) +} + +func TestInMemoryReserver_ConcurrentGoroutines(t *testing.T) { + r := NewInMemoryReserver() + balance := int64(10_000_000) + costPerRequest := int64(100_000) + numGoroutines := 50 + + var wg sync.WaitGroup + successCh := make(chan string, numGoroutines) + failCh := make(chan struct{}, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + id, err := r.TryReserve("acc1", balance, costPerRequest) + if err != nil { + failCh <- struct{}{} + return + } + successCh <- id + }() + } + + wg.Wait() + close(successCh) + close(failCh) + + successes := 0 + var ids []string + for id := range successCh { + successes++ + ids = append(ids, id) + } + failures := len(failCh) + + // balance / costPerRequest = 100, and we only launched 50 goroutines, + // so all should succeed. + if successes != numGoroutines { + t.Errorf("successes = %d, want %d (failures = %d)", successes, numGoroutines, failures) + } + + // Clean up + for _, id := range ids { + r.Release(id) + } +} diff --git a/apps/gateway/internal/adapters/http/dto/chat_completion_request.go b/apps/gateway/internal/adapters/http/dto/chat_completion_request.go new file mode 100644 index 0000000..2f06762 --- /dev/null +++ b/apps/gateway/internal/adapters/http/dto/chat_completion_request.go @@ -0,0 +1,75 @@ +package dto + +import "encoding/json" + +// ChatCompletionRequest is the DTO for POST /v1/chat/completions. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Stream *bool `json:"stream,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N *int `json:"n,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + ReasoningEffort *string `json:"reasoning_effort,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` + Tools []ToolDefinition `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + User *string `json:"user,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + ProviderOptions map[string]any `json:"provider_options,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + Seed *int `json:"seed,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + Store *bool `json:"store,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` +} + +// ChatMessage is a message in the chat completions format. +type ChatMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` + Name *string `json:"name,omitempty"` +} + +// ChatToolCall is a tool call in assistant messages. +type ChatToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall holds the function call details. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ResponseFormat for structured outputs. +type ResponseFormat struct { + Type string `json:"type"` + JSONSchema map[string]any `json:"json_schema,omitempty"` +} + +// ToolDefinition describes a tool in the chat completions format. +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction holds the function details inside a ToolDefinition. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} diff --git a/apps/gateway/internal/adapters/http/dto/chat_completion_response.go b/apps/gateway/internal/adapters/http/dto/chat_completion_response.go new file mode 100644 index 0000000..80183f0 --- /dev/null +++ b/apps/gateway/internal/adapters/http/dto/chat_completion_response.go @@ -0,0 +1,66 @@ +package dto + +// ChatCompletionResponse is the non-streaming response for POST /v1/chat/completions. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage *ChatCompletionUsage `json:"usage,omitempty"` +} + +// ChatCompletionChoice is a single choice in the response. +type ChatCompletionChoice struct { + Index int `json:"index"` + Message ChatChoiceMessage `json:"message"` + FinishReason *string `json:"finish_reason"` +} + +// ChatChoiceMessage is the assistant message in a choice. +type ChatChoiceMessage struct { + Role string `json:"role"` + Content *string `json:"content"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` +} + +// ChatCompletionUsage holds token usage for chat completions. +type ChatCompletionUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// ChatCompletionChunk is a streaming chunk for chat completions. +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChunkChoice `json:"choices"` + Usage *ChatCompletionUsage `json:"usage,omitempty"` +} + +// ChatCompletionChunkChoice is a streaming choice delta. +type ChatCompletionChunkChoice struct { + Index int `json:"index"` + Delta ChatChunkDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +// ChatChunkDelta holds the delta content in a streaming chunk. +type ChatChunkDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ChatToolCallChunk `json:"tool_calls,omitempty"` +} + +// ChatToolCallChunk is a partial tool call in a stream chunk. +type ChatToolCallChunk struct { + Index int `json:"index"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function ChatFunctionCall `json:"function"` +} diff --git a/apps/gateway/internal/adapters/http/dto/error_response.go b/apps/gateway/internal/adapters/http/dto/error_response.go new file mode 100644 index 0000000..689ad79 --- /dev/null +++ b/apps/gateway/internal/adapters/http/dto/error_response.go @@ -0,0 +1,13 @@ +package dto + +// ErrorResponse is the standard error envelope. +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail contains the error classification fields. +type ErrorDetail struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` +} diff --git a/apps/gateway/internal/adapters/http/dto/models_response.go b/apps/gateway/internal/adapters/http/dto/models_response.go new file mode 100644 index 0000000..2251fa2 --- /dev/null +++ b/apps/gateway/internal/adapters/http/dto/models_response.go @@ -0,0 +1,34 @@ +package dto + +// ModelsResponse is the response envelope for GET /v1/models. +type ModelsResponse struct { + Object string `json:"object"` + Data []ModelData `json:"data"` +} + +// ModelData is a single model entry. +type ModelData struct { + ID string `json:"id"` + Object string `json:"object"` + Kind string `json:"kind"` + BaseModel *string `json:"base_model,omitempty"` + OwnedBy string `json:"owned_by"` + Created int64 `json:"created"` + DisplayName string `json:"display_name"` + Description *string `json:"description,omitempty"` + ContextSize *int `json:"context_size,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Currency *string `json:"currency,omitempty"` + InputPricePer1MTokensCents *int64 `json:"input_price_per_1m_tokens_cents,omitempty"` + OutputPricePer1MTokensCents *int64 `json:"output_price_per_1m_tokens_cents,omitempty"` + CacheReadPricePer1MTokensCents *int64 `json:"cache_read_price_per_1m_tokens_cents,omitempty"` + CacheWritePricePer1MTokensCents *int64 `json:"cache_write_price_per_1m_tokens_cents,omitempty"` + InputPricePer1MTokensUSD *float64 `json:"input_price_per_1m_tokens_usd,omitempty"` + OutputPricePer1MTokensUSD *float64 `json:"output_price_per_1m_tokens_usd,omitempty"` + CacheReadPricePer1MTokensUSD *float64 `json:"cache_read_price_per_1m_tokens_usd,omitempty"` + CacheWritePricePer1MTokensUSD *float64 `json:"cache_write_price_per_1m_tokens_usd,omitempty"` + Features []string `json:"features"` + Endpoints []string `json:"endpoints"` + ProofMode string `json:"proof_mode"` + ProofsEnabled bool `json:"proofs_enabled"` +} diff --git a/apps/gateway/internal/adapters/http/handlers/chat_completions_handler.go b/apps/gateway/internal/adapters/http/handlers/chat_completions_handler.go new file mode 100644 index 0000000..f6734d9 --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/chat_completions_handler.go @@ -0,0 +1,147 @@ +package handlers + +import ( + "io" + "net/http" + "sync" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/mapper" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/middleware" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/sse" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/services" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/google/uuid" +) + +// ChatCompletionsHandler handles POST /v1/chat/completions. +type ChatCompletionsHandler struct { + service *services.ChatCompletionsService + logger ports.Logger +} + +func NewChatCompletionsHandler(service *services.ChatCompletionsService, logger ports.Logger) *ChatCompletionsHandler { + return &ChatCompletionsHandler{service: service, logger: logger} +} + +func (h *ChatCompletionsHandler) Handle(w http.ResponseWriter, r *http.Request) { + token, err := ExtractBearerToken(r) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, maxRequestBodySize)) + if err != nil { + WriteErrorWithLog(w, r, h.logger, domain.ErrInvalidField("failed to read request body")) + return + } + + if fields := mapper.UnknownChatCompletionFields(body); len(fields) > 0 { + h.logger.Warn("chat completion request ignored unknown fields", + "request_id", middleware.GetRequestID(r.Context()), + "path", r.URL.Path, + "fields", fields, + ) + } + + genReq, err := mapper.ChatCompletionRequestToDomain(body) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + + if genReq.Stream { + h.handleStream(w, r, genReq, token) + return + } + + result, err := h.service.Execute(r.Context(), genReq, token) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + + resp := mapper.DomainToChatCompletionResponse(result) + WriteJSON(w, http.StatusOK, resp) +} + +func (h *ChatCompletionsHandler) handleStream(w http.ResponseWriter, r *http.Request, genReq domain.GenerateRequest, token string) { + stream, model, err := h.service.ExecuteStream(r.Context(), genReq, token) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + defer stream.Close() + responseModelID := genReq.PublicModelID + if model != nil { + responseModelID = model.PublicModelID + } + + sw, err := sse.NewWriter(w) + if err != nil { + WriteError(w, domain.ErrInternal("streaming not supported")) + return + } + + responseID := uuid.New().String()[:12] + createdAt := time.Now().Unix() + + // Keep-alive: write SSE comments periodically to prevent proxy timeouts. + var writeMu sync.Mutex + stopKeepAlive := make(chan struct{}) + defer close(stopKeepAlive) + go func() { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + for { + select { + case <-stopKeepAlive: + return + case <-ticker.C: + writeMu.Lock() + sw.WriteComment("keepalive") + writeMu.Unlock() + } + } + }() + + for { + event, err := stream.Recv() + if err != nil { + if err == io.EOF { + break + } + writeMu.Lock() + sw.WriteData(map[string]any{ + "error": map[string]any{"type": "internal_error", "message": err.Error()}, + }) + writeMu.Unlock() + return + } + if event.ProviderResponseID != "" { + responseID = event.ProviderResponseID + } + + chunk, done := mapper.DomainStreamEventToChatChunk(event, responseModelID, responseID, createdAt) + if chunk != nil { + writeMu.Lock() + sw.WriteData(chunk) + writeMu.Unlock() + } + if done { + writeMu.Lock() + sw.WriteDone() + writeMu.Unlock() + // Keep draining so the usage-tracking wrapper can accumulate + // the final usage chunk before io.EOF triggers recording. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + break + } + } +} diff --git a/apps/gateway/internal/adapters/http/handlers/common.go b/apps/gateway/internal/adapters/http/handlers/common.go new file mode 100644 index 0000000..b8e53a9 --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/common.go @@ -0,0 +1,93 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/dto" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/middleware" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +const maxRequestBodySize = 10 * 1024 * 1024 // 10 MB + +// WriteJSON writes a JSON response. +func WriteJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +// WriteError writes a GatewayError as a JSON error response. +func WriteError(w http.ResponseWriter, err error) { + gwErr, ok := err.(*domain.GatewayError) + if !ok { + gwErr = domain.ErrInternal("an internal error occurred") + } + WriteJSON(w, gwErr.HTTPStatus, dto.ErrorResponse{ + Error: dto.ErrorDetail{ + Type: gwErr.Type, + Code: gwErr.Code, + Message: gwErr.Message, + }, + }) +} + +// WriteErrorWithLog writes a GatewayError as a JSON error response and logs it server-side. +func WriteErrorWithLog(w http.ResponseWriter, r *http.Request, logger ports.Logger, err error) { + gwErr, ok := err.(*domain.GatewayError) + if !ok { + logger.Error("untyped internal error", + "request_id", middleware.GetRequestID(r.Context()), + "path", r.URL.Path, + "original_error", err.Error(), + ) + gwErr = domain.ErrInternal("an internal error occurred") + } + + // Log at warn for client errors, error for server errors. + if gwErr.HTTPStatus >= 500 { + fields := []any{ + "request_id", middleware.GetRequestID(r.Context()), + "status", gwErr.HTTPStatus, + "gateway_status", gwErr.HTTPStatus, + "error_code", gwErr.Code, + "error", gwErr.Message, + } + fields = append(fields, gwErr.LogFields()...) + logger.Error("request failed", fields...) + } else if gwErr.HTTPStatus >= 400 { + logger.Warn("request error", + "request_id", middleware.GetRequestID(r.Context()), + "status", gwErr.HTTPStatus, + "gateway_status", gwErr.HTTPStatus, + "error_code", gwErr.Code, + "error", gwErr.Message, + ) + } + + WriteJSON(w, gwErr.HTTPStatus, dto.ErrorResponse{ + Error: dto.ErrorDetail{ + Type: gwErr.Type, + Code: gwErr.Code, + Message: gwErr.Message, + }, + }) +} + +// ExtractBearerToken extracts the bearer token from the Authorization header. +func ExtractBearerToken(r *http.Request) (string, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", domain.ErrInvalidAPIKey("missing Authorization header") + } + if len(auth) < 8 || auth[:7] != "Bearer " { + return "", domain.ErrInvalidAPIKey("malformed Authorization header, expected 'Bearer '") + } + token := auth[7:] + if token == "" { + return "", domain.ErrInvalidAPIKey("empty bearer token") + } + return token, nil +} diff --git a/apps/gateway/internal/adapters/http/handlers/handlers_test.go b/apps/gateway/internal/adapters/http/handlers/handlers_test.go new file mode 100644 index 0000000..0b8f46f --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/handlers_test.go @@ -0,0 +1,346 @@ +package handlers_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/handlers" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/services" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/shopspring/decimal" +) + +// --- Mock implementations --- + +type mockAuthService struct { + authCtx domain.AuthContext + err error +} + +func (m *mockAuthService) AuthenticateAPIKey(_ context.Context, _ string) (domain.AuthContext, error) { + return m.authCtx, m.err +} + +type mockModelCatalog struct { + models []domain.PublicModel + model domain.PublicModel + err error +} + +func (m *mockModelCatalog) ListPublicModels(_ context.Context) ([]domain.PublicModel, error) { + return m.models, m.err +} + +func (m *mockModelCatalog) GetPublicModel(_ context.Context, _ string) (domain.PublicModel, error) { + return m.model, m.err +} + +func (m *mockModelCatalog) ListRouters(_ context.Context) ([]domain.RouterEntry, error) { + return nil, nil +} + +func (m *mockModelCatalog) GetRouter(_ context.Context, routerID string) (domain.RouterEntry, error) { + return domain.RouterEntry{}, domain.ErrNotFound("router", routerID) +} + +type mockBalanceChecker struct { + balance int64 + err error +} + +func (m *mockBalanceChecker) GetSpendableBalance(_ context.Context, _ string) (int64, error) { + return m.balance, m.err +} + +type mockUsageRecorder struct{} + +func (m *mockUsageRecorder) RecordSuccess(_ context.Context, _ domain.AuthContext, _ string, _ domain.GenerateRequest, _ domain.GenerateResult, _ domain.PublicModel, _ int64) error { + return nil +} + +func (m *mockUsageRecorder) RecordFailure(_ context.Context, _ *domain.AuthContext, _ string, _ *domain.GenerateRequest, _ *domain.PublicModel, _ error, _ *domain.Usage, _ int64) error { + return nil +} + +type mockCostReserver struct{} + +func (m *mockCostReserver) TryReserve(_ string, _ int64, _ int64) (string, error) { + return "mock-reservation", nil +} + +func (m *mockCostReserver) Release(_ string) {} + +type mockLogger struct{} + +func (m *mockLogger) Debug(_ string, _ ...any) {} +func (m *mockLogger) Info(_ string, _ ...any) {} +func (m *mockLogger) Warn(_ string, _ ...any) {} +func (m *mockLogger) Error(_ string, _ ...any) {} + +// --- Tests --- + +func TestHealthHandler(t *testing.T) { + h := handlers.NewHealthHandler() + req := httptest.NewRequest("GET", "/healthz", nil) + w := httptest.NewRecorder() + h.Handle(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + var body map[string]string + json.NewDecoder(resp.Body).Decode(&body) + if body["status"] != "ok" { + t.Errorf("status = %s, want ok", body["status"]) + } +} + +func TestModelsHandler_NoAuthRequired(t *testing.T) { + catalog := &mockModelCatalog{} + logger := &mockLogger{} + + svc := services.NewListModelsService(catalog, logger) + h := handlers.NewModelsHandler(svc) + + req := httptest.NewRequest("GET", "/v1/models", nil) + w := httptest.NewRecorder() + h.Handle(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200 (endpoint is public)", resp.StatusCode) + } +} + +func TestModelsHandler_ListsModels(t *testing.T) { + desc := "GPT-4.1 Mini model" + catalog := &mockModelCatalog{ + models: []domain.PublicModel{ + { + PublicModelID: "openai/gpt-4.1-mini", + DisplayName: "GPT-4.1 Mini", + Description: &desc, + ProviderConfig: domain.ProviderConfig{ProviderName: "openai"}, + Active: true, + SupportsChatCompletions: true, + SupportsChatCompletionsStream: true, + SupportsTools: true, + SupportsStructuredOutput: true, + MaxContextWindow: 128000, + MaxOutputTokens: 16000, + Currency: "EUR", + InputPricePerMillion: decimal.NewFromFloat(0.4), + OutputPricePerMillion: decimal.NewFromFloat(1.6), + }, + { + PublicModelID: "anthropic/claude-sonnet-4", + DisplayName: "Claude Sonnet 4", + ProviderConfig: domain.ProviderConfig{ProviderName: "anthropic"}, + Active: true, + SupportsChatCompletions: true, + Currency: "EUR", + InputPricePerMillion: decimal.NewFromFloat(3.0), + OutputPricePerMillion: decimal.NewFromFloat(15.0), + }, + }, + } + logger := &mockLogger{} + + svc := services.NewListModelsService(catalog, logger) + h := handlers.NewModelsHandler(svc) + + req := httptest.NewRequest("GET", "/v1/models", nil) + w := httptest.NewRecorder() + h.Handle(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + var body struct { + Object string `json:"object"` + Data []struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + DisplayName string `json:"display_name"` + Description *string `json:"description,omitempty"` + ContextSize int `json:"context_size"` + MaxOutputTokens int `json:"max_output_tokens"` + Currency string `json:"currency"` + InputPricePer1MTokensCents int64 `json:"input_price_per_1m_tokens_cents"` + OutputPricePer1MTokensCents int64 `json:"output_price_per_1m_tokens_cents"` + InputPricePer1MTokensUSD float64 `json:"input_price_per_1m_tokens_usd"` + OutputPricePer1MTokensUSD float64 `json:"output_price_per_1m_tokens_usd"` + Features []string `json:"features"` + Endpoints []string `json:"endpoints"` + } `json:"data"` + } + json.NewDecoder(resp.Body).Decode(&body) + + if body.Object != "list" { + t.Errorf("object = %s, want list", body.Object) + } + if len(body.Data) != 2 { + t.Fatalf("data len = %d, want 2", len(body.Data)) + } + first := body.Data[0] + if first.ID != "openai/gpt-4.1-mini" { + t.Errorf("data[0].id = %s, want openai/gpt-4.1-mini", first.ID) + } + if first.OwnedBy != "nexus" { + t.Errorf("data[0].owned_by = %s, want nexus", first.OwnedBy) + } + if first.DisplayName != "GPT-4.1 Mini" { + t.Errorf("data[0].display_name = %s, want GPT-4.1 Mini", first.DisplayName) + } + if first.Description == nil || *first.Description != "GPT-4.1 Mini model" { + t.Errorf("data[0].description mismatch: %+v", first.Description) + } + if first.ContextSize != 128000 { + t.Errorf("data[0].context_size = %d, want 128000", first.ContextSize) + } + if first.MaxOutputTokens != 16000 { + t.Errorf("data[0].max_output_tokens = %d, want 16000", first.MaxOutputTokens) + } + if first.Currency != "EUR" { + t.Errorf("data[0].currency = %s, want EUR", first.Currency) + } + if first.InputPricePer1MTokensCents != 40 { + t.Errorf("data[0].input_price = %d, want 40", first.InputPricePer1MTokensCents) + } + if first.OutputPricePer1MTokensCents != 160 { + t.Errorf("data[0].output_price = %d, want 160", first.OutputPricePer1MTokensCents) + } + if first.InputPricePer1MTokensUSD != 0.432 { + t.Errorf("data[0].input_price_usd = %v, want 0.432", first.InputPricePer1MTokensUSD) + } + if first.OutputPricePer1MTokensUSD != 1.728 { + t.Errorf("data[0].output_price_usd = %v, want 1.728", first.OutputPricePer1MTokensUSD) + } + wantFeatures := map[string]bool{ + "streaming": true, + "function-calling": true, + "structured-outputs": true, + } + if len(first.Features) != len(wantFeatures) { + t.Errorf("data[0].features = %v, want %v", first.Features, wantFeatures) + } + for _, f := range first.Features { + if !wantFeatures[f] { + t.Errorf("unexpected feature %q in data[0].features", f) + } + } + if len(first.Endpoints) != 1 || first.Endpoints[0] != "chat/completions" { + t.Errorf("data[0].endpoints = %v, want [chat/completions]", first.Endpoints) + } +} + +func TestModelsHandler_PrivateModelIncludesBaseModel(t *testing.T) { + catalog := &mockModelCatalog{ + models: []domain.PublicModel{ + { + PublicModelID: "private/minimax", + UpstreamModelName: "minimax/minimax-m2.7", + DisplayName: "Private MiniMax", + SupportsChatCompletions: true, + Currency: "EUR", + }, + }, + } + logger := &mockLogger{} + + svc := services.NewListModelsService(catalog, logger) + h := handlers.NewModelsHandler(svc) + + req := httptest.NewRequest("GET", "/v1/models", nil) + w := httptest.NewRecorder() + h.Handle(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + var body struct { + Data []struct { + ID string `json:"id"` + BaseModel *string `json:"base_model,omitempty"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(body.Data) != 1 { + t.Fatalf("data len = %d, want 1", len(body.Data)) + } + if body.Data[0].ID != "private/minimax" { + t.Fatalf("data[0].id = %s, want private/minimax", body.Data[0].ID) + } + if body.Data[0].BaseModel == nil || *body.Data[0].BaseModel != "minimax/minimax-m2.7" { + t.Fatalf("data[0].base_model = %v, want minimax/minimax-m2.7", body.Data[0].BaseModel) + } +} + +func TestExtractBearerToken(t *testing.T) { + tests := []struct { + name string + authHeader string + wantToken string + wantErr bool + }{ + {"valid", "Bearer sk-test-123", "sk-test-123", false}, + {"missing", "", "", true}, + {"malformed", "Basic abc", "", true}, + {"empty_bearer", "Bearer ", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + token, err := handlers.ExtractBearerToken(req) + if tt.wantErr && err == nil { + t.Fatal("expected error") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != tt.wantToken { + t.Errorf("token = %s, want %s", token, tt.wantToken) + } + }) + } +} + +func TestChatCompletionsHandler_MissingAuth(t *testing.T) { + auth := &mockAuthService{err: domain.ErrInvalidAPIKey("bad")} + catalog := &mockModelCatalog{} + logger := &mockLogger{} + usage := &mockUsageRecorder{} + balances := &mockBalanceChecker{balance: 100_000_000_000} + + genSvc := services.NewGenerateService(auth, balances, catalog, nil, nil, usage, &mockCostReserver{}, nil, logger) + chatSvc := services.NewChatCompletionsService(genSvc, logger) + h := handlers.NewChatCompletionsHandler(chatSvc, logger) + + body := `{"model": "test", "messages": [{"role": "user", "content": "hi"}]}` + req := httptest.NewRequest("POST", "/v1/chat/completions", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + h.Handle(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", resp.StatusCode) + } +} diff --git a/apps/gateway/internal/adapters/http/handlers/health_handler.go b/apps/gateway/internal/adapters/http/handlers/health_handler.go new file mode 100644 index 0000000..1d7c573 --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/health_handler.go @@ -0,0 +1,14 @@ +package handlers + +import "net/http" + +// HealthHandler handles GET /healthz. +type HealthHandler struct{} + +func NewHealthHandler() *HealthHandler { + return &HealthHandler{} +} + +func (h *HealthHandler) Handle(w http.ResponseWriter, r *http.Request) { + WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} diff --git a/apps/gateway/internal/adapters/http/handlers/models_handler.go b/apps/gateway/internal/adapters/http/handlers/models_handler.go new file mode 100644 index 0000000..d56f781 --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/models_handler.go @@ -0,0 +1,169 @@ +package handlers + +import ( + "net/http" + "strings" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/dto" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/services" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/shopspring/decimal" +) + +// ModelsHandler handles GET /v1/models. +type ModelsHandler struct { + service *services.ListModelsService +} + +func NewModelsHandler(service *services.ListModelsService) *ModelsHandler { + return &ModelsHandler{service: service} +} + +func (h *ModelsHandler) Handle(w http.ResponseWriter, r *http.Request) { + models, err := h.service.Execute(r.Context()) + if err != nil { + WriteError(w, err) + return + } + + data := make([]dto.ModelData, 0, len(models)) + for _, entry := range models { + data = append(data, toModelData(entry)) + } + + w.Header().Set("Cache-Control", "public, max-age=60") + WriteJSON(w, http.StatusOK, dto.ModelsResponse{ + Object: "list", + Data: data, + }) +} + +func toModelData(entry domain.ModelCatalogEntry) dto.ModelData { + if entry.Kind == domain.CatalogKindRouter && entry.Router != nil { + return toRouterData(*entry.Router) + } + if entry.PublicModel == nil { + return dto.ModelData{} + } + return toPublicModelData(*entry.PublicModel, entry.EURToUSDRate) +} + +func toPublicModelData(m domain.PublicModel, eurToUSDRate float64) dto.ModelData { + inputPrice := eurToCents(m.InputPricePerMillion) + outputPrice := eurToCents(m.OutputPricePerMillion) + proofMode := m.EffectiveProofMode() + return dto.ModelData{ + ID: m.PublicModelID, + Object: "model", + Kind: string(domain.CatalogKindPublicModel), + BaseModel: privateBaseModel(m.PublicModelID, m.UpstreamModelName), + OwnedBy: "nexus", + Created: 0, + DisplayName: m.DisplayName, + Description: m.Description, + ContextSize: &m.MaxContextWindow, + MaxOutputTokens: &m.MaxOutputTokens, + Currency: &m.Currency, + InputPricePer1MTokensCents: &inputPrice, + OutputPricePer1MTokensCents: &outputPrice, + CacheReadPricePer1MTokensCents: eurPtrToCentsPtr(m.CacheReadPricePerMillion), + CacheWritePricePer1MTokensCents: eurPtrToCentsPtr(m.CacheWritePricePerMillion), + InputPricePer1MTokensUSD: priceToUSDPtr(m.Currency, m.InputPricePerMillion, eurToUSDRate), + OutputPricePer1MTokensUSD: priceToUSDPtr(m.Currency, m.OutputPricePerMillion, eurToUSDRate), + CacheReadPricePer1MTokensUSD: pricePtrToUSDPtr(m.Currency, m.CacheReadPricePerMillion, eurToUSDRate), + CacheWritePricePer1MTokensUSD: pricePtrToUSDPtr(m.Currency, m.CacheWritePricePerMillion, eurToUSDRate), + Features: modelFeatures(m), + Endpoints: modelEndpoints(m), + ProofMode: proofMode, + ProofsEnabled: domain.ProofModeEnabled(proofMode), + } +} + +func toRouterData(r domain.RouterEntry) dto.ModelData { + return dto.ModelData{ + ID: r.RouterID, + Object: "model", + Kind: string(domain.CatalogKindRouter), + OwnedBy: "nexus", + Created: r.CreatedAt.Unix(), + DisplayName: r.DisplayName, + Description: r.Description, + Features: []string{"routing"}, + Endpoints: []string{"chat/completions"}, + ProofMode: domain.ProofModeNone, + } +} + +func modelFeatures(m domain.PublicModel) []string { + features := make([]string, 0, 5) + if m.SupportsChatCompletionsStream { + features = append(features, "streaming") + } + if m.SupportsTools { + features = append(features, "function-calling") + } + if m.SupportsParallelToolCalls { + features = append(features, "parallel-tool-calls") + } + if m.SupportsStructuredOutput { + features = append(features, "structured-outputs") + } + if m.SupportsReasoning { + features = append(features, "reasoning") + } + switch m.EffectiveProofMode() { + case domain.ProofModeTinfoilAttestedTransport: + features = append(features, "tinfoil-attested-transport") + } + return features +} + +func modelEndpoints(m domain.PublicModel) []string { + endpoints := make([]string, 0, 1) + if m.SupportsChatCompletions { + endpoints = append(endpoints, "chat/completions") + } + return endpoints +} + +func eurToCents(price decimal.Decimal) int64 { + return price.Mul(decimal.NewFromInt(100)).Round(0).IntPart() +} + +func eurPtrToCentsPtr(price *decimal.Decimal) *int64 { + if price == nil { + return nil + } + v := eurToCents(*price) + return &v +} + +func pricePtrToUSDPtr(currency string, price *decimal.Decimal, eurToUSDRate float64) *float64 { + if price == nil { + return nil + } + return priceToUSDPtr(currency, *price, eurToUSDRate) +} + +func priceToUSDPtr(currency string, price decimal.Decimal, eurToUSDRate float64) *float64 { + factor := 0.0 + switch strings.ToUpper(strings.TrimSpace(currency)) { + case "USD": + factor = 1 + case "EUR": + factor = eurToUSDRate + } + if factor <= 0 { + return nil + } + usd, _ := price.Mul(decimal.NewFromFloat(factor)).Round(4).Float64() + return &usd +} + +func privateBaseModel(publicModelID, upstreamModelName string) *string { + baseModel := strings.TrimSpace(upstreamModelName) + if !strings.HasPrefix(publicModelID, "private/") || baseModel == "" { + return nil + } + return &baseModel +} diff --git a/apps/gateway/internal/adapters/http/handlers/tinfoil_handler.go b/apps/gateway/internal/adapters/http/handlers/tinfoil_handler.go new file mode 100644 index 0000000..54e9e0d --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/tinfoil_handler.go @@ -0,0 +1,100 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// TinfoilHandler serves safe Tinfoil attested-transport proof APIs. +type TinfoilHandler struct { + auth ports.AuthService + proofs ports.TinfoilTransportProofRepository + logger ports.Logger +} + +func NewTinfoilHandler(auth ports.AuthService, proofs ports.TinfoilTransportProofRepository, logger ports.Logger) *TinfoilHandler { + return &TinfoilHandler{auth: auth, proofs: proofs, logger: logger} +} + +func (h *TinfoilHandler) HandleGetProof(w http.ResponseWriter, r *http.Request) { + authCtx, err := h.authenticate(r) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + responseID := r.PathValue("response_id") + if responseID == "" { + WriteErrorWithLog(w, r, h.logger, domain.ErrInvalidField("response_id is required")) + return + } + if h.proofs == nil { + WriteErrorWithLog(w, r, h.logger, domain.ErrInternal("Tinfoil proof repository is not configured")) + return + } + proof, err := h.proofs.GetTinfoilTransportProof(r.Context(), authCtx.Account.ID, responseID) + if err != nil { + WriteErrorWithLog(w, r, h.logger, err) + return + } + WriteJSON(w, http.StatusOK, tinfoilProofResponseFromDomain(proof)) +} + +func (h *TinfoilHandler) authenticate(r *http.Request) (domain.AuthContext, error) { + token, err := ExtractBearerToken(r) + if err != nil { + return domain.AuthContext{}, err + } + return h.auth.AuthenticateAPIKey(r.Context(), token) +} + +type tinfoilProofResponse struct { + Provider string `json:"provider"` + PublicModelID string `json:"public_model_id"` + UpstreamModelID string `json:"upstream_model_id"` + ProviderResponseID string `json:"provider_response_id"` + EnclaveHost *string `json:"enclave_host"` + ConfigRepo *string `json:"config_repo"` + Digest *string `json:"digest"` + CodeFingerprint *string `json:"code_fingerprint"` + EnclaveFingerprint *string `json:"enclave_fingerprint"` + TLSPublicKey *string `json:"tls_public_key"` + HPKEPublicKey *string `json:"hpke_public_key"` + TransportMode *string `json:"transport_mode"` + SDKVersion *string `json:"sdk_version"` + Status string `json:"status"` + FailureReason *string `json:"failure_reason"` + VerificationEvidence json.RawMessage `json:"verification_evidence,omitempty"` + CreatedAt string `json:"created_at"` + VerifiedAt *string `json:"verified_at"` +} + +func tinfoilProofResponseFromDomain(proof domain.TinfoilTransportProof) tinfoilProofResponse { + var verifiedAt *string + if proof.VerifiedAt != nil { + v := proof.VerifiedAt.Format(http.TimeFormat) + verifiedAt = &v + } + return tinfoilProofResponse{ + Provider: proof.Provider, + PublicModelID: proof.PublicModelID, + UpstreamModelID: proof.UpstreamModelID, + ProviderResponseID: proof.ProviderResponseID, + EnclaveHost: proof.EnclaveHost, + ConfigRepo: proof.ConfigRepo, + Digest: proof.Digest, + CodeFingerprint: proof.CodeFingerprint, + EnclaveFingerprint: proof.EnclaveFingerprint, + TLSPublicKey: proof.TLSPublicKey, + HPKEPublicKey: proof.HPKEPublicKey, + TransportMode: proof.TransportMode, + SDKVersion: proof.SDKVersion, + Status: proof.Status, + FailureReason: proof.FailureReason, + VerificationEvidence: proof.VerificationEvidenceJSON, + CreatedAt: proof.CreatedAt.Format(http.TimeFormat), + VerifiedAt: verifiedAt, + } +} diff --git a/apps/gateway/internal/adapters/http/handlers/tinfoil_handler_test.go b/apps/gateway/internal/adapters/http/handlers/tinfoil_handler_test.go new file mode 100644 index 0000000..076705c --- /dev/null +++ b/apps/gateway/internal/adapters/http/handlers/tinfoil_handler_test.go @@ -0,0 +1,91 @@ +package handlers_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/handlers" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestTinfoilHandler_GetProofReturnsSafeEvidence(t *testing.T) { + now := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC) + value := func(v string) *string { return &v } + proofRepo := &handlerTinfoilRepo{proofs: map[string]domain.TinfoilTransportProof{ + "cmpl-tinfoil-safe": { + AccountID: "acc1", + APIKeyID: "key1", + Provider: "tinfoil", + PublicModelID: "tinfoil/kimi-k2", + UpstreamModelID: "kimi-k2", + ProviderResponseID: "cmpl-tinfoil-safe", + EnclaveHost: value("inference.tinfoil.sh"), + ConfigRepo: value("tinfoilsh/confidential-model-router"), + Digest: value("sha256:abc"), + CodeFingerprint: value("code-fp"), + EnclaveFingerprint: value("enclave-fp"), + TLSPublicKey: value("tls-key"), + HPKEPublicKey: value("hpke-key"), + TransportMode: value("ehbp"), + SDKVersion: value("github.com/tinfoilsh/tinfoil-go v0.13.1"), + Status: domain.ProofStatusVerified, + VerificationEvidenceJSON: json.RawMessage(`{"ground_truth":{"digest":"sha256:abc"}}`), + CreatedAt: now, + VerifiedAt: &now, + }, + }} + h := handlers.NewTinfoilHandler( + &mockAuthService{authCtx: domain.AuthContext{Account: domain.Account{ID: "acc1"}}}, + proofRepo, + &mockLogger{}, + ) + + req := httptest.NewRequest("GET", "/v1/tinfoil/proofs/cmpl-tinfoil-safe", nil) + req.SetPathValue("response_id", "cmpl-tinfoil-safe") + req.Header.Set("Authorization", "Bearer test") + rr := httptest.NewRecorder() + + h.HandleGetProof(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", rr.Code, rr.Body.String()) + } + + var body map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body["provider_response_id"] != "cmpl-tinfoil-safe" || body["status"] != domain.ProofStatusVerified { + t.Fatalf("unexpected proof response: %#v", body) + } + if body["digest"] != "sha256:abc" || body["transport_mode"] != "ehbp" { + t.Fatalf("missing safe Tinfoil evidence: %#v", body) + } + if _, ok := body["verification_evidence"].(map[string]any); !ok { + t.Fatalf("verification evidence missing or wrong type: %#v", body["verification_evidence"]) + } +} + +type handlerTinfoilRepo struct { + proofs map[string]domain.TinfoilTransportProof +} + +func (r *handlerTinfoilRepo) UpsertTinfoilTransportProof(_ context.Context, proof domain.TinfoilTransportProof) error { + if r.proofs == nil { + r.proofs = make(map[string]domain.TinfoilTransportProof) + } + r.proofs[proof.ProviderResponseID] = proof + return nil +} + +func (r *handlerTinfoilRepo) GetTinfoilTransportProof(_ context.Context, accountID, providerResponseID string) (domain.TinfoilTransportProof, error) { + proof, ok := r.proofs[providerResponseID] + if !ok || proof.AccountID != accountID { + return domain.TinfoilTransportProof{}, domain.ErrNotFound("tinfoil proof", providerResponseID) + } + return proof, nil +} diff --git a/apps/gateway/internal/adapters/http/mapper/chat_mapper_test.go b/apps/gateway/internal/adapters/http/mapper/chat_mapper_test.go new file mode 100644 index 0000000..ca9d14a --- /dev/null +++ b/apps/gateway/internal/adapters/http/mapper/chat_mapper_test.go @@ -0,0 +1,318 @@ +package mapper_test + +import ( + "encoding/json" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/mapper" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestChatCompletionRequestToDomain_Basic(t *testing.T) { + raw := json.RawMessage(`{ + "model": "openai/gpt-4.1-mini", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`) + + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if req.PublicModelID != "openai/gpt-4.1-mini" { + t.Errorf("model = %s, want openai/gpt-4.1-mini", req.PublicModelID) + } + if len(req.Input) != 1 { + t.Fatalf("input len = %d, want 1", len(req.Input)) + } + if *req.Input[0].Content != "Hello" { + t.Errorf("content = %s, want Hello", *req.Input[0].Content) + } + if *req.Input[0].Role != "user" { + t.Errorf("role = %s, want user", *req.Input[0].Role) + } +} + +func TestChatCompletionRequestToDomain_DeveloperRole(t *testing.T) { + raw := json.RawMessage(`{ + "model": "openai/gpt-5", + "messages": [ + {"role": "developer", "content": "Follow these instructions"} + ] + }`) + + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(req.Input) != 1 { + t.Fatalf("input len = %d, want 1", len(req.Input)) + } + if *req.Input[0].Role != "developer" { + t.Fatalf("role = %s, want developer", *req.Input[0].Role) + } + if *req.Input[0].Content != "Follow these instructions" { + t.Fatalf("content = %s, want developer content", *req.Input[0].Content) + } +} + +func TestChatCompletionRequestToDomain_EmptyMessages(t *testing.T) { + raw := json.RawMessage(`{"model": "test", "messages": []}`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err == nil { + t.Fatal("expected error for empty messages") + } +} + +func TestChatCompletionRequestToDomain_BothMaxTokens(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100, + "max_completion_tokens": 200 + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err == nil { + t.Fatal("expected error when both max_tokens and max_completion_tokens are set") + } +} + +func TestChatCompletionRequestToDomain_MaxTokensOnly(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 100 + }`) + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.MaxOutputTokens == nil || *req.MaxOutputTokens != 100 { + t.Errorf("max_output_tokens = %v, want 100", req.MaxOutputTokens) + } +} + +func TestChatCompletionRequestToDomain_MaxCompletionTokensOnly(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "max_completion_tokens": 200 + }`) + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.MaxOutputTokens == nil || *req.MaxOutputTokens != 200 { + t.Errorf("max_output_tokens = %v, want 200", req.MaxOutputTokens) + } +} + +func TestChatCompletionRequestToDomain_ReasoningEffort(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "reasoning_effort": "low" + }`) + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.ReasoningEffort == nil || *req.ReasoningEffort != "low" { + t.Errorf("reasoning_effort = %v, want low", req.ReasoningEffort) + } + if fields := mapper.UnknownChatCompletionFields(raw); len(fields) != 0 { + t.Fatalf("unknown fields = %v, want none", fields) + } +} + +func TestChatCompletionRequestToDomain_UnknownField(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "foobar": true + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unknown fields should be silently ignored, got: %v", err) + } +} + +func TestUnknownChatCompletionFields(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "zeta": true, + "n": 1, + "best_of": 2, + "alpha": "ignored" + }`) + fields := mapper.UnknownChatCompletionFields(raw) + if len(fields) != 2 || fields[0] != "alpha" || fields[1] != "zeta" { + t.Fatalf("fields = %v, want [alpha zeta]", fields) + } +} + +func TestChatCompletionRequestToDomain_NOne(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "n": 1 + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error for n=1: %v", err) + } +} + +func TestChatCompletionRequestToDomain_NMultipleUnsupported(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "n": 2 + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err == nil { + t.Fatal("expected error for n > 1") + } +} + +func TestChatCompletionRequestToDomain_ToolMessage(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [ + {"role": "user", "content": "What is the weather?"}, + {"role": "assistant", "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"NYC\"}"}}]}, + {"role": "tool", "tool_call_id": "call_1", "content": "72F sunny"} + ] + }`) + + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(req.Input) != 3 { + t.Fatalf("input len = %d, want 3", len(req.Input)) + } + if len(req.Input[1].ToolCalls) != 1 { + t.Fatalf("assistant tool_calls len = %d, want 1", len(req.Input[1].ToolCalls)) + } + if *req.Input[2].ToolCallID != "call_1" { + t.Errorf("tool_call_id = %s, want call_1", *req.Input[2].ToolCallID) + } +} + +func TestChatCompletionRequestToDomain_AssistantReasoningContent(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [ + {"role": "assistant", "reasoning_content": "thoughts", "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}}]} + ] + }`) + + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Input[0].ReasoningContent == nil || *req.Input[0].ReasoningContent != "thoughts" { + t.Fatalf("reasoning_content = %v, want thoughts", req.Input[0].ReasoningContent) + } +} + +func TestChatCompletionRequestToDomain_ToolMessageMissingID(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [ + {"role": "tool", "content": "result"} + ] + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err == nil { + t.Fatal("expected error for tool message without tool_call_id") + } +} + +func TestChatCompletionRequestToDomain_SystemWithToolCalls(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [ + {"role": "system", "content": "You are helpful", "tool_calls": [{"id": "x", "type": "function", "function": {"name": "f", "arguments": "{}"}}]} + ] + }`) + _, err := mapper.ChatCompletionRequestToDomain(raw) + if err == nil { + t.Fatal("expected error for system message with tool_calls") + } +} + +func TestChatCompletionRequestToDomain_Stop(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "stop": ["END", "STOP"] + }`) + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(req.Stop) != 2 { + t.Errorf("stop len = %d, want 2", len(req.Stop)) + } +} + +func TestChatCompletionRequestToDomain_StopString(t *testing.T) { + raw := json.RawMessage(`{ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "stop": "END" + }`) + req, err := mapper.ChatCompletionRequestToDomain(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(req.Stop) != 1 || req.Stop[0] != "END" { + t.Errorf("stop = %v, want [END]", req.Stop) + } +} + +func TestDomainToChatCompletionResponse(t *testing.T) { + content := "Hello!" + reasoning := "thinking" + role := "assistant" + finishReason := "stop" + result := domain.GenerateResult{ + ID: "test-id-123", + CreatedUnix: 1700000000, + PublicModelID: "openai/gpt-4.1-mini", + Output: []domain.OutputItem{ + {Type: "message", Role: &role, Content: &content, ReasoningContent: &reasoning}, + }, + FinishReason: &finishReason, + Usage: &domain.Usage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + resp := mapper.DomainToChatCompletionResponse(result) + if resp.Object != "chat.completion" { + t.Errorf("object = %s, want chat.completion", resp.Object) + } + if len(resp.Choices) != 1 { + t.Fatalf("choices len = %d, want 1", len(resp.Choices)) + } + if *resp.Choices[0].Message.Content != "Hello!" { + t.Errorf("content = %s, want Hello!", *resp.Choices[0].Message.Content) + } + if resp.Choices[0].Message.ReasoningContent == nil || *resp.Choices[0].Message.ReasoningContent != "thinking" { + t.Errorf("reasoning_content = %v, want thinking", resp.Choices[0].Message.ReasoningContent) + } + if *resp.Choices[0].FinishReason != "stop" { + t.Errorf("finish_reason = %s, want stop", *resp.Choices[0].FinishReason) + } + if resp.Usage.TotalTokens != 15 { + t.Errorf("total_tokens = %d, want 15", resp.Usage.TotalTokens) + } +} diff --git a/apps/gateway/internal/adapters/http/mapper/chat_to_domain.go b/apps/gateway/internal/adapters/http/mapper/chat_to_domain.go new file mode 100644 index 0000000..971c2b8 --- /dev/null +++ b/apps/gateway/internal/adapters/http/mapper/chat_to_domain.go @@ -0,0 +1,334 @@ +package mapper + +import ( + "encoding/json" + "fmt" + "sort" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/dto" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// unsupportedChatFields are known top-level fields that must be rejected +// because they change semantics in ways the gateway cannot support. +var unsupportedChatFields = map[string]bool{ + "best_of": true, + "function_call": true, "functions": true, +} + +// knownChatFields are allowed top-level fields. +var knownChatFields = map[string]bool{ + "model": true, "messages": true, "stream": true, + "max_tokens": true, "max_completion_tokens": true, "n": true, + "temperature": true, "reasoning_effort": true, "top_p": true, "stop": true, + "tools": true, "tool_choice": true, "parallel_tool_calls": true, + "user": true, "metadata": true, "response_format": true, + "provider_options": true, "stream_options": true, + // Passed through / silently ignored: + "presence_penalty": true, "frequency_penalty": true, "logit_bias": true, + "seed": true, "logprobs": true, "top_logprobs": true, + "suffix": true, "echo": true, "service_tier": true, "store": true, +} + +// ChatCompletionRequestToDomain maps a /v1/chat/completions request DTO to the canonical GenerateRequest. +func ChatCompletionRequestToDomain(raw json.RawMessage) (domain.GenerateRequest, error) { + var rawMap map[string]json.RawMessage + if err := json.Unmarshal(raw, &rawMap); err != nil { + return domain.GenerateRequest{}, domain.ErrInvalidField("invalid JSON body") + } + + for key := range rawMap { + if unsupportedChatFields[key] { + return domain.GenerateRequest{}, domain.ErrInvalidField(fmt.Sprintf("field '%s' is not supported on /v1/chat/completions in this version", key)) + } + // Unknown fields are silently ignored for client compatibility + } + + var req dto.ChatCompletionRequest + if err := json.Unmarshal(raw, &req); err != nil { + return domain.GenerateRequest{}, domain.ErrInvalidField("invalid request body: " + err.Error()) + } + + if req.Model == "" { + return domain.GenerateRequest{}, domain.ErrInvalidField("model is required") + } + + if len(req.Messages) == 0 { + return domain.GenerateRequest{}, domain.ErrInvalidField("messages is required and cannot be empty") + } + + if req.MaxTokens != nil && req.MaxCompletionTokens != nil { + return domain.GenerateRequest{}, domain.ErrInvalidField("cannot provide both max_tokens and max_completion_tokens") + } + if req.N != nil && *req.N != 1 { + return domain.GenerateRequest{}, domain.ErrInvalidField("field 'n' only supports value 1 on /v1/chat/completions in this version") + } + + input, err := mapChatMessages(req.Messages) + if err != nil { + return domain.GenerateRequest{}, err + } + + gen := domain.GenerateRequest{ + PublicModelID: req.Model, + Input: input, + Temperature: req.Temperature, + ReasoningEffort: req.ReasoningEffort, + TopP: req.TopP, + Stream: req.Stream != nil && *req.Stream, + User: req.User, + Metadata: req.Metadata, + ProviderOptions: req.ProviderOptions, + PresencePenalty: req.PresencePenalty, + FrequencyPenalty: req.FrequencyPenalty, + LogitBias: req.LogitBias, + Seed: req.Seed, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, + Store: req.Store, + ServiceTier: req.ServiceTier, + } + + // Normalize max tokens + if req.MaxTokens != nil { + gen.MaxOutputTokens = req.MaxTokens + } else if req.MaxCompletionTokens != nil { + gen.MaxOutputTokens = req.MaxCompletionTokens + } + + if req.ParallelToolCalls != nil { + gen.ParallelToolCalls = req.ParallelToolCalls + } + + // Parse stop + if len(req.Stop) > 0 { + stops, err := parseStop(req.Stop) + if err != nil { + return domain.GenerateRequest{}, err + } + gen.Stop = stops + } + + // Map tools + for _, t := range req.Tools { + td, err := mapToolDefinition(t) + if err != nil { + return domain.GenerateRequest{}, err + } + gen.Tools = append(gen.Tools, td) + } + + // Map tool_choice + if len(req.ToolChoice) > 0 { + tc, err := parseToolChoice(req.ToolChoice) + if err != nil { + return domain.GenerateRequest{}, err + } + gen.ToolChoice = tc + } + + // Map response_format -> TextConfig + if req.ResponseFormat != nil { + gen.TextConfig = &domain.ResponseTextConfig{ + FormatType: &req.ResponseFormat.Type, + JSONSchema: req.ResponseFormat.JSONSchema, + } + } + + return gen, nil +} + +// UnknownChatCompletionFields returns top-level request fields the gateway does +// not currently understand. The mapper ignores these fields for client +// compatibility, but handlers can log them for visibility. +func UnknownChatCompletionFields(raw json.RawMessage) []string { + var rawMap map[string]json.RawMessage + if err := json.Unmarshal(raw, &rawMap); err != nil { + return nil + } + + fields := make([]string, 0) + for key := range rawMap { + if knownChatFields[key] || unsupportedChatFields[key] { + continue + } + fields = append(fields, key) + } + sort.Strings(fields) + return fields +} + +func mapChatMessages(messages []dto.ChatMessage) ([]domain.InputItem, error) { + var result []domain.InputItem + + for i, msg := range messages { + switch msg.Role { + case "system", "developer", "user": + if msg.ToolCalls != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: %s message must not contain tool_calls", i, msg.Role)) + } + if msg.ToolCallID != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: %s message must not contain tool_call_id", i, msg.Role)) + } + content, err := extractStringContent(msg.Content) + if err != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: %s", i, err.Error())) + } + role := msg.Role + result = append(result, domain.InputItem{ + Type: domain.InputItemTypeMessage, + Role: &role, + Content: &content, + }) + + case "assistant": + if msg.ToolCallID != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: assistant message must not contain tool_call_id", i)) + } + role := msg.Role + item := domain.InputItem{ + Type: domain.InputItemTypeMessage, + Role: &role, + } + // May have content + if len(msg.Content) > 0 { + content, err := extractStringContent(msg.Content) + if err == nil { + item.Content = &content + } + } + if msg.ReasoningContent != nil { + item.ReasoningContent = msg.ReasoningContent + } + // May have tool_calls + for _, tc := range msg.ToolCalls { + item.ToolCalls = append(item.ToolCalls, domain.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + ArgumentsJSON: tc.Function.Arguments, + }) + } + result = append(result, item) + + case "tool": + if msg.ToolCallID == nil || *msg.ToolCallID == "" { + return nil, domain.ErrToolMessageInvalid(fmt.Sprintf("message[%d]: tool message requires tool_call_id", i)) + } + if msg.ToolCalls != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: tool message must not contain tool_calls", i)) + } + content, err := extractStringContent(msg.Content) + if err != nil { + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: %s", i, err.Error())) + } + role := msg.Role + result = append(result, domain.InputItem{ + Type: domain.InputItemTypeMessage, + Role: &role, + Content: &content, + ToolCallID: msg.ToolCallID, + }) + + default: + return nil, domain.ErrInvalidField(fmt.Sprintf("message[%d]: unsupported role '%s'", i, msg.Role)) + } + } + + return result, nil +} + +func extractStringContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", fmt.Errorf("content is required") + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s, nil + } + + // Try array of content parts + var parts []struct { + Type string `json:"type"` + Text string `json:"text"` + } + if err := json.Unmarshal(raw, &parts); err != nil { + return "", fmt.Errorf("content must be a string or array of text content parts") + } + + var combined string + for _, p := range parts { + switch p.Type { + case "text", "input_text", "output_text": + combined += p.Text + default: + // Skip non-text content parts for now + } + } + return combined, nil +} + +func parseStop(raw json.RawMessage) ([]string, error) { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return []string{s}, nil + } + + var arr []string + if err := json.Unmarshal(raw, &arr); err != nil { + return nil, domain.ErrInvalidField("stop must be a string or array of strings") + } + return arr, nil +} + +func mapToolDefinition(t dto.ToolDefinition) (domain.ToolDefinition, error) { + if t.Type != "" && t.Type != "function" { + return domain.ToolDefinition{}, domain.ErrInvalidField(fmt.Sprintf("unsupported tool type: %s", t.Type)) + } + if t.Function.Name == "" { + return domain.ToolDefinition{}, domain.ErrInvalidField("tool function name is required") + } + strict := false + if t.Function.Strict != nil { + strict = *t.Function.Strict + } + return domain.ToolDefinition{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + Strict: strict, + }, nil +} + +func parseToolChoice(raw json.RawMessage) (*domain.ToolChoice, error) { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + switch s { + case "none", "auto", "required": + return &domain.ToolChoice{Mode: s}, nil + default: + return nil, domain.ErrInvalidField(fmt.Sprintf("invalid tool_choice value: %s", s)) + } + } + + var obj struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, domain.ErrInvalidField("tool_choice must be a string or object") + } + + if obj.Type != "function" { + return nil, domain.ErrInvalidField(fmt.Sprintf("unsupported tool_choice type: %s", obj.Type)) + } + if obj.Function.Name == "" { + return nil, domain.ErrInvalidField("tool_choice function name is required") + } + return &domain.ToolChoice{ + Mode: domain.ToolChoiceFunction, + FunctionName: &obj.Function.Name, + }, nil +} diff --git a/apps/gateway/internal/adapters/http/mapper/domain_to_chat.go b/apps/gateway/internal/adapters/http/mapper/domain_to_chat.go new file mode 100644 index 0000000..1c1dccd --- /dev/null +++ b/apps/gateway/internal/adapters/http/mapper/domain_to_chat.go @@ -0,0 +1,158 @@ +package mapper + +import ( + "strings" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/dto" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// DomainToChatCompletionResponse maps a canonical GenerateResult to a chat completions response DTO. +func DomainToChatCompletionResponse(result domain.GenerateResult) dto.ChatCompletionResponse { + resp := dto.ChatCompletionResponse{ + ID: chatCompletionID(result.ID), + Object: "chat.completion", + Created: result.CreatedUnix, + Model: result.PublicModelID, + } + + message := dto.ChatChoiceMessage{ + Role: "assistant", + } + + for _, out := range result.Output { + if out.Content != nil && *out.Content != "" { + message.Content = out.Content + } + if out.ReasoningContent != nil && *out.ReasoningContent != "" { + message.ReasoningContent = out.ReasoningContent + } + for _, tc := range out.ToolCalls { + message.ToolCalls = append(message.ToolCalls, dto.ChatToolCall{ + ID: tc.ID, + Type: "function", + Function: dto.ChatFunctionCall{ + Name: tc.Name, + Arguments: tc.ArgumentsJSON, + }, + }) + } + } + + choice := dto.ChatCompletionChoice{ + Index: 0, + Message: message, + FinishReason: result.FinishReason, + } + + resp.Choices = []dto.ChatCompletionChoice{choice} + + if result.Usage != nil { + resp.Usage = &dto.ChatCompletionUsage{ + PromptTokens: result.Usage.PromptTokens, + CompletionTokens: result.Usage.CompletionTokens, + TotalTokens: result.Usage.TotalTokens, + } + } + + return resp +} + +// DomainStreamEventToChatChunk maps a canonical StreamEvent to a chat completions streaming chunk. +func DomainStreamEventToChatChunk(event domain.StreamEvent, model string, responseID string, createdAt int64) (data *dto.ChatCompletionChunk, done bool) { + chunk := &dto.ChatCompletionChunk{ + ID: chatCompletionID(responseID), + Object: "chat.completion.chunk", + Created: createdAt, + Model: model, + } + + switch event.Type { + case domain.StreamEventOutputTextDelta: + choice := dto.ChatCompletionChunkChoice{ + Index: 0, + Delta: dto.ChatChunkDelta{ + Content: event.ContentDelta, + ReasoningContent: event.ReasoningDelta, + }, + } + if event.Role != nil { + choice.Delta.Role = *event.Role + } + chunk.Choices = []dto.ChatCompletionChunkChoice{choice} + return chunk, false + + case domain.StreamEventOutputMessageDelta: + choice := dto.ChatCompletionChunkChoice{ + Index: 0, + Delta: dto.ChatChunkDelta{}, + } + if event.Role != nil { + choice.Delta.Role = *event.Role + } + chunk.Choices = []dto.ChatCompletionChunkChoice{choice} + return chunk, false + + case domain.StreamEventToolCallDelta: + choice := dto.ChatCompletionChunkChoice{ + Index: 0, + Delta: dto.ChatChunkDelta{}, + } + if event.ToolCallDelta != nil { + tc := dto.ChatToolCallChunk{ + Index: event.ToolCallDelta.Index, + } + if event.ToolCallDelta.ID != nil { + tc.ID = *event.ToolCallDelta.ID + tc.Type = "function" + } + fn := dto.ChatFunctionCall{} + if event.ToolCallDelta.Name != nil { + fn.Name = *event.ToolCallDelta.Name + } + if event.ToolCallDelta.ArgumentsDelta != nil { + fn.Arguments = *event.ToolCallDelta.ArgumentsDelta + } + tc.Function = fn + choice.Delta.ToolCalls = []dto.ChatToolCallChunk{tc} + } + chunk.Choices = []dto.ChatCompletionChunkChoice{choice} + return chunk, false + + case domain.StreamEventCompleted: + delta := dto.ChatChunkDelta{} + if event.ContentDelta != nil { + delta.Content = event.ContentDelta + } + if event.ReasoningDelta != nil { + delta.ReasoningContent = event.ReasoningDelta + } + choice := dto.ChatCompletionChunkChoice{ + Index: 0, + Delta: delta, + FinishReason: event.FinishReason, + } + chunk.Choices = []dto.ChatCompletionChunkChoice{choice} + if event.Usage != nil { + chunk.Usage = &dto.ChatCompletionUsage{ + PromptTokens: event.Usage.PromptTokens, + CompletionTokens: event.Usage.CompletionTokens, + TotalTokens: event.Usage.TotalTokens, + } + } + return chunk, true + + case domain.StreamEventError: + return nil, true + + default: + return nil, false + } +} + +func chatCompletionID(id string) string { + if strings.HasPrefix(id, "chatcmpl-") { + return id + } + return "chatcmpl-" + id +} diff --git a/apps/gateway/internal/adapters/http/middleware/cors.go b/apps/gateway/internal/adapters/http/middleware/cors.go new file mode 100644 index 0000000..6c244aa --- /dev/null +++ b/apps/gateway/internal/adapters/http/middleware/cors.go @@ -0,0 +1,29 @@ +package middleware + +import "net/http" + +// CORS adds permissive CORS headers so browser SPAs can call the gateway +// directly — both the public landing page hitting the unauthenticated +// GET /v1/models endpoint and the user-ui chat playground hitting +// /v1/chat/completions with a Bearer API key. Origin is reflected only when +// present, with `Vary: Origin` to keep intermediate caches correct. +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Request-ID") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", "86400") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/apps/gateway/internal/adapters/http/middleware/logging.go b/apps/gateway/internal/adapters/http/middleware/logging.go new file mode 100644 index 0000000..fa382a3 --- /dev/null +++ b/apps/gateway/internal/adapters/http/middleware/logging.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" +) + +// Logging logs request details. +func Logging(logger ports.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sw := &statusWriter{ResponseWriter: w, status: 200} + next.ServeHTTP(sw, r) + + durationMs := time.Since(start).Milliseconds() + fields := []any{ + "method", r.Method, + "path", r.URL.Path, + "status", sw.status, + "duration_ms", durationMs, + "request_id", GetRequestID(r.Context()), + } + if ua := r.Header.Get("User-Agent"); ua != "" { + fields = append(fields, "user_agent", ua) + } + + if sw.status >= 500 { + logger.Error("request", fields...) + } else if sw.status >= 400 { + logger.Warn("request", fields...) + } else { + logger.Info("request", fields...) + } + }) + } +} + +type statusWriter struct { + http.ResponseWriter + status int + written bool +} + +func (w *statusWriter) WriteHeader(code int) { + if !w.written { + w.status = code + w.written = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} diff --git a/apps/gateway/internal/adapters/http/middleware/recovery.go b/apps/gateway/internal/adapters/http/middleware/recovery.go new file mode 100644 index 0000000..8c0725d --- /dev/null +++ b/apps/gateway/internal/adapters/http/middleware/recovery.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/dto" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" +) + +// Recovery catches panics and returns a 500 error. +func Recovery(logger ports.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + logger.Error("panic recovered", + "error", rec, + "request_id", GetRequestID(r.Context()), + "path", r.URL.Path, + ) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(dto.ErrorResponse{ + Error: dto.ErrorDetail{ + Type: "internal_error", + Code: "internal_error", + Message: "an internal error occurred", + }, + }) + } + }() + next.ServeHTTP(w, r) + }) + } +} diff --git a/apps/gateway/internal/adapters/http/middleware/request_id.go b/apps/gateway/internal/adapters/http/middleware/request_id.go new file mode 100644 index 0000000..d1b6167 --- /dev/null +++ b/apps/gateway/internal/adapters/http/middleware/request_id.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/google/uuid" +) + +type contextKey string + +const RequestIDKey contextKey = "request_id" + +// RequestID adds a unique request ID to each request. +func RequestID(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := r.Header.Get("X-Request-ID") + if id == "" { + id = uuid.New().String() + } + w.Header().Set("X-Request-ID", id) + ctx := context.WithValue(r.Context(), RequestIDKey, id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetRequestID retrieves the request ID from context. +func GetRequestID(ctx context.Context) string { + if id, ok := ctx.Value(RequestIDKey).(string); ok { + return id + } + return "" +} diff --git a/apps/gateway/internal/adapters/http/router.go b/apps/gateway/internal/adapters/http/router.go new file mode 100644 index 0000000..c2ac5aa --- /dev/null +++ b/apps/gateway/internal/adapters/http/router.go @@ -0,0 +1,35 @@ +package http + +import ( + "net/http" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/handlers" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/middleware" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" +) + +// NewRouter builds the HTTP mux with all routes and middleware. +func NewRouter( + health *handlers.HealthHandler, + models *handlers.ModelsHandler, + chatCompletions *handlers.ChatCompletionsHandler, + tinfoil *handlers.TinfoilHandler, + logger ports.Logger, +) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("GET /healthz", health.Handle) + mux.HandleFunc("GET /v1/models", models.Handle) + mux.HandleFunc("POST /v1/chat/completions", chatCompletions.Handle) + if tinfoil != nil { + mux.HandleFunc("GET /v1/tinfoil/proofs/{response_id}", tinfoil.HandleGetProof) + } + + var handler http.Handler = mux + handler = middleware.CORS(handler) + handler = middleware.Logging(logger)(handler) + handler = middleware.Recovery(logger)(handler) + handler = middleware.RequestID(handler) + + return handler +} diff --git a/apps/gateway/internal/adapters/http/sse/writer.go b/apps/gateway/internal/adapters/http/sse/writer.go new file mode 100644 index 0000000..4cf106a --- /dev/null +++ b/apps/gateway/internal/adapters/http/sse/writer.go @@ -0,0 +1,81 @@ +package sse + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// Writer writes Server-Sent Events to an HTTP response. +type Writer struct { + w http.ResponseWriter + flusher http.Flusher +} + +// NewWriter creates a new SSE writer and sets appropriate headers. +func NewWriter(w http.ResponseWriter) (*Writer, error) { + flusher, ok := w.(http.Flusher) + if !ok { + return nil, fmt.Errorf("streaming not supported") + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + return &Writer{w: w, flusher: flusher}, nil +} + +// WriteEvent writes a named SSE event with JSON data. +func (sw *Writer) WriteEvent(eventType string, data any) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal SSE data: %w", err) + } + + if eventType != "" { + if _, err := fmt.Fprintf(sw.w, "event: %s\n", eventType); err != nil { + return err + } + } + if _, err := fmt.Fprintf(sw.w, "data: %s\n\n", jsonData); err != nil { + return err + } + sw.flusher.Flush() + return nil +} + +// WriteData writes a data-only SSE event with JSON content. +func (sw *Writer) WriteData(data any) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal SSE data: %w", err) + } + + if _, err := fmt.Fprintf(sw.w, "data: %s\n\n", jsonData); err != nil { + return err + } + sw.flusher.Flush() + return nil +} + +// WriteDone writes the [DONE] terminal marker. +func (sw *Writer) WriteDone() error { + if _, err := fmt.Fprint(sw.w, "data: [DONE]\n\n"); err != nil { + return err + } + sw.flusher.Flush() + return nil +} + +// WriteComment writes an SSE comment (ignored by clients). Useful as a keep-alive. +func (sw *Writer) WriteComment(text string) error { + if _, err := fmt.Fprintf(sw.w, ": %s\n\n", text); err != nil { + return err + } + sw.flusher.Flush() + return nil +} diff --git a/apps/gateway/internal/adapters/observability/metrics/prometheus.go b/apps/gateway/internal/adapters/observability/metrics/prometheus.go new file mode 100644 index 0000000..99400c2 --- /dev/null +++ b/apps/gateway/internal/adapters/observability/metrics/prometheus.go @@ -0,0 +1,53 @@ +package metrics + +import ( + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + RequestsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gateway_requests_total", + Help: "Total number of requests by endpoint and status", + }, + []string{"endpoint", "status", "model"}, + ) + + RequestDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "gateway_request_duration_seconds", + Help: "Request duration in seconds", + Buckets: prometheus.DefBuckets, + }, + []string{"endpoint", "model"}, + ) + + UpstreamLatency = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "gateway_upstream_latency_seconds", + Help: "Upstream provider latency in seconds", + Buckets: prometheus.DefBuckets, + }, + []string{"provider", "model"}, + ) + + TokensTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gateway_tokens_total", + Help: "Total tokens by direction", + }, + []string{"direction", "model", "provider"}, + ) +) + +func init() { + prometheus.MustRegister(RequestsTotal, RequestDuration, UpstreamLatency, TokensTotal) +} + +// Handler returns the Prometheus metrics HTTP handler. +func Handler() http.Handler { + return promhttp.Handler() +} diff --git a/apps/gateway/internal/adapters/pii/presidio/adapter.go b/apps/gateway/internal/adapters/pii/presidio/adapter.go new file mode 100644 index 0000000..9b4ab7d --- /dev/null +++ b/apps/gateway/internal/adapters/pii/presidio/adapter.go @@ -0,0 +1,242 @@ +// Package presidio adapts Microsoft Presidio's analyzer HTTP API to the +// ports.PIIFilter interface. Only the analyzer is used — masking is performed +// in-process by domain.ApplyMask so that PII values never leave the gateway. +package presidio + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "time" + "unicode/utf8" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// Config controls adapter behavior. +type Config struct { + // BaseURL is the analyzer root, e.g. "http://presidio-analyzer:3000". + BaseURL string + // DefaultLanguage is used when callers pass an empty language code. + DefaultLanguage string + // ScoreThreshold filters out detections below this confidence (0..1). + ScoreThreshold float64 + // Timeout caps each Analyze call. + Timeout time.Duration + // Logger receives debug / warn events. Required. + Logger ports.Logger +} + +// Adapter implements ports.PIIFilter against Presidio's `/analyze` endpoint. +type Adapter struct { + cfg Config + client *http.Client +} + +// NewAdapter constructs an Adapter. Sensible defaults are filled in for any +// zero-valued config fields. +func NewAdapter(cfg Config) *Adapter { + if cfg.DefaultLanguage == "" { + cfg.DefaultLanguage = "en" + } + if cfg.Timeout <= 0 { + cfg.Timeout = 1500 * time.Millisecond + } + if cfg.ScoreThreshold < 0 { + cfg.ScoreThreshold = 0 + } + return &Adapter{ + cfg: cfg, + client: &http.Client{Timeout: cfg.Timeout}, + } +} + +// Enabled always returns true for the real adapter. The Noop adapter (see +// noop.go) is used to bypass detection. +func (a *Adapter) Enabled() bool { return true } + +// analyzeRequest mirrors Presidio's POST /analyze body. +type analyzeRequest struct { + Text string `json:"text"` + Language string `json:"language"` + ScoreThreshold float64 `json:"score_threshold,omitempty"` + Entities []string `json:"entities,omitempty"` +} + +// analyzeResponseItem mirrors a single result from POST /analyze. Offsets are +// character offsets (Python `len()` on a unicode string). +type analyzeResponseItem struct { + EntityType string `json:"entity_type"` + Start int `json:"start"` + End int `json:"end"` + Score float64 `json:"score"` +} + +// Analyze sends `text` to Presidio and converts character offsets to byte +// offsets before returning. +func (a *Adapter) Analyze(ctx context.Context, text string, opts ports.PIIAnalyzeOptions) ([]domain.PIIEntity, error) { + if text == "" { + return nil, nil + } + language := opts.Language + if language == "" { + language = a.cfg.DefaultLanguage + } + + body, err := json.Marshal(analyzeRequest{ + Text: text, + Language: language, + ScoreThreshold: a.cfg.ScoreThreshold, + Entities: presidioEntitiesForMode(opts.Mode), + }) + if err != nil { + return nil, fmt.Errorf("presidio: marshal request: %w", err) + } + + url := a.cfg.BaseURL + "/analyze" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("presidio: build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("presidio: call analyzer: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // Drain a small prefix for diagnostics without leaking large bodies. + preview, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("presidio: analyzer http %d: %s", resp.StatusCode, bytes.TrimSpace(preview)) + } + + var items []analyzeResponseItem + if err := json.NewDecoder(resp.Body).Decode(&items); err != nil { + return nil, fmt.Errorf("presidio: decode response: %w", err) + } + + return convertEntities(text, items), nil +} + +func presidioEntitiesForMode(mode string) []string { + switch mode { + case domain.APIKeyPIIModeLow: + return cloneEntities(lowProfileEntities) + case domain.APIKeyPIIModeBalanced: + entities := cloneEntities(lowProfileEntities) + entities = append(entities, "PERSON") + return entities + case domain.APIKeyPIIModeHigh, "": + return nil + default: + return cloneEntities(lowProfileEntities) + } +} + +func cloneEntities(entities []string) []string { + return append([]string(nil), entities...) +} + +// lowProfileEntities is the Presidio-specific allowlist behind Nexus' low PII +// masking level. It is intentionally explicit: if Presidio adds a new broad +// semantic recognizer later, low mode should not start masking it by accident. +// Keep this profile limited to stable identifiers whose exact value +// is usually less important to prompt meaning than names, places, and dates. +// Every entry must also be supported by the pinned official Presidio analyzer +// image for English; unsupported names make Presidio log a warning per request. +var lowProfileEntities = flattenEntityGroups( + contactAndNetworkEntities, + paymentAndAccountEntities, + genericCredentialEntities, + usIdentifierEntities, + ukIdentifierEntities, +) + +var contactAndNetworkEntities = []string{ + "EMAIL_ADDRESS", + "IP_ADDRESS", + "MAC_ADDRESS", + "PHONE_NUMBER", +} + +var paymentAndAccountEntities = []string{ + "CREDIT_CARD", + "CRYPTO", + "IBAN_CODE", +} + +var genericCredentialEntities = []string{ + "MEDICAL_LICENSE", +} + +var usIdentifierEntities = []string{ + "US_BANK_NUMBER", + "US_DRIVER_LICENSE", + "US_ITIN", + "US_PASSPORT", + "US_SSN", +} + +var ukIdentifierEntities = []string{ + "UK_NHS", +} + +func flattenEntityGroups(groups ...[]string) []string { + seen := make(map[string]struct{}) + entities := make([]string, 0) + for _, group := range groups { + for _, entity := range group { + if _, ok := seen[entity]; ok { + continue + } + seen[entity] = struct{}{} + entities = append(entities, entity) + } + } + sort.Strings(entities) + return entities +} + +// convertEntities turns Presidio's character offsets into byte offsets and +// drops any spans that fall outside `text`. +func convertEntities(text string, items []analyzeResponseItem) []domain.PIIEntity { + if len(items) == 0 { + return nil + } + + // Build a runeIndex -> byteIndex lookup table. The final entry maps the + // past-the-end rune index to len(text), so we can convert exclusive ends. + runes := make([]int, 0, len(text)+1) + for i := range text { // i is the byte index at each rune start + runes = append(runes, i) + } + runes = append(runes, len(text)) + + // totalRunes lets us bounds-check character offsets cheaply. + totalRunes := utf8.RuneCountInString(text) + + out := make([]domain.PIIEntity, 0, len(items)) + for _, it := range items { + if it.Start < 0 || it.End < it.Start || it.End > totalRunes { + continue + } + out = append(out, domain.PIIEntity{ + Type: it.EntityType, + Start: runes[it.Start], + End: runes[it.End], + Score: float32(it.Score), + }) + } + + // Sort ascending by start to give callers a stable order. + sort.SliceStable(out, func(i, j int) bool { return out[i].Start < out[j].Start }) + return out +} diff --git a/apps/gateway/internal/adapters/pii/presidio/adapter_test.go b/apps/gateway/internal/adapters/pii/presidio/adapter_test.go new file mode 100644 index 0000000..86d4f52 --- /dev/null +++ b/apps/gateway/internal/adapters/pii/presidio/adapter_test.go @@ -0,0 +1,236 @@ +package presidio + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/dappnode/dappnode-nexus-gateway/pkg/observability/logger" +) + +func newTestAdapter(t *testing.T, srv *httptest.Server) *Adapter { + t.Helper() + zap, err := logger.NewZapLogger("error") + if err != nil { + t.Fatalf("logger: %v", err) + } + return NewAdapter(Config{ + BaseURL: srv.URL, + ScoreThreshold: 0.4, + Timeout: 2 * time.Second, + Logger: zap, + }) +} + +func TestAdapter_Analyze_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/analyze" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + } + var body analyzeRequest + _ = json.NewDecoder(r.Body).Decode(&body) + if body.Language != "en" || body.ScoreThreshold != 0.4 { + t.Errorf("unexpected body: %+v", body) + } + _ = json.NewEncoder(w).Encode([]analyzeResponseItem{ + {EntityType: "PERSON", Start: 11, End: 21, Score: 0.99}, // "John Smith" + {EntityType: "EMAIL_ADDRESS", Start: 32, End: 48, Score: 0.99}, // "john@example.com" + }) + })) + defer srv.Close() + + a := newTestAdapter(t, srv) + text := "My name is John Smith and email john@example.com" + got, err := a.Analyze(context.Background(), text, ports.PIIAnalyzeOptions{}) + if err != nil { + t.Fatalf("Analyze: %v", err) + } + if len(got) != 2 { + t.Fatalf("entities = %d, want 2", len(got)) + } + if text[got[0].Start:got[0].End] != "John Smith" { + t.Errorf("entity 0 = %q, want John Smith", text[got[0].Start:got[0].End]) + } + if text[got[1].Start:got[1].End] != "john@example.com" { + t.Errorf("entity 1 = %q, want john@example.com", text[got[1].Start:got[1].End]) + } +} + +func TestAdapter_Analyze_HandlesNonASCIIOffsets(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Presidio returns character offsets. For "héllo John" (10 runes, + // 11 bytes), Start=6 End=10 selects "John" by rune index. + _ = json.NewEncoder(w).Encode([]analyzeResponseItem{ + {EntityType: "PERSON", Start: 6, End: 10, Score: 0.9}, + }) + })) + defer srv.Close() + + a := newTestAdapter(t, srv) + text := "héllo John" + got, err := a.Analyze(context.Background(), text, ports.PIIAnalyzeOptions{Language: "en"}) + if err != nil { + t.Fatalf("Analyze: %v", err) + } + if len(got) != 1 { + t.Fatalf("entities = %d", len(got)) + } + if v := text[got[0].Start:got[0].End]; v != "John" { + t.Fatalf("byte-converted span = %q, want John", v) + } +} + +func TestAdapter_Analyze_HTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + + a := newTestAdapter(t, srv) + if _, err := a.Analyze(context.Background(), "x", ports.PIIAnalyzeOptions{Language: "en"}); err == nil { + t.Fatal("expected error for 500 response") + } +} + +func TestAdapter_Analyze_EmptyTextShortCircuits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("server should not be called for empty text") + w.WriteHeader(500) + })) + defer srv.Close() + a := newTestAdapter(t, srv) + got, err := a.Analyze(context.Background(), "", ports.PIIAnalyzeOptions{Language: "en"}) + if err != nil || got != nil { + t.Fatalf("Analyze(\"\") = %v, %v", got, err) + } +} + +func TestAdapter_Analyze_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(150 * time.Millisecond) + _, _ = w.Write([]byte("[]")) + })) + defer srv.Close() + zap, _ := logger.NewZapLogger("error") + a := NewAdapter(Config{BaseURL: srv.URL, Timeout: 20 * time.Millisecond, Logger: zap}) + if _, err := a.Analyze(context.Background(), "hello", ports.PIIAnalyzeOptions{Language: "en"}); err == nil { + t.Fatal("expected timeout error") + } +} + +func TestAdapter_Analyze_BadJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("{not json")) + })) + defer srv.Close() + a := newTestAdapter(t, srv) + if _, err := a.Analyze(context.Background(), "hi", ports.PIIAnalyzeOptions{Language: "en"}); err == nil { + t.Fatal("expected decode error") + } +} + +func TestAdapter_Analyze_SendsEntitiesForPrivacyMode(t *testing.T) { + tests := []struct { + name string + mode string + want []string + wantAbsent []string + wantNil bool + }{ + { + name: "low", + mode: domain.APIKeyPIIModeLow, + want: []string{"EMAIL_ADDRESS", "PHONE_NUMBER", "CREDIT_CARD", "US_SSN"}, + wantAbsent: []string{"PERSON", "LOCATION", "DATE_TIME", "URL"}, + }, + { + name: "balanced", + mode: domain.APIKeyPIIModeBalanced, + want: []string{"EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON"}, + wantAbsent: []string{"LOCATION", "DATE_TIME", "URL"}, + }, + { + name: "high", + mode: domain.APIKeyPIIModeHigh, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotEntities []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body analyzeRequest + _ = json.NewDecoder(r.Body).Decode(&body) + gotEntities = body.Entities + _ = json.NewEncoder(w).Encode([]analyzeResponseItem{}) + })) + defer srv.Close() + + a := newTestAdapter(t, srv) + if _, err := a.Analyze(context.Background(), "hello", ports.PIIAnalyzeOptions{Mode: tt.mode}); err != nil { + t.Fatalf("Analyze: %v", err) + } + if tt.wantNil { + if gotEntities != nil { + t.Fatalf("entities = %#v, want nil/omitted for high", gotEntities) + } + return + } + for _, want := range tt.want { + if !contains(gotEntities, want) { + t.Fatalf("entities missing %s: %#v", want, gotEntities) + } + } + for _, absent := range tt.wantAbsent { + if contains(gotEntities, absent) { + t.Fatalf("entities unexpectedly include %s: %#v", absent, gotEntities) + } + } + }) + } +} + +func TestLowProfileEntities_AreStableIdentifiersOnly(t *testing.T) { + if len(lowProfileEntities) == 0 { + t.Fatal("low profile must include stable identifier entities") + } + + for i, entity := range lowProfileEntities { + if i > 0 && lowProfileEntities[i-1] >= entity { + t.Fatalf("low profile must be sorted and unique, got %q before %q", lowProfileEntities[i-1], entity) + } + } + + semanticEntities := []string{"PERSON", "LOCATION", "DATE_TIME", "NRP", "ADDRESS", "AGE", "URL"} + for _, entity := range semanticEntities { + if contains(lowProfileEntities, entity) { + t.Fatalf("low profile must not include semantic entity %q", entity) + } + } +} + +func TestNoopFilter(t *testing.T) { + n := NewNoopFilter() + if n.Enabled() { + t.Fatal("noop must report Enabled()=false") + } + got, err := n.Analyze(context.Background(), "anything", ports.PIIAnalyzeOptions{Language: "en"}) + if err != nil || got != nil { + t.Fatalf("noop.Analyze = %v, %v", got, err) + } +} + +func contains(items []string, want string) bool { + for _, item := range items { + if item == want { + return true + } + } + return false +} diff --git a/apps/gateway/internal/adapters/pii/presidio/noop.go b/apps/gateway/internal/adapters/pii/presidio/noop.go new file mode 100644 index 0000000..a0706ad --- /dev/null +++ b/apps/gateway/internal/adapters/pii/presidio/noop.go @@ -0,0 +1,24 @@ +package presidio + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// NoopFilter is a PIIFilter that never detects anything. It is used when the +// PII filter is disabled by configuration so that callers can keep the same +// code path without nil checks. +type NoopFilter struct{} + +// NewNoopFilter returns a no-op filter. +func NewNoopFilter() *NoopFilter { return &NoopFilter{} } + +// Enabled always returns false. +func (NoopFilter) Enabled() bool { return false } + +// Analyze always returns no entities and no error. +func (NoopFilter) Analyze(context.Context, string, ports.PIIAnalyzeOptions) ([]domain.PIIEntity, error) { + return nil, nil +} diff --git a/apps/gateway/internal/adapters/providers/anthropic/adapter.go b/apps/gateway/internal/adapters/providers/anthropic/adapter.go new file mode 100644 index 0000000..374e7f8 --- /dev/null +++ b/apps/gateway/internal/adapters/providers/anthropic/adapter.go @@ -0,0 +1,177 @@ +package anthropic + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/google/uuid" +) + +// Adapter is the Anthropic provider adapter. +type Adapter struct { + client *Client +} + +func NewAdapter(timeout time.Duration) *Adapter { + return &Adapter{client: NewClient(timeout)} +} + +func (a *Adapter) Generate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + apiKey := resolveAPIKey(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return domain.GenerateResult{}, domain.ErrProviderUnavailable("anthropic: API key not configured") + } + + body := buildRequestBody(req, model) + delete(body, "stream") + + respBody, err := a.client.Do(ctx, model.ProviderConfig.BaseURL, apiKey, body) + if err != nil { + return domain.GenerateResult{}, mapProviderError(err) + } + + return parseResponse(respBody, req, model) +} + +func (a *Adapter) StreamGenerate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (ports.GenerationStream, error) { + apiKey := resolveAPIKey(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return nil, domain.ErrProviderUnavailable("anthropic: API key not configured") + } + + body := buildRequestBody(req, model) + + resp, err := a.client.DoStream(ctx, model.ProviderConfig.BaseURL, apiKey, body) + if err != nil { + return nil, mapProviderError(err) + } + + return NewStream(resp), nil +} + +func resolveAPIKey(secretRef string) string { + return os.Getenv(secretRef) +} + +func parseResponse(data json.RawMessage, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + var resp struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + } `json:"content"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` + CacheReadInputTokens int64 `json:"cache_read_input_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return domain.GenerateResult{}, fmt.Errorf("failed to parse provider response: %w", err) + } + + result := domain.GenerateResult{ + ID: resp.ID, + CreatedUnix: time.Now().Unix(), + PublicModelID: req.PublicModelID, + ProviderName: model.ProviderConfig.ProviderName, + ProviderModelID: model.ProviderModelID, + } + + if result.ID == "" { + result.ID = uuid.New().String() + } + + role := resp.Role + if role == "" { + role = "assistant" + } + + out := domain.OutputItem{ + Type: domain.OutputItemTypeMessage, + Role: &role, + } + + for _, block := range resp.Content { + switch block.Type { + case "text": + text := block.Text + out.Content = &text + case "tool_use": + argsJSON, _ := json.Marshal(block.Input) + out.ToolCalls = append(out.ToolCalls, domain.ToolCall{ + ID: block.ID, + Name: block.Name, + ArgumentsJSON: string(argsJSON), + }) + } + } + + result.Output = []domain.OutputItem{out} + + finishReason := mapStopReason(resp.StopReason) + result.FinishReason = &finishReason + + // Anthropic input_tokens = only uncached input. Normalize PromptTokens to total input. + totalInput := resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + result.Usage = &domain.Usage{ + PromptTokens: totalInput, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: totalInput + resp.Usage.OutputTokens, + CacheCreationTokens: resp.Usage.CacheCreationInputTokens, + CacheReadTokens: resp.Usage.CacheReadInputTokens, + } + + return result, nil +} + +func mapProviderError(err error) error { + if err == nil { + return nil + } + + var httpErr *ProviderHTTPError + if errors.As(err, &httpErr) { + var gwErr *domain.GatewayError + switch { + case httpErr.StatusCode == 401 || httpErr.StatusCode == 403: + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider auth/permission error: %s", httpErr.Body)) + case httpErr.StatusCode == 429: + gwErr = domain.ErrProviderError(429, "provider rate limited: "+httpErr.Body) + case httpErr.StatusCode == 503: + gwErr = domain.ErrProviderUnavailable("anthropic") + case httpErr.StatusCode >= 500: + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider server error: %s", httpErr.Body)) + default: + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider rejected request: %s", httpErr.Body)) + } + return gwErr.WithMeta( + "upstream_status", httpErr.StatusCode, + "upstream_error", httpErr.Body, + ) + } + + msg := err.Error() + if strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline") { + return domain.ErrProviderTimeout("anthropic").WithMeta("upstream_error", msg) + } + if strings.Contains(msg, "connection refused") || strings.Contains(msg, "no such host") { + return domain.ErrProviderUnavailable("anthropic").WithMeta("upstream_error", msg) + } + return domain.ErrProviderError(502, msg).WithMeta("upstream_error", msg) +} diff --git a/apps/gateway/internal/adapters/providers/anthropic/client.go b/apps/gateway/internal/adapters/providers/anthropic/client.go new file mode 100644 index 0000000..a71a36a --- /dev/null +++ b/apps/gateway/internal/adapters/providers/anthropic/client.go @@ -0,0 +1,135 @@ +package anthropic + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "time" +) + +// Client handles HTTP communication with the Anthropic API. +type Client struct { + httpClient *http.Client + responseTimeout time.Duration +} + +func NewClient(timeout time.Duration) *Client { + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: timeout, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + return &Client{ + httpClient: &http.Client{Transport: transport}, + responseTimeout: timeout, + } +} + +// jsonUnmarshalBytes is used by mapper.go +func jsonUnmarshalBytes(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (c *Client) Do(ctx context.Context, baseURL, apiKey string, body map[string]any) (json.RawMessage, error) { + ctx, cancel := context.WithTimeout(ctx, c.responseTimeout) + defer cancel() + + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint := baseURL + "/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("provider request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("failed to read provider response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, parseProviderError(resp.StatusCode, respBody) + } + + return respBody, nil +} + +func (c *Client) DoStream(ctx context.Context, baseURL, apiKey string, body map[string]any) (*http.Response, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint := baseURL + "/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("provider stream request failed: %w", err) + } + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + resp.Body.Close() + return nil, parseProviderError(resp.StatusCode, body) + } + + return resp, nil +} + +// ProviderHTTPError carries the upstream HTTP status code. +type ProviderHTTPError struct { + StatusCode int + Body string +} + +func (e *ProviderHTTPError) Error() string { + return fmt.Sprintf("provider error (%d): %s", e.StatusCode, e.Body) +} + +func parseProviderError(statusCode int, body []byte) error { + var errResp struct { + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + } + msg := string(body) + if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" { + msg = errResp.Error.Message + } + + return &ProviderHTTPError{StatusCode: statusCode, Body: msg} +} diff --git a/apps/gateway/internal/adapters/providers/anthropic/mapper.go b/apps/gateway/internal/adapters/providers/anthropic/mapper.go new file mode 100644 index 0000000..9a77615 --- /dev/null +++ b/apps/gateway/internal/adapters/providers/anthropic/mapper.go @@ -0,0 +1,152 @@ +package anthropic + +import ( + "encoding/json" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func buildRequestBody(req domain.GenerateRequest, model domain.PublicModel) map[string]any { + body := map[string]any{ + "model": model.UpstreamModelName, + } + + messages, system := buildMessages(req) + body["messages"] = messages + if system != "" { + body["system"] = system + } + + if req.MaxOutputTokens != nil { + v := *req.MaxOutputTokens + if model.MaxOutputTokens > 0 && v > model.MaxOutputTokens { + v = model.MaxOutputTokens + } + body["max_tokens"] = v + } else { + body["max_tokens"] = model.MaxOutputTokens + } + + if req.Temperature != nil { + body["temperature"] = *req.Temperature + } + if req.TopP != nil { + body["top_p"] = *req.TopP + } + if len(req.Stop) > 0 { + body["stop_sequences"] = req.Stop + } + if req.Stream { + body["stream"] = true + } + + if len(req.Tools) > 0 { + tools := make([]map[string]any, 0, len(req.Tools)) + for _, t := range req.Tools { + tool := map[string]any{ + "name": t.Name, + "description": t.Description, + "input_schema": t.Parameters, + } + tools = append(tools, tool) + } + body["tools"] = tools + } + + if req.ToolChoice != nil { + switch req.ToolChoice.Mode { + case domain.ToolChoiceAuto: + body["tool_choice"] = map[string]any{"type": "auto"} + case domain.ToolChoiceRequired: + body["tool_choice"] = map[string]any{"type": "any"} + case domain.ToolChoiceFunction: + body["tool_choice"] = map[string]any{ + "type": "tool", + "name": *req.ToolChoice.FunctionName, + } + case domain.ToolChoiceNone: + delete(body, "tools") + } + } + + return body +} + +func buildMessages(req domain.GenerateRequest) ([]map[string]any, string) { + var messages []map[string]any + var system string + + if req.Instructions != nil && *req.Instructions != "" { + system = *req.Instructions + } + + for _, item := range req.Input { + role := "user" + if item.Role != nil { + role = *item.Role + } + + if role == "system" || role == "developer" { + if item.Content != nil { + if system != "" { + system += "\n" + } + system += *item.Content + } + continue + } + + msg := map[string]any{ + "role": role, + } + + if role == "tool" && item.ToolCallID != nil { + msg["role"] = "user" + msg["content"] = []map[string]any{ + { + "type": "tool_result", + "tool_use_id": *item.ToolCallID, + "content": safeContent(item.Content), + }, + } + } else if len(item.ToolCalls) > 0 { + content := make([]map[string]any, 0) + if item.Content != nil && *item.Content != "" { + content = append(content, map[string]any{ + "type": "text", + "text": *item.Content, + }) + } + for _, tc := range item.ToolCalls { + content = append(content, map[string]any{ + "type": "tool_use", + "id": tc.ID, + "name": tc.Name, + "input": parseJSONOrString(tc.ArgumentsJSON), + }) + } + msg["content"] = content + } else if item.Content != nil { + msg["content"] = *item.Content + } + + messages = append(messages, msg) + } + + return messages, system +} + +func safeContent(s *string) string { + if s == nil { + return "" + } + return *s +} + +func parseJSONOrString(s string) any { + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return s + } + return v +} diff --git a/apps/gateway/internal/adapters/providers/anthropic/stream.go b/apps/gateway/internal/adapters/providers/anthropic/stream.go new file mode 100644 index 0000000..30e6044 --- /dev/null +++ b/apps/gateway/internal/adapters/providers/anthropic/stream.go @@ -0,0 +1,218 @@ +package anthropic + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// Stream reads SSE events from an Anthropic streaming response. +type Stream struct { + resp *http.Response + scanner *bufio.Scanner + done bool + usage *domain.Usage + toolIdx int +} + +func NewStream(resp *http.Response) *Stream { + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + return &Stream{resp: resp, scanner: scanner} +} + +func (s *Stream) Recv() (domain.StreamEvent, error) { + if s.done { + return domain.StreamEvent{}, io.EOF + } + + var currentEvent string + + for s.scanner.Scan() { + line := s.scanner.Text() + + if line == "" { + continue + } + + if strings.HasPrefix(line, "event: ") { + currentEvent = strings.TrimPrefix(line, "event: ") + continue + } + + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + event := s.processEvent(currentEvent, []byte(data)) + if event != nil { + return *event, nil + } + } + + if err := s.scanner.Err(); err != nil { + return domain.StreamEvent{}, err + } + + s.done = true + return domain.StreamEvent{}, io.EOF +} + +func (s *Stream) Close() error { + s.done = true + return s.resp.Body.Close() +} + +func (s *Stream) processEvent(eventType string, data []byte) *domain.StreamEvent { + switch eventType { + case "content_block_delta": + var delta struct { + Index int `json:"index"` + Delta struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` + } `json:"delta"` + } + if json.Unmarshal(data, &delta) != nil { + return nil + } + + if delta.Delta.Type == "text_delta" { + return &domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ContentDelta: &delta.Delta.Text, + } + } + if delta.Delta.Type == "input_json_delta" { + return &domain.StreamEvent{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: &domain.ToolCallDelta{ + Index: delta.Index, + ArgumentsDelta: &delta.Delta.PartialJSON, + }, + } + } + + case "content_block_start": + var block struct { + Index int `json:"index"` + ContentBlock struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + } `json:"content_block"` + } + if json.Unmarshal(data, &block) != nil { + return nil + } + if block.ContentBlock.Type == "tool_use" { + return &domain.StreamEvent{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: &domain.ToolCallDelta{ + Index: block.Index, + ID: &block.ContentBlock.ID, + Name: &block.ContentBlock.Name, + }, + } + } + + case "message_start": + var msg struct { + Message struct { + Usage struct { + InputTokens int64 `json:"input_tokens"` + CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"` + CacheReadInputTokens int64 `json:"cache_read_input_tokens"` + } `json:"usage"` + } `json:"message"` + } + if json.Unmarshal(data, &msg) == nil { + // Normalize PromptTokens to total input (uncached + cache_creation + cache_read) + totalInput := msg.Message.Usage.InputTokens + msg.Message.Usage.CacheCreationInputTokens + msg.Message.Usage.CacheReadInputTokens + s.usage = &domain.Usage{ + PromptTokens: totalInput, + CacheCreationTokens: msg.Message.Usage.CacheCreationInputTokens, + CacheReadTokens: msg.Message.Usage.CacheReadInputTokens, + } + } + role := "assistant" + return &domain.StreamEvent{ + Type: domain.StreamEventOutputMessageDelta, + Role: &role, + } + + case "message_delta": + var delta struct { + Delta struct { + StopReason string `json:"stop_reason"` + } `json:"delta"` + Usage struct { + OutputTokens int64 `json:"output_tokens"` + } `json:"usage"` + } + if json.Unmarshal(data, &delta) != nil { + return nil + } + + if s.usage != nil { + s.usage.CompletionTokens = delta.Usage.OutputTokens + s.usage.TotalTokens = s.usage.PromptTokens + s.usage.CompletionTokens + } + + finishReason := mapStopReason(delta.Delta.StopReason) + return &domain.StreamEvent{ + Type: domain.StreamEventCompleted, + FinishReason: &finishReason, + Usage: s.usage, + } + + case "message_stop": + s.done = true + return nil + + case "error": + var errData struct { + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + } + msg := "unknown error" + if json.Unmarshal(data, &errData) == nil { + msg = errData.Error.Message + } + return &domain.StreamEvent{ + Type: domain.StreamEventError, + Error: &domain.GatewayError{ + HTTPStatus: 502, + Type: domain.ErrTypeProvider, + Code: domain.ErrCodeProviderUnavailable, + Message: msg, + }, + } + } + + return nil +} + +func mapStopReason(reason string) string { + switch reason { + case "end_turn": + return "stop" + case "max_tokens": + return "length" + case "tool_use": + return "tool_calls" + case "stop_sequence": + return "stop" + default: + return reason + } +} diff --git a/apps/gateway/internal/adapters/providers/openai/adapter.go b/apps/gateway/internal/adapters/providers/openai/adapter.go new file mode 100644 index 0000000..1d17d5e --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/adapter.go @@ -0,0 +1,470 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/google/uuid" +) + +// Adapter is the OpenAI-compatible provider adapter. +type Adapter struct { + client *Client + logger ports.Logger +} + +const ( + // Novita sometimes returns a generic 400 invalid_request_error with only a + // trace_id for otherwise valid tool/chat-history requests, especially on + // Kimi. Retrying the exact same body preserves proxy semantics and avoids + // guessing which OpenAI field caused the rejection. + maxNovitaInvalidTraceSameBodyRetries = 2 + maxNovitaServerOverloadRetries = 1 +) + +func NewAdapter(timeout time.Duration, logger ...ports.Logger) *Adapter { + var l ports.Logger + if len(logger) > 0 { + l = logger[0] + } + return &Adapter{client: NewClient(timeout), logger: l} +} + +func (a *Adapter) Generate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + apiKey := resolveAPIKey(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return domain.GenerateResult{}, domain.ErrProviderUnavailable(model.ProviderConfig.ProviderName + ": API key not configured") + } + + built := buildProviderRequest(req, model) + built.Body["stream"] = false + delete(built.Body, "stream_options") + + activeBuilt := built + var rawResp []byte + var err error + var retryReason string + invalidTraceSameBodyRetries := 0 + serverOverloadRetries := 0 + downgradeRetried := false + for attempt := 1; ; attempt++ { + a.logProviderRequest(ctx, model, activeBuilt, attempt, retryReason) + rawResp, err = a.client.Do(ctx, model.ProviderConfig.BaseURL, apiKey, activeBuilt.Body) + if err == nil { + break + } + if retry := maybeBuildNovitaSameBodyRetry(model, err, activeBuilt.Body); retry.CanRetry { + if canSpendSameBodyRetry(retry.RetryReason, &invalidTraceSameBodyRetries, &serverOverloadRetries) { + retryReason = retry.RetryReason + if !sleepBeforeProviderRetry(ctx, retryReason) { + return domain.GenerateResult{}, withProviderPolicyMeta(mapProviderError(context.Cause(ctx), model.ProviderConfig.ProviderName), activeBuilt, attempt, retryReason) + } + continue + } + } + if !downgradeRetried { + if retry := maybeBuildNovitaDowngradeRetry(model, err, activeBuilt.Body); retry.CanRetry { + downgradeRetried = true + retryReason = retry.RetryReason + activeBuilt = builtProviderRequest{ + Body: retry.Body, + Policy: built.Policy, + Transforms: built.Transforms, + Omitted: retry.Omitted, + } + if !sleepBeforeProviderRetry(ctx, retryReason) { + return domain.GenerateResult{}, withProviderPolicyMeta(mapProviderError(context.Cause(ctx), model.ProviderConfig.ProviderName), activeBuilt, attempt, retryReason) + } + continue + } + } + return domain.GenerateResult{}, withProviderPolicyMeta(mapProviderErrorWithCompatibilityContext(err, model, activeBuilt.Body), activeBuilt, attempt, retryReason) + } + + result, err := parseResponse(rawResp, req, model) + if err != nil { + return domain.GenerateResult{}, err + } + return result, nil +} + +func (a *Adapter) StreamGenerate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (ports.GenerationStream, error) { + apiKey := resolveAPIKey(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return nil, domain.ErrProviderUnavailable(model.ProviderConfig.ProviderName + ": API key not configured") + } + + built := buildProviderRequest(req, model) + + activeBuilt := built + var streamResp *http.Response + var err error + var retryReason string + invalidTraceSameBodyRetries := 0 + serverOverloadRetries := 0 + downgradeRetried := false + for attempt := 1; ; attempt++ { + a.logProviderRequest(ctx, model, activeBuilt, attempt, retryReason) + streamResp, err = a.client.DoStream(ctx, model.ProviderConfig.BaseURL, apiKey, activeBuilt.Body) + if err == nil { + break + } + if retry := maybeBuildNovitaSameBodyRetry(model, err, activeBuilt.Body); retry.CanRetry { + if canSpendSameBodyRetry(retry.RetryReason, &invalidTraceSameBodyRetries, &serverOverloadRetries) { + retryReason = retry.RetryReason + if !sleepBeforeProviderRetry(ctx, retryReason) { + return nil, withProviderPolicyMeta(mapProviderError(context.Cause(ctx), model.ProviderConfig.ProviderName), activeBuilt, attempt, retryReason) + } + continue + } + } + if !downgradeRetried { + if retry := maybeBuildNovitaDowngradeRetry(model, err, activeBuilt.Body); retry.CanRetry { + downgradeRetried = true + retryReason = retry.RetryReason + activeBuilt = builtProviderRequest{ + Body: retry.Body, + Policy: built.Policy, + Transforms: built.Transforms, + Omitted: retry.Omitted, + } + if !sleepBeforeProviderRetry(ctx, retryReason) { + return nil, withProviderPolicyMeta(mapProviderError(context.Cause(ctx), model.ProviderConfig.ProviderName), activeBuilt, attempt, retryReason) + } + continue + } + } + return nil, withProviderPolicyMeta(mapProviderErrorWithCompatibilityContext(err, model, activeBuilt.Body), activeBuilt, attempt, retryReason) + } + + return NewStream(streamResp, model.ProviderConfig.ProviderName), nil +} + +func resolveAPIKey(secretRef string) string { + return os.Getenv(secretRef) +} + +func parseResponse(data json.RawMessage, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + var resp struct { + ID string `json:"id"` + Created int64 `json:"created"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content *string `json:"content"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls,omitempty"` + } `json:"message"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` + Usage *struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int64 `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + } `json:"usage"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return domain.GenerateResult{}, fmt.Errorf("failed to parse provider response: %w", err) + } + + result := domain.GenerateResult{ + ID: resp.ID, + CreatedUnix: resp.Created, + PublicModelID: req.PublicModelID, + ProviderName: model.ProviderConfig.ProviderName, + ProviderModelID: model.ProviderModelID, + } + + if result.ID == "" { + result.ID = uuid.New().String() + } + if result.CreatedUnix == 0 { + result.CreatedUnix = time.Now().Unix() + } + + for _, choice := range resp.Choices { + role := choice.Message.Role + out := domain.OutputItem{ + Type: domain.OutputItemTypeMessage, + Role: &role, + Content: choice.Message.Content, + } + if model.ProviderConfig.ProviderName == "deepseek" { + out.ReasoningContent = choice.Message.ReasoningContent + } + for _, tc := range choice.Message.ToolCalls { + out.ToolCalls = append(out.ToolCalls, domain.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + ArgumentsJSON: tc.Function.Arguments, + }) + } + result.Output = append(result.Output, out) + result.FinishReason = choice.FinishReason + } + + if resp.Usage != nil { + result.Usage = &domain.Usage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + if resp.Usage.PromptTokensDetails != nil { + result.Usage.CacheReadTokens = resp.Usage.PromptTokensDetails.CachedTokens + } + } + + return result, nil +} + +func ParseResponse(data []byte, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + return parseResponse(data, req, model) +} + +func mapProviderError(err error, providerName string) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) { + return domain.ErrClientCanceled().WithMeta("upstream_error", err.Error()) + } + + // If the upstream provider returned an HTTP error, use its status code. + var httpErr *ProviderHTTPError + if errors.As(err, &httpErr) { + var gwErr *domain.GatewayError + switch { + case httpErr.StatusCode == 401 || httpErr.StatusCode == 403: + // Auth / permission / billing failure at the upstream provider. + // Novita uses 403 for invalid API key, insufficient balance, and access denied. + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider auth/permission error: %s", httpErr.Body)) + case httpErr.StatusCode == 429: + // Rate limit or token limit exceeded — surface to client so it can back off. + gwErr = domain.ErrProviderError(429, "provider rate limited: "+httpErr.Body) + case httpErr.StatusCode == 503: + // Service unavailable — surface as 503 so clients know to retry. + gwErr = domain.ErrProviderUnavailable(providerName) + case httpErr.StatusCode >= 500: + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider server error: %s", httpErr.Body)) + default: + // Client errors from upstream (400, 404, 422, etc.) — the gateway + // forwarded a request the provider doesn't accept. + gwErr = domain.ErrProviderError(502, fmt.Sprintf("provider rejected request: %s", httpErr.Body)) + } + return gwErr.WithMeta( + "upstream_status", httpErr.StatusCode, + "upstream_error", httpErr.Body, + "upstream_type", httpErr.Type, + "upstream_code", httpErr.Code, + "upstream_reason", httpErr.Reason, + "upstream_trace_id", httpErr.TraceID, + ) + } + + // Network-level errors (no HTTP response received). + msg := err.Error() + if strings.Contains(msg, "context canceled") { + return domain.ErrClientCanceled().WithMeta("upstream_error", msg) + } + if contains(msg, "timeout") || contains(msg, "deadline") { + return domain.ErrProviderTimeout(providerName).WithMeta("upstream_error", msg) + } + if contains(msg, "connection refused") || contains(msg, "no such host") { + return domain.ErrProviderUnavailable(providerName).WithMeta("upstream_error", msg) + } + return domain.ErrProviderError(502, msg).WithMeta("upstream_error", msg) +} + +func maybeBuildNovitaSameBodyRetry(model domain.PublicModel, err error, body map[string]any) retryBuildResult { + if model.ProviderConfig.ProviderName != "novita" { + return retryBuildResult{} + } + if isNovitaServerOverload(err) { + return retryBuildResult{ + Body: body, + RetryReason: "novita_server_overload_same_body_retry", + CanRetry: true, + } + } + if isGenericNovitaInvalidTraceError(err) && hasToolSurface(body) { + return retryBuildResult{ + Body: body, + RetryReason: "novita_invalid_request_same_body_retry", + CanRetry: true, + } + } + return retryBuildResult{} +} + +func maybeBuildNovitaDowngradeRetry(model domain.PublicModel, err error, body map[string]any) retryBuildResult { + if model.ProviderConfig.ProviderName != "novita" || !isInvalidRequestHTTPError(err) { + return retryBuildResult{} + } + return buildNovitaRetryRequest(body) +} + +func isInvalidRequestHTTPError(err error) bool { + var httpErr *ProviderHTTPError + if !errors.As(err, &httpErr) || httpErr.StatusCode != http.StatusBadRequest { + return false + } + text := strings.ToLower(strings.Join([]string{httpErr.Body, httpErr.Type, httpErr.Reason}, " ")) + return strings.Contains(text, "invalid") +} + +func isGenericNovitaInvalidTraceError(err error) bool { + var httpErr *ProviderHTTPError + if !errors.As(err, &httpErr) || httpErr.StatusCode != http.StatusBadRequest { + return false + } + text := strings.ToLower(httpErr.Body) + return strings.Contains(text, "invalid request error") && httpErr.TraceID != "" +} + +func isNovitaServerOverload(err error) bool { + var httpErr *ProviderHTTPError + if !errors.As(err, &httpErr) || httpErr.StatusCode != http.StatusTooManyRequests { + return false + } + text := strings.ToLower(strings.Join([]string{httpErr.Body, httpErr.Type, httpErr.Reason}, " ")) + return strings.Contains(text, "server_overload") || strings.Contains(text, "server overload") +} + +func canSpendSameBodyRetry(retryReason string, invalidTraceRetries, serverOverloadRetries *int) bool { + switch retryReason { + case "novita_invalid_request_same_body_retry": + if *invalidTraceRetries >= maxNovitaInvalidTraceSameBodyRetries { + return false + } + *invalidTraceRetries++ + return true + case "novita_server_overload_same_body_retry": + if *serverOverloadRetries >= maxNovitaServerOverloadRetries { + return false + } + *serverOverloadRetries++ + return true + default: + return false + } +} + +func hasToolSurface(body map[string]any) bool { + if tools, ok := body["tools"]; ok && tools != nil { + return true + } + messages, ok := body["messages"].([]map[string]any) + if !ok { + return false + } + for _, msg := range messages { + if role, _ := msg["role"].(string); role == "tool" { + return true + } + if _, ok := msg["tool_calls"]; ok { + return true + } + if _, ok := msg["tool_call_id"]; ok { + return true + } + } + return false +} + +func sleepBeforeProviderRetry(ctx context.Context, retryReason string) bool { + delay := 150 * time.Millisecond + if strings.Contains(retryReason, "server_overload") { + delay = 500 * time.Millisecond + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +func mapProviderErrorWithCompatibilityContext(err error, model domain.PublicModel, body map[string]any) error { + if model.ProviderConfig.ProviderName == "novita" && isInvalidRequestHTTPError(err) && hasNonDroppableToolChoice(body) { + var httpErr *ProviderHTTPError + if errors.As(err, &httpErr) { + return domain.ErrProviderError(502, "provider rejected request; Novita may not support the requested tool_choice semantics for this model: "+httpErr.Body).WithMeta( + "upstream_status", httpErr.StatusCode, + "upstream_error", httpErr.Body, + "upstream_type", httpErr.Type, + "upstream_code", httpErr.Code, + "upstream_reason", httpErr.Reason, + "upstream_trace_id", httpErr.TraceID, + "compatibility_note", "tool_choice_required_or_named_not_downgraded", + ) + } + } + return mapProviderError(err, model.ProviderConfig.ProviderName) +} + +func MapProviderErrorWithCompatibilityContext(err error, model domain.PublicModel, body map[string]any) error { + return mapProviderErrorWithCompatibilityContext(err, model, body) +} + +func hasNonDroppableToolChoice(body map[string]any) bool { + toolChoice, ok := body["tool_choice"] + if !ok { + return false + } + if mode, ok := toolChoice.(string); ok { + return mode == domain.ToolChoiceRequired + } + _, named := toolChoice.(map[string]any) + return named +} + +func withProviderPolicyMeta(err error, built builtProviderRequest, attempt int, retryReason string) error { + var gwErr *domain.GatewayError + if !errors.As(err, &gwErr) { + return err + } + fields := []any{ + "provider_policy", built.Policy, + "attempt", attempt, + "provider_params", summarizeProviderBody(built.Body), + } + if retryReason != "" { + fields = append(fields, "retry_reason", retryReason) + if built.Policy == "novita" && retryReason == "novita_invalid_request_same_body_retry" && len(built.Omitted) == 0 { + fields = append(fields, "retry_outcome", "same_body_failed_no_safe_downgrade") + } + } + if len(built.Transforms) > 0 { + fields = append(fields, "transforms", built.Transforms) + } + if len(built.Omitted) > 0 { + fields = append(fields, "omitted_fields", built.Omitted) + } + return gwErr.WithMeta(fields...) +} + +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/apps/gateway/internal/adapters/providers/openai/client.go b/apps/gateway/internal/adapters/providers/openai/client.go new file mode 100644 index 0000000..369874c --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/client.go @@ -0,0 +1,203 @@ +package openai + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "regexp" + "strings" + "time" +) + +// Client handles HTTP communication with OpenAI-compatible APIs. +type Client struct { + httpClient *http.Client + responseTimeout time.Duration +} + +func NewClient(timeout time.Duration) *Client { + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: timeout, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + return &Client{ + httpClient: &http.Client{Transport: transport}, + responseTimeout: timeout, + } +} + +func (c *Client) Do(ctx context.Context, baseURL, apiKey string, body map[string]any) ([]byte, error) { + ctx, cancel := context.WithTimeout(ctx, c.responseTimeout) + defer cancel() + + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("provider request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("failed to read provider response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, parseProviderError(resp.StatusCode, respBody) + } + + return respBody, nil +} + +func (c *Client) DoStream(ctx context.Context, baseURL, apiKey string, body map[string]any) (*http.Response, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("provider stream request failed: %w", err) + } + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + resp.Body.Close() + return nil, parseProviderError(resp.StatusCode, body) + } + + return resp, nil +} + +// ProviderHTTPError carries the upstream HTTP status code so the gateway can +// propagate a meaningful status instead of always returning 502. +type ProviderHTTPError struct { + StatusCode int + Body string + RawBody string + Type string + Code string + Reason string + TraceID string +} + +func (e *ProviderHTTPError) Error() string { + return fmt.Sprintf("provider error (%d): %s", e.StatusCode, e.Body) +} + +func parseProviderError(statusCode int, body []byte) error { + // Try OpenAI standard format: {"error":{"message":"..."}} + var errResp struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + Code any `json:"code"` + } `json:"error"` + } + msg := string(body) + errType := "" + errCode := "" + reason := "" + if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" { + msg = errResp.Error.Message + errType = errResp.Error.Type + errCode = stringifyProviderCode(errResp.Error.Code) + } else { + // Try Novita/flat format: {"message":"...","type":"..."} + var flatErr struct { + Message string `json:"message"` + Type string `json:"type"` + Code any `json:"code"` + Reason string `json:"reason"` + Metadata map[string]any `json:"metadata"` + } + if json.Unmarshal(body, &flatErr) == nil && flatErr.Message != "" { + msg = flatErr.Message + errType = flatErr.Type + errCode = stringifyProviderCode(flatErr.Code) + reason = flatErr.Reason + if traceID, ok := flatErr.Metadata["trace_id"].(string); ok { + return &ProviderHTTPError{ + StatusCode: statusCode, + Body: msg, + RawBody: string(body), + Type: errType, + Code: errCode, + Reason: reason, + TraceID: traceID, + } + } + } + } + + return &ProviderHTTPError{ + StatusCode: statusCode, + Body: msg, + RawBody: string(body), + Type: errType, + Code: errCode, + Reason: reason, + TraceID: extractTraceID(msg), + } +} + +func ParseProviderError(statusCode int, body []byte) error { + return parseProviderError(statusCode, body) +} + +func stringifyProviderCode(v any) string { + if v == nil { + return "" + } + switch typed := v.(type) { + case string: + return typed + case float64: + return fmt.Sprintf("%.0f", typed) + default: + return fmt.Sprint(typed) + } +} + +var traceIDPattern = regexp.MustCompile(`(?i)trace[_ -]?id[:= ]+([a-z0-9_-]+)`) + +func extractTraceID(msg string) string { + match := traceIDPattern.FindStringSubmatch(msg) + if len(match) < 2 { + return "" + } + return strings.TrimSpace(match[1]) +} diff --git a/apps/gateway/internal/adapters/providers/openai/logging.go b/apps/gateway/internal/adapters/providers/openai/logging.go new file mode 100644 index 0000000..e98193a --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/logging.go @@ -0,0 +1,179 @@ +package openai + +import ( + "context" + "sort" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/middleware" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func (a *Adapter) logProviderRequest(ctx context.Context, model domain.PublicModel, built builtProviderRequest, attempt int, retryReason string) { + if a.logger == nil { + return + } + fields := []any{ + "request_id", middleware.GetRequestID(ctx), + "provider", model.ProviderConfig.ProviderName, + "provider_model", model.UpstreamModelName, + "model", model.PublicModelID, + "provider_policy", built.Policy, + "attempt", attempt, + "params", summarizeProviderBody(built.Body), + } + if retryReason != "" { + fields = append(fields, "retry_reason", retryReason) + } + if len(built.Transforms) > 0 { + fields = append(fields, "transforms", built.Transforms) + } + if len(built.Omitted) > 0 { + fields = append(fields, "omitted_fields", built.Omitted) + } + a.logger.Info("provider request", fields...) +} + +func summarizeProviderBody(body map[string]any) map[string]any { + summary := map[string]any{ + "fields": sortedKeys(body), + } + copyScalar(summary, body, "model") + copyScalar(summary, body, "stream") + copyScalar(summary, body, "max_tokens") + copyScalar(summary, body, "max_completion_tokens") + copyScalar(summary, body, "temperature") + copyScalar(summary, body, "top_p") + copyScalar(summary, body, "stop") + copyScalar(summary, body, "presence_penalty") + copyScalar(summary, body, "frequency_penalty") + copyScalar(summary, body, "seed") + copyScalar(summary, body, "logprobs") + copyScalar(summary, body, "top_logprobs") + copyScalar(summary, body, "parallel_tool_calls") + copyScalar(summary, body, "store") + copyScalar(summary, body, "service_tier") + if _, ok := body["user"]; ok { + summary["user"] = "[redacted]" + } + if messages, ok := body["messages"].([]map[string]any); ok { + summary["message_count"] = len(messages) + summary["messages"] = summarizeMessages(messages) + summary["total_content_chars"] = totalMessageContentChars(messages) + } + if tools, ok := body["tools"].([]map[string]any); ok { + summary["tool_count"] = len(tools) + summary["tool_names"] = summarizeToolNames(tools) + } + if toolChoice, ok := body["tool_choice"]; ok { + summary["tool_choice"] = summarizeToolChoice(toolChoice) + } + if responseFormat, ok := body["response_format"].(map[string]any); ok { + if formatType, ok := responseFormat["type"].(string); ok { + summary["response_format"] = formatType + } + } + if streamOptions, ok := body["stream_options"].(map[string]any); ok { + summary["stream_options"] = streamOptions + } + return summary +} + +func copyScalar(summary, body map[string]any, key string) { + if v, ok := body[key]; ok { + summary[key] = v + } +} + +func sortedKeys(body map[string]any) []string { + keys := make([]string, 0, len(body)) + for k := range body { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func summarizeMessages(messages []map[string]any) []map[string]any { + const maxEdge = 3 + total := len(messages) + sampled := make([]map[string]any, 0, min(total, maxEdge*2+1)) + for i, msg := range messages { + if total > maxEdge*2 && i >= maxEdge && i < total-maxEdge { + if i == maxEdge { + sampled = append(sampled, map[string]any{"_omitted": total - maxEdge*2}) + } + continue + } + item := map[string]any{} + if role, ok := msg["role"].(string); ok { + item["role"] = role + } + if content, ok := msg["content"].(string); ok { + item["content_chars"] = len(content) + } else if _, ok := msg["content"]; ok { + item["content_null"] = true + } + if toolCalls, ok := msg["tool_calls"].([]map[string]any); ok { + item["tool_call_count"] = len(toolCalls) + item["tool_call_names"] = summarizeToolCallNames(toolCalls) + } + if reasoningContent, ok := msg["reasoning_content"].(string); ok { + item["reasoning_content_chars"] = len(reasoningContent) + } else if _, ok := msg["reasoning_content"]; ok { + item["has_reasoning_content"] = true + } + if _, ok := msg["tool_call_id"]; ok { + item["has_tool_call_id"] = true + } + sampled = append(sampled, item) + } + return sampled +} + +func totalMessageContentChars(messages []map[string]any) int { + total := 0 + for _, msg := range messages { + if content, ok := msg["content"].(string); ok { + total += len(content) + } + } + return total +} + +func summarizeToolNames(tools []map[string]any) []string { + names := make([]string, 0, len(tools)) + for _, tool := range tools { + fn, _ := tool["function"].(map[string]any) + if name, ok := fn["name"].(string); ok { + names = append(names, name) + } + } + return names +} + +func summarizeToolCallNames(toolCalls []map[string]any) []string { + names := make([]string, 0, len(toolCalls)) + for _, toolCall := range toolCalls { + fn, _ := toolCall["function"].(map[string]any) + if name, ok := fn["name"].(string); ok { + names = append(names, name) + } + } + return names +} + +func summarizeToolChoice(toolChoice any) any { + switch v := toolChoice.(type) { + case string: + return v + case map[string]any: + fn, _ := v["function"].(map[string]any) + name, _ := fn["name"].(string) + return map[string]any{ + "type": v["type"], + "function_name": name, + } + default: + return "[present]" + } +} diff --git a/apps/gateway/internal/adapters/providers/openai/mapper.go b/apps/gateway/internal/adapters/providers/openai/mapper.go new file mode 100644 index 0000000..ab2c41c --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/mapper.go @@ -0,0 +1,309 @@ +package openai + +import ( + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +type providerPolicy struct { + name string + useMaxTokens bool + explicitNullAssistantToolContent bool + forwardAssistantReasoningContent bool + requireToolReasoningContent bool + forwardDeveloperRole bool +} + +type builtProviderRequest struct { + Body map[string]any + Policy string + Transforms []string + Omitted []string +} + +type retryBuildResult struct { + Body map[string]any + Omitted []string + RetryReason string + CanRetry bool +} + +func buildRequestBody(req domain.GenerateRequest, model domain.PublicModel) map[string]any { + return buildProviderRequest(req, model).Body +} + +func BuildRequestBody(req domain.GenerateRequest, model domain.PublicModel) map[string]any { + return buildRequestBody(req, model) +} + +func buildProviderRequest(req domain.GenerateRequest, model domain.PublicModel) builtProviderRequest { + policy := policyForProvider(model.ProviderConfig.ProviderName) + body := map[string]any{ + "model": model.UpstreamModelName, + } + transforms := make([]string, 0, 2) + + messages, messageTransforms := buildMessages(req, policy) + body["messages"] = messages + transforms = append(transforms, messageTransforms...) + + if req.MaxOutputTokens != nil { + v := *req.MaxOutputTokens + if model.MaxOutputTokens > 0 && v > model.MaxOutputTokens { + v = model.MaxOutputTokens + } + if policy.useMaxTokens { + body["max_tokens"] = v + transforms = append(transforms, "token_limit_field=max_tokens") + } else { + body["max_completion_tokens"] = v + } + } + if req.Temperature != nil { + body["temperature"] = *req.Temperature + } + if req.ReasoningEffort != nil && *req.ReasoningEffort != "" { + body["reasoning_effort"] = *req.ReasoningEffort + } + if req.TopP != nil { + body["top_p"] = *req.TopP + } + if len(req.Stop) > 0 { + if len(req.Stop) == 1 { + body["stop"] = req.Stop[0] + } else { + body["stop"] = req.Stop + } + } + if req.Stream { + body["stream"] = true + body["stream_options"] = map[string]any{"include_usage": true} + } + if req.User != nil { + body["user"] = *req.User + } + + if len(req.Tools) > 0 { + tools := make([]map[string]any, 0, len(req.Tools)) + for _, t := range req.Tools { + fn := map[string]any{ + "name": t.Name, + "description": t.Description, + "parameters": t.Parameters, + } + // Only include "strict" when explicitly true; many providers + // reject this OpenAI-specific field. + if t.Strict { + fn["strict"] = true + } + tool := map[string]any{ + "type": "function", + "function": fn, + } + tools = append(tools, tool) + } + body["tools"] = tools + } + + if req.ToolChoice != nil { + switch req.ToolChoice.Mode { + case domain.ToolChoiceNone, domain.ToolChoiceAuto, domain.ToolChoiceRequired: + body["tool_choice"] = req.ToolChoice.Mode + case domain.ToolChoiceFunction: + body["tool_choice"] = map[string]any{ + "type": "function", + "function": map[string]any{ + "name": *req.ToolChoice.FunctionName, + }, + } + } + } + + if req.ParallelToolCalls != nil { + body["parallel_tool_calls"] = *req.ParallelToolCalls + } + + // Pass-through parameters + if req.PresencePenalty != nil { + body["presence_penalty"] = *req.PresencePenalty + } + if req.FrequencyPenalty != nil { + body["frequency_penalty"] = *req.FrequencyPenalty + } + if len(req.LogitBias) > 0 { + body["logit_bias"] = req.LogitBias + } + if req.Seed != nil { + body["seed"] = *req.Seed + } + if req.Logprobs != nil { + body["logprobs"] = *req.Logprobs + } + if req.TopLogprobs != nil { + body["top_logprobs"] = *req.TopLogprobs + } + if req.Store != nil { + body["store"] = *req.Store + } + if req.ServiceTier != nil { + body["service_tier"] = *req.ServiceTier + } + + if req.TextConfig != nil && req.TextConfig.FormatType != nil { + switch *req.TextConfig.FormatType { + case "json_object": + body["response_format"] = map[string]any{"type": "json_object"} + case "json_schema": + rf := map[string]any{"type": "json_schema"} + if req.TextConfig.JSONSchema != nil { + rf["json_schema"] = req.TextConfig.JSONSchema + } + body["response_format"] = rf + } + } + + return builtProviderRequest{ + Body: body, + Policy: policy.name, + Transforms: transforms, + } +} + +func policyForProvider(providerName string) providerPolicy { + switch providerName { + case "deepseek": + // DeepSeek exposes an OpenAI-compatible chat API at /chat/completions, + // but currently documents the legacy `max_tokens` field and nullable + // assistant tool-call content. + return providerPolicy{ + name: "deepseek", + useMaxTokens: true, + explicitNullAssistantToolContent: true, + forwardAssistantReasoningContent: true, + requireToolReasoningContent: true, + } + case "novita": + // Novita exposes an OpenAI-compatible API, but its public Chat + // Completions docs currently document `max_tokens` rather than + // OpenAI's newer `max_completion_tokens`, and require a `content` + // field that may be null for assistant tool-call messages. Keep these + // as Novita-only wire-shape translations; do not leak them into the + // public OpenAI-compatible gateway API. + return providerPolicy{ + name: "novita", + useMaxTokens: true, + explicitNullAssistantToolContent: true, + } + default: + return providerPolicy{ + name: "openai-compatible", + forwardDeveloperRole: providerName == "openai", + } + } +} + +func buildMessages(req domain.GenerateRequest, policy providerPolicy) ([]map[string]any, []string) { + var messages []map[string]any + var transforms []string + + if req.Instructions != nil && *req.Instructions != "" { + messages = append(messages, map[string]any{ + "role": "system", + "content": *req.Instructions, + }) + } + + for _, item := range req.Input { + msg := map[string]any{} + + role := "user" + if item.Role != nil { + role = *item.Role + } + if role == "developer" && !policy.forwardDeveloperRole { + role = "system" + transforms = append(transforms, "developer_role=system") + } + msg["role"] = role + + if item.Content != nil { + msg["content"] = *item.Content + } else if policy.explicitNullAssistantToolContent && role == "assistant" && len(item.ToolCalls) > 0 { + msg["content"] = nil + transforms = append(transforms, "assistant_tool_content=null") + } + + if item.ToolCallID != nil { + msg["tool_call_id"] = *item.ToolCallID + } + + if role == "assistant" && policy.forwardAssistantReasoningContent { + if item.ReasoningContent != nil { + msg["reasoning_content"] = *item.ReasoningContent + } else if policy.requireToolReasoningContent && len(item.ToolCalls) > 0 { + msg["reasoning_content"] = "" + transforms = append(transforms, "assistant_tool_reasoning_content=empty") + } + } + + if len(item.ToolCalls) > 0 { + tcs := make([]map[string]any, 0, len(item.ToolCalls)) + for _, tc := range item.ToolCalls { + tcs = append(tcs, map[string]any{ + "id": tc.ID, + "type": "function", + "function": map[string]any{ + "name": tc.Name, + "arguments": tc.ArgumentsJSON, + }, + }) + } + msg["tool_calls"] = tcs + } + + messages = append(messages, msg) + } + + return messages, transforms +} + +func buildNovitaRetryRequest(body map[string]any) retryBuildResult { + retryBody := cloneBody(body) + omitted := make([]string, 0, 6) + + for _, field := range []string{"parallel_tool_calls", "store", "service_tier", "user"} { + if _, ok := retryBody[field]; ok { + delete(retryBody, field) + omitted = append(omitted, field) + } + } + + if toolChoice, ok := retryBody["tool_choice"]; ok { + switch v := toolChoice.(type) { + case string: + switch v { + case domain.ToolChoiceAuto: + delete(retryBody, "tool_choice") + omitted = append(omitted, "tool_choice=auto") + case domain.ToolChoiceNone: + delete(retryBody, "tool_choice") + delete(retryBody, "tools") + omitted = append(omitted, "tool_choice=none", "tools") + } + } + } + + return retryBuildResult{ + Body: retryBody, + Omitted: omitted, + RetryReason: "novita_invalid_request_guarded_downgrade", + CanRetry: len(omitted) > 0, + } +} + +func cloneBody(body map[string]any) map[string]any { + clone := make(map[string]any, len(body)) + for k, v := range body { + clone[k] = v + } + return clone +} diff --git a/apps/gateway/internal/adapters/providers/openai/policy_test.go b/apps/gateway/internal/adapters/providers/openai/policy_test.go new file mode 100644 index 0000000..f4bdaeb --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/policy_test.go @@ -0,0 +1,633 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestBuildProviderRequest_DefaultPolicyTokenLimit(t *testing.T) { + maxTokens := 32 + req := domain.GenerateRequest{ + PublicModelID: "openai/test", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + } + model := domain.PublicModel{ + UpstreamModelName: "test-model", + ProviderConfig: domain.ProviderConfig{ProviderName: "openai"}, + MaxOutputTokens: 128, + } + + built := buildProviderRequest(req, model) + if built.Policy != "openai-compatible" { + t.Fatalf("policy = %q, want openai-compatible", built.Policy) + } + if got := built.Body["max_completion_tokens"]; got != 32 { + t.Fatalf("max_completion_tokens = %v, want 32", got) + } + if _, ok := built.Body["max_tokens"]; ok { + t.Fatal("default policy must not send max_tokens") + } +} + +func TestBuildProviderRequest_ForwardsReasoningEffort(t *testing.T) { + effort := "low" + maxTokens := 1024 + req := domain.GenerateRequest{ + PublicModelID: "phala/gpt-oss-20b", + Input: []domain.InputItem{message("user", "hi")}, + ReasoningEffort: &effort, + MaxOutputTokens: &maxTokens, + } + model := domain.PublicModel{ + UpstreamModelName: "phala/gpt-oss-20b", + ProviderConfig: domain.ProviderConfig{ProviderName: "phala"}, + MaxOutputTokens: 1024, + } + + built := buildProviderRequest(req, model) + if got := built.Body["reasoning_effort"]; got != "low" { + t.Fatalf("reasoning_effort = %v, want low", got) + } +} + +func TestBuildProviderRequest_DeveloperRolePolicy(t *testing.T) { + req := domain.GenerateRequest{ + PublicModelID: "test/developer-role", + Input: []domain.InputItem{message("developer", "follow instructions")}, + } + + compatible := buildProviderRequest(req, domain.PublicModel{ + UpstreamModelName: "compatible-model", + ProviderConfig: domain.ProviderConfig{ProviderName: "phala"}, + }) + messages := compatible.Body["messages"].([]map[string]any) + if messages[0]["role"] != "system" { + t.Fatalf("compatible role = %v, want system", messages[0]["role"]) + } + if !containsString(compatible.Transforms, "developer_role=system") { + t.Fatalf("transforms = %v, want developer_role=system", compatible.Transforms) + } + + openaiBuilt := buildProviderRequest(req, domain.PublicModel{ + UpstreamModelName: "gpt-5", + ProviderConfig: domain.ProviderConfig{ProviderName: "openai"}, + }) + messages = openaiBuilt.Body["messages"].([]map[string]any) + if messages[0]["role"] != "developer" { + t.Fatalf("openai role = %v, want developer", messages[0]["role"]) + } +} + +func TestBuildProviderRequest_OmittedTokenLimitStaysOmitted(t *testing.T) { + req := domain.GenerateRequest{ + PublicModelID: "openai/test", + Input: []domain.InputItem{message("user", "hi")}, + } + model := domain.PublicModel{ + UpstreamModelName: "test-model", + ProviderConfig: domain.ProviderConfig{ProviderName: "openai"}, + MaxOutputTokens: 128, + } + + built := buildProviderRequest(req, model) + if _, ok := built.Body["max_completion_tokens"]; ok { + t.Fatal("max_completion_tokens must stay omitted when the client omits a token limit") + } + if _, ok := built.Body["max_tokens"]; ok { + t.Fatal("max_tokens must stay omitted when the client omits a token limit") + } +} + +func TestBuildProviderRequest_NovitaPolicy(t *testing.T) { + maxTokens := 32 + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + } + model := novitaModel("http://example.test") + + built := buildProviderRequest(req, model) + if built.Policy != "novita" { + t.Fatalf("policy = %q, want novita", built.Policy) + } + if got := built.Body["max_tokens"]; got != 32 { + t.Fatalf("max_tokens = %v, want 32", got) + } + if _, ok := built.Body["max_completion_tokens"]; ok { + t.Fatal("Novita policy must not send max_completion_tokens") + } + if !containsString(built.Transforms, "token_limit_field=max_tokens") { + t.Fatalf("transforms = %v, want token_limit_field=max_tokens", built.Transforms) + } +} + +func TestBuildProviderRequest_DeepSeekPolicy(t *testing.T) { + maxTokens := 32 + req := domain.GenerateRequest{ + PublicModelID: "deepseek/deepseek-v4-pro", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + } + model := deepseekModel("http://example.test") + + built := buildProviderRequest(req, model) + if built.Policy != "deepseek" { + t.Fatalf("policy = %q, want deepseek", built.Policy) + } + if got := built.Body["max_tokens"]; got != 32 { + t.Fatalf("max_tokens = %v, want 32", got) + } + if _, ok := built.Body["max_completion_tokens"]; ok { + t.Fatal("DeepSeek policy must not send max_completion_tokens") + } + if !containsString(built.Transforms, "token_limit_field=max_tokens") { + t.Fatalf("transforms = %v, want token_limit_field=max_tokens", built.Transforms) + } +} + +func TestBuildProviderRequest_DeepSeekAssistantToolContentNull(t *testing.T) { + req := domain.GenerateRequest{ + PublicModelID: "deepseek/deepseek-v4-pro", + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: stringPtr("assistant"), + ToolCalls: []domain.ToolCall{{ID: "call_1", Name: "noop", ArgumentsJSON: "{}"}}, + }, + }, + } + + built := buildProviderRequest(req, deepseekModel("http://example.test")) + messages := built.Body["messages"].([]map[string]any) + if _, ok := messages[0]["content"]; !ok { + t.Fatal("DeepSeek assistant tool-call message must include content key") + } + if messages[0]["content"] != nil { + t.Fatalf("content = %v, want nil", messages[0]["content"]) + } + if !containsString(built.Transforms, "assistant_tool_content=null") { + t.Fatalf("transforms = %v, want assistant_tool_content=null", built.Transforms) + } + if got, ok := messages[0]["reasoning_content"].(string); !ok || got != "" { + t.Fatalf("reasoning_content = %#v, want empty string", messages[0]["reasoning_content"]) + } + if !containsString(built.Transforms, "assistant_tool_reasoning_content=empty") { + t.Fatalf("transforms = %v, want assistant_tool_reasoning_content=empty", built.Transforms) + } +} + +func TestBuildProviderRequest_DeepSeekPreservesAssistantReasoningContent(t *testing.T) { + reasoning := "used the search result" + req := domain.GenerateRequest{ + PublicModelID: "deepseek/deepseek-v4-pro", + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: stringPtr("assistant"), + ReasoningContent: &reasoning, + ToolCalls: []domain.ToolCall{{ID: "call_1", Name: "noop", ArgumentsJSON: "{}"}}, + }, + }, + } + + built := buildProviderRequest(req, deepseekModel("http://example.test")) + messages := built.Body["messages"].([]map[string]any) + if got := messages[0]["reasoning_content"]; got != reasoning { + t.Fatalf("reasoning_content = %#v, want %q", got, reasoning) + } + if containsString(built.Transforms, "assistant_tool_reasoning_content=empty") { + t.Fatalf("transforms = %v, must not synthesize empty reasoning when client supplied it", built.Transforms) + } +} + +func TestBuildProviderRequest_NovitaAssistantToolContentNull(t *testing.T) { + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: stringPtr("assistant"), + ToolCalls: []domain.ToolCall{{ID: "call_1", Name: "noop", ArgumentsJSON: "{}"}}, + }, + }, + } + + built := buildProviderRequest(req, novitaModel("http://example.test")) + messages := built.Body["messages"].([]map[string]any) + if _, ok := messages[0]["content"]; !ok { + t.Fatal("Novita assistant tool-call message must include content key") + } + if messages[0]["content"] != nil { + t.Fatalf("content = %v, want nil", messages[0]["content"]) + } + if !containsString(built.Transforms, "assistant_tool_content=null") { + t.Fatalf("transforms = %v, want assistant_tool_content=null", built.Transforms) + } + if _, ok := messages[0]["reasoning_content"]; ok { + t.Fatal("Novita assistant tool-call message must not include DeepSeek reasoning_content compatibility field") + } +} + +func TestBuildNovitaRetryRequest_GuardedDowngrade(t *testing.T) { + body := map[string]any{ + "model": "moonshotai/kimi-k2.6", + "parallel_tool_calls": false, + "store": true, + "service_tier": "auto", + "user": "user-1", + "tool_choice": "auto", + } + + retry := buildNovitaRetryRequest(body) + if !retry.CanRetry { + t.Fatal("expected retry to be allowed") + } + for _, field := range []string{"parallel_tool_calls", "store", "service_tier", "user", "tool_choice"} { + if _, ok := retry.Body[field]; ok { + t.Fatalf("retry body still contains %s", field) + } + } + for _, omitted := range []string{"parallel_tool_calls", "store", "service_tier", "user", "tool_choice=auto"} { + if !containsString(retry.Omitted, omitted) { + t.Fatalf("omitted = %v, missing %s", retry.Omitted, omitted) + } + } +} + +func TestBuildNovitaRetryRequest_ToolChoiceNoneRemovesTools(t *testing.T) { + body := map[string]any{ + "model": "moonshotai/kimi-k2.6", + "tool_choice": "none", + "tools": []map[string]any{{"type": "function"}}, + } + + retry := buildNovitaRetryRequest(body) + if !retry.CanRetry { + t.Fatal("expected retry to be allowed") + } + if _, ok := retry.Body["tool_choice"]; ok { + t.Fatal("retry body still contains tool_choice") + } + if _, ok := retry.Body["tools"]; ok { + t.Fatal("retry body still contains tools") + } +} + +func TestBuildNovitaRetryRequest_NamedToolChoiceIsNotDowngraded(t *testing.T) { + body := map[string]any{ + "model": "moonshotai/kimi-k2.6", + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{"name": "noop"}, + }, + } + + retry := buildNovitaRetryRequest(body) + if retry.CanRetry { + t.Fatalf("named tool_choice must not be downgraded, omitted = %v", retry.Omitted) + } +} + +func TestAdapterGenerate_NovitaRetriesSafeDowngrade(t *testing.T) { + t.Setenv("NOVITA_TEST_KEY", "test-key") + var attempts int + var bodies []map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + defer r.Body.Close() + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request: %v", err) + } + bodies = append(bodies, body) + if attempts <= 3 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message":"invalid request error trace_id: testtrace","type":"invalid_request_error"}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"id":"cmpl-1","created":123,"choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)) + })) + defer server.Close() + + maxTokens := 8 + parallel := false + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + ParallelToolCalls: ¶llel, + Store: boolPtr(true), + ServiceTier: stringPtr("auto"), + User: stringPtr("user-1"), + ToolChoice: &domain.ToolChoice{Mode: domain.ToolChoiceAuto}, + Tools: []domain.ToolDefinition{{Name: "noop", Parameters: map[string]any{"type": "object"}}}, + } + + adapter := &Adapter{ + client: &Client{httpClient: server.Client(), responseTimeout: time.Second}, + } + _, err := adapter.Generate(context.Background(), req, novitaModel(server.URL)) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if attempts != 4 { + t.Fatalf("attempts = %d, want 4", attempts) + } + for _, field := range []string{"parallel_tool_calls", "store", "service_tier", "user", "tool_choice"} { + if _, ok := bodies[1][field]; !ok { + t.Fatalf("same-body retry should still contain %s", field) + } + } + for _, field := range []string{"parallel_tool_calls", "store", "service_tier", "user", "tool_choice"} { + if _, ok := bodies[2][field]; !ok { + t.Fatalf("second same-body retry should still contain %s", field) + } + } + for _, field := range []string{"parallel_tool_calls", "store", "service_tier", "user", "tool_choice"} { + if _, ok := bodies[3][field]; ok { + t.Fatalf("downgrade retry body still contains %s", field) + } + } + if _, ok := bodies[3]["tools"]; !ok { + t.Fatal("tool_choice=auto retry should keep tools") + } +} + +func TestAdapterGenerate_NovitaNamedToolChoiceFailsClearlyAfterSameBodyRetry(t *testing.T) { + t.Setenv("NOVITA_TEST_KEY", "test-key") + var attempts int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message":"invalid request error trace_id: namedtrace","type":"invalid_request_error"}`)) + })) + defer server.Close() + + maxTokens := 8 + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + ToolChoice: &domain.ToolChoice{ + Mode: domain.ToolChoiceFunction, + FunctionName: stringPtr("noop"), + }, + Tools: []domain.ToolDefinition{{Name: "noop", Parameters: map[string]any{"type": "object"}}}, + } + + adapter := &Adapter{ + client: &Client{httpClient: server.Client(), responseTimeout: time.Second}, + } + _, err := adapter.Generate(context.Background(), req, novitaModel(server.URL)) + if err == nil { + t.Fatal("expected error") + } + if attempts != 3 { + t.Fatalf("attempts = %d, want 3", attempts) + } + if !strings.Contains(err.Error(), "tool_choice") { + t.Fatalf("error = %q, want clear tool_choice context", err.Error()) + } +} + +func TestAdapterGenerate_NovitaSafeRetryStillReportsNamedToolChoiceIncompatibility(t *testing.T) { + t.Setenv("NOVITA_TEST_KEY", "test-key") + var attempts int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message":"invalid request error trace_id: namedtrace","type":"invalid_request_error"}`)) + })) + defer server.Close() + + maxTokens := 8 + parallel := false + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + ParallelToolCalls: ¶llel, + ToolChoice: &domain.ToolChoice{ + Mode: domain.ToolChoiceFunction, + FunctionName: stringPtr("noop"), + }, + Tools: []domain.ToolDefinition{{Name: "noop", Parameters: map[string]any{"type": "object"}}}, + } + + adapter := &Adapter{ + client: &Client{httpClient: server.Client(), responseTimeout: time.Second}, + } + _, err := adapter.Generate(context.Background(), req, novitaModel(server.URL)) + if err == nil { + t.Fatal("expected error") + } + if attempts != 4 { + t.Fatalf("attempts = %d, want 4", attempts) + } + if !strings.Contains(err.Error(), "tool_choice") { + t.Fatalf("error = %q, want clear tool_choice context", err.Error()) + } +} + +func TestAdapterGenerate_NovitaFailedSameBodyRetryIncludesProviderParams(t *testing.T) { + t.Setenv("NOVITA_TEST_KEY", "test-key") + var attempts int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message":"invalid request error trace_id: tooltrace","type":"invalid_request_error"}`)) + })) + defer server.Close() + + maxTokens := 8 + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + Tools: []domain.ToolDefinition{{Name: "noop", Parameters: map[string]any{"type": "object"}}}, + } + + adapter := &Adapter{ + client: &Client{httpClient: server.Client(), responseTimeout: time.Second}, + } + _, err := adapter.Generate(context.Background(), req, novitaModel(server.URL)) + if err == nil { + t.Fatal("expected error") + } + if attempts != 3 { + t.Fatalf("attempts = %d, want 3", attempts) + } + + var gwErr *domain.GatewayError + if !errors.As(err, &gwErr) { + t.Fatalf("error = %T, want *domain.GatewayError", err) + } + if got := gwErr.Metadata["retry_outcome"]; got != "same_body_failed_no_safe_downgrade" { + t.Fatalf("retry_outcome = %v, want same_body_failed_no_safe_downgrade", got) + } + params, ok := gwErr.Metadata["provider_params"].(map[string]any) + if !ok { + t.Fatalf("provider_params = %T, want map[string]any", gwErr.Metadata["provider_params"]) + } + if got := params["tool_count"]; got != 1 { + t.Fatalf("provider_params.tool_count = %v, want 1", got) + } + if _, ok := params["messages"]; !ok { + t.Fatalf("provider_params = %v, want redacted message summary", params) + } +} + +func TestAdapterGenerate_NovitaServerOverloadRetriesSameBody(t *testing.T) { + t.Setenv("NOVITA_TEST_KEY", "test-key") + var attempts int + var firstBody map[string]any + var secondBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + defer r.Body.Close() + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request: %v", err) + } + if attempts == 1 { + firstBody = body + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"message":"server overload, please try again later trace_id: overloadtrace","type":"server_overload"}`)) + return + } + secondBody = body + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"id":"cmpl-1","created":123,"choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)) + })) + defer server.Close() + + maxTokens := 8 + parallel := false + req := domain.GenerateRequest{ + PublicModelID: "moonshotai/kimi-k2.6", + Input: []domain.InputItem{message("user", "hi")}, + MaxOutputTokens: &maxTokens, + ParallelToolCalls: ¶llel, + Tools: []domain.ToolDefinition{{Name: "noop", Parameters: map[string]any{"type": "object"}}}, + } + + adapter := &Adapter{ + client: &Client{httpClient: server.Client(), responseTimeout: time.Second}, + } + _, err := adapter.Generate(context.Background(), req, novitaModel(server.URL)) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if attempts != 2 { + t.Fatalf("attempts = %d, want 2", attempts) + } + if got, want := secondBody["parallel_tool_calls"], firstBody["parallel_tool_calls"]; got != want { + t.Fatalf("same-body retry changed parallel_tool_calls: got %v, want %v", got, want) + } + if _, ok := secondBody["tools"]; !ok { + t.Fatal("same-body retry should keep tools") + } +} + +func TestSummarizeProviderBodyRedactsPromptTextUserAndToolSchema(t *testing.T) { + body := map[string]any{ + "model": "moonshotai/kimi-k2.6", + "user": "raw-user-id", + "messages": []map[string]any{ + {"role": "user", "content": "secret prompt text"}, + }, + "tools": []map[string]any{ + { + "type": "function", + "function": map[string]any{ + "name": "lookup", + "description": "secret tool description", + "parameters": map[string]any{"description": "secret schema"}, + }, + }, + }, + } + + summary := summarizeProviderBody(body) + raw, err := json.Marshal(summary) + if err != nil { + t.Fatalf("marshal summary: %v", err) + } + got := string(raw) + for _, secret := range []string{"secret prompt text", "raw-user-id", "secret tool description", "secret schema"} { + if strings.Contains(got, secret) { + t.Fatalf("summary leaked %q: %s", secret, got) + } + } + if !strings.Contains(got, `"user":"[redacted]"`) { + t.Fatalf("summary = %s, want redacted user marker", got) + } + if !strings.Contains(got, `"content_chars":18`) { + t.Fatalf("summary = %s, want content length", got) + } + if !strings.Contains(got, `"tool_names":["lookup"]`) { + t.Fatalf("summary = %s, want tool name only", got) + } +} + +func novitaModel(baseURL string) domain.PublicModel { + return domain.PublicModel{ + PublicModelID: "moonshotai/kimi-k2.6", + UpstreamModelName: "moonshotai/kimi-k2.6", + MaxOutputTokens: 262144, + ProviderConfig: domain.ProviderConfig{ + ProviderName: "novita", + BaseURL: baseURL, + APIKeySecretRef: "NOVITA_TEST_KEY", + }, + } +} + +func deepseekModel(baseURL string) domain.PublicModel { + return domain.PublicModel{ + PublicModelID: "deepseek/deepseek-v4-pro", + UpstreamModelName: "deepseek-v4-pro", + MaxOutputTokens: 8192, + ProviderConfig: domain.ProviderConfig{ + ProviderName: "deepseek", + BaseURL: baseURL, + APIKeySecretRef: "DEEPSEEK_TEST_KEY", + }, + } +} + +func message(role, content string) domain.InputItem { + return domain.InputItem{ + Type: domain.InputItemTypeMessage, + Role: stringPtr(role), + Content: stringPtr(content), + } +} + +func stringPtr(v string) *string { + return &v +} + +func boolPtr(v bool) *bool { + return &v +} + +func containsString(values []string, want string) bool { + for _, v := range values { + if v == want { + return true + } + } + return false +} diff --git a/apps/gateway/internal/adapters/providers/openai/stream.go b/apps/gateway/internal/adapters/providers/openai/stream.go new file mode 100644 index 0000000..a6ca071 --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/stream.go @@ -0,0 +1,297 @@ +package openai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// Stream reads SSE events from an OpenAI-compatible streaming response. +type Stream struct { + resp *http.Response + scanner *bufio.Scanner + done bool + includeReasoningContent bool + deferredCompleted *domain.StreamEvent // stashed when tool-call delta + finish_reason arrive in one chunk +} + +func NewStream(resp *http.Response, providerName ...string) *Stream { + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + includeReasoningContent := len(providerName) > 0 && providerName[0] == "deepseek" + return &Stream{ + resp: resp, + scanner: scanner, + includeReasoningContent: includeReasoningContent, + } +} + +func (s *Stream) Recv() (domain.StreamEvent, error) { + if s.done { + return domain.StreamEvent{}, io.EOF + } + + // Return a stashed completion event from a previous chunk that carried + // both a tool-call delta and a finish_reason. + if s.deferredCompleted != nil { + event := *s.deferredCompleted + s.deferredCompleted = nil + return event, nil + } + + for s.scanner.Scan() { + line := s.scanner.Text() + + if line == "" { + continue + } + + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + if data == "[DONE]" { + s.done = true + return domain.StreamEvent{}, io.EOF + } + + var chunk chatCompletionChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + events := mapChunkToStreamEvents(chunk, s.includeReasoningContent) + if len(events) == 0 { + continue + } + for i := range events { + events[i].ProviderResponseID = chunk.ID + } + // If the mapper produced two events (tool-call delta + completed), + // return the first now and stash the second for the next Recv(). + if len(events) > 1 { + s.deferredCompleted = &events[1] + } + return events[0], nil + } + + if err := s.scanner.Err(); err != nil { + return domain.StreamEvent{}, err + } + + s.done = true + return domain.StreamEvent{}, io.EOF +} + +func (s *Stream) Close() error { + s.done = true + return s.resp.Body.Close() +} + +type chatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Delta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []struct { + Index int `json:"index"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + } `json:"function"` + } `json:"tool_calls,omitempty"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + } `json:"choices"` + Usage *struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int64 `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + } `json:"usage,omitempty"` +} + +func mapChunkToStreamEvents(chunk chatCompletionChunk, includeReasoningContent ...bool) []domain.StreamEvent { + keepReasoning := len(includeReasoningContent) > 0 && includeReasoningContent[0] + // Handle usage-only chunk (often last chunk with stream_options.include_usage) + if len(chunk.Choices) == 0 && chunk.Usage != nil { + return []domain.StreamEvent{{ + Type: domain.StreamEventCompleted, + Usage: chunkUsageToDomain(chunk.Usage), + }} + } + + if len(chunk.Choices) == 0 { + return nil + } + + choice := chunk.Choices[0] + + // Some providers (e.g. Novita/MiniMax) send reasoning_content chunks + // with finish_reason set before the actual content chunk. If we honour + // that finish_reason the stream closes before real content arrives. + // Neutralise it so the subsequent content chunk carries the real signal. + // The stream still terminates via upstream "data: [DONE]" / EOF even if + // the later chunk happens to lack finish_reason. + if choice.FinishReason != nil && choice.Delta.ReasoningContent != nil { + hasContent := choice.Delta.Content != nil && *choice.Delta.Content != "" + if !hasContent && len(choice.Delta.ToolCalls) == 0 { + choice.FinishReason = nil + } + } + + // When a chunk carries both a tool-call delta AND a finish_reason (some + // providers, e.g. MiniMax, pack the final argument fragment and the + // finish signal into one chunk), we must emit the tool-call delta + // FIRST so the client receives the complete JSON arguments before the + // stream is marked done. + if choice.FinishReason != nil && len(choice.Delta.ToolCalls) > 0 { + tc := choice.Delta.ToolCalls[0] + tcd := &domain.ToolCallDelta{ + Index: tc.Index, + } + if tc.ID != "" { + tcd.ID = &tc.ID + } + if tc.Function.Name != "" { + tcd.Name = &tc.Function.Name + } + if tc.Function.Arguments != "" { + tcd.ArgumentsDelta = &tc.Function.Arguments + } + completedEvent := domain.StreamEvent{ + Type: domain.StreamEventCompleted, + FinishReason: choice.FinishReason, + } + if choice.Delta.Content != nil && *choice.Delta.Content != "" { + completedEvent.ContentDelta = choice.Delta.Content + } + if keepReasoning && choice.Delta.ReasoningContent != nil && *choice.Delta.ReasoningContent != "" { + completedEvent.ReasoningDelta = choice.Delta.ReasoningContent + } + if chunk.Usage != nil { + completedEvent.Usage = chunkUsageToDomain(chunk.Usage) + } + return []domain.StreamEvent{ + { + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: tcd, + }, + completedEvent, + } + } + + // Check for finish reason -> completed event + if choice.FinishReason != nil { + event := domain.StreamEvent{ + Type: domain.StreamEventCompleted, + FinishReason: choice.FinishReason, + } + if choice.Delta.Content != nil && *choice.Delta.Content != "" { + event.ContentDelta = choice.Delta.Content + } + if keepReasoning && choice.Delta.ReasoningContent != nil && *choice.Delta.ReasoningContent != "" { + event.ReasoningDelta = choice.Delta.ReasoningContent + } + if chunk.Usage != nil { + event.Usage = chunkUsageToDomain(chunk.Usage) + } + return []domain.StreamEvent{event} + } + + // Tool call delta + if len(choice.Delta.ToolCalls) > 0 { + tc := choice.Delta.ToolCalls[0] + tcd := &domain.ToolCallDelta{ + Index: tc.Index, + } + if tc.ID != "" { + tcd.ID = &tc.ID + } + if tc.Function.Name != "" { + tcd.Name = &tc.Function.Name + } + if tc.Function.Arguments != "" { + tcd.ArgumentsDelta = &tc.Function.Arguments + } + return []domain.StreamEvent{{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: tcd, + }} + } + + // Text content delta + if choice.Delta.Content != nil { + event := domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ContentDelta: choice.Delta.Content, + } + if keepReasoning { + event.ReasoningDelta = choice.Delta.ReasoningContent + } + if choice.Delta.Role != "" { + event.Role = &choice.Delta.Role + } + return []domain.StreamEvent{event} + } + + // Reasoning content delta + if keepReasoning && choice.Delta.ReasoningContent != nil { + event := domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ReasoningDelta: choice.Delta.ReasoningContent, + } + if choice.Delta.Role != "" { + event.Role = &choice.Delta.Role + } + return []domain.StreamEvent{event} + } + + // Role-only delta (first chunk often) + if choice.Delta.Role != "" { + role := choice.Delta.Role + return []domain.StreamEvent{{ + Type: domain.StreamEventOutputMessageDelta, + Role: &role, + }} + } + + return nil +} + +func chunkUsageToDomain(u *struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int64 `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` +}) *domain.Usage { + if u == nil { + return nil + } + usage := &domain.Usage{ + PromptTokens: u.PromptTokens, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + } + if u.PromptTokensDetails != nil { + usage.CacheReadTokens = u.PromptTokensDetails.CachedTokens + } + return usage +} diff --git a/apps/gateway/internal/adapters/providers/openai/stream_test.go b/apps/gateway/internal/adapters/providers/openai/stream_test.go new file mode 100644 index 0000000..84c691b --- /dev/null +++ b/apps/gateway/internal/adapters/providers/openai/stream_test.go @@ -0,0 +1,175 @@ +package openai + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +type fakeBody struct { + *strings.Reader +} + +func (fakeBody) Close() error { return nil } + +func newTestStream(sseData string) *Stream { + body := fakeBody{strings.NewReader(sseData)} + resp := &http.Response{Body: body} + return NewStream(resp) +} + +func TestStream_ToolCallDeltaAndFinishReasonInSameChunk(t *testing.T) { + sseData := "data: {\"id\":\"c1\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c1\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"write_file\",\"arguments\":\"\"}}]},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c1\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"path\\\":\\\"f.txt\\\"\"}}]},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c1\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"}\"}}]},\"finish_reason\":\"tool_calls\"}]}\n\n" + + "data: [DONE]\n" + + body := fakeBody{strings.NewReader(sseData)} + stream := NewStream(&http.Response{Body: body}, "deepseek") + + ev1, err := stream.Recv() + if err != nil { + t.Fatalf("ev1: %v", err) + } + if ev1.Type != domain.StreamEventOutputMessageDelta { + t.Fatalf("ev1 type = %v, want OutputMessageDelta", ev1.Type) + } + + ev2, err := stream.Recv() + if err != nil { + t.Fatalf("ev2: %v", err) + } + if ev2.Type != domain.StreamEventToolCallDelta { + t.Fatalf("ev2 type = %v, want ToolCallDelta", ev2.Type) + } + + ev3, err := stream.Recv() + if err != nil { + t.Fatalf("ev3: %v", err) + } + if ev3.Type != domain.StreamEventToolCallDelta { + t.Fatalf("ev3 type = %v, want ToolCallDelta", ev3.Type) + } + + // The combined chunk with tool_calls + finish_reason must emit + // the tool call delta FIRST, then the completed event. + ev4, err := stream.Recv() + if err != nil { + t.Fatalf("ev4: %v", err) + } + if ev4.Type != domain.StreamEventToolCallDelta { + t.Fatalf("ev4 type = %v, want ToolCallDelta (deferred finish)", ev4.Type) + } + if ev4.ToolCallDelta == nil || ev4.ToolCallDelta.ArgumentsDelta == nil { + t.Fatal("ev4: missing arguments delta") + } + if *ev4.ToolCallDelta.ArgumentsDelta != "}" { + t.Fatalf("ev4 args = %q, want %q", *ev4.ToolCallDelta.ArgumentsDelta, "}") + } + + ev5, err := stream.Recv() + if err != nil { + t.Fatalf("ev5: %v", err) + } + if ev5.Type != domain.StreamEventCompleted { + t.Fatalf("ev5 type = %v, want Completed", ev5.Type) + } + if ev5.FinishReason == nil || *ev5.FinishReason != "tool_calls" { + t.Fatalf("ev5 finish_reason = %v, want tool_calls", ev5.FinishReason) + } + + _, err = stream.Recv() + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } +} + +func TestStream_FinishReasonWithoutToolCalls(t *testing.T) { + sseData := "data: {\"id\":\"c2\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Hi\"},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c2\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n" + + "data: [DONE]\n" + + body := fakeBody{strings.NewReader(sseData)} + stream := NewStream(&http.Response{Body: body}, "deepseek") + + ev1, err := stream.Recv() + if err != nil { + t.Fatalf("ev1: %v", err) + } + if ev1.Type != domain.StreamEventOutputTextDelta { + t.Fatalf("ev1 type = %v, want OutputTextDelta", ev1.Type) + } + + ev2, err := stream.Recv() + if err != nil { + t.Fatalf("ev2: %v", err) + } + if ev2.Type != domain.StreamEventCompleted { + t.Fatalf("ev2 type = %v, want Completed", ev2.Type) + } + if ev2.FinishReason == nil || *ev2.FinishReason != "stop" { + t.Fatalf("ev2 finish_reason = %v, want stop", ev2.FinishReason) + } + + _, err = stream.Recv() + if err != io.EOF { + t.Fatalf("expected EOF, got %v", err) + } +} + +func TestStream_ReasoningContentDelta(t *testing.T) { + sseData := "data: {\"id\":\"c3\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"reasoning_content\":\"thinking\"},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c3\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"answer\"},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c3\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n" + + "data: [DONE]\n" + + body := fakeBody{strings.NewReader(sseData)} + stream := NewStream(&http.Response{Body: body}, "deepseek") + + ev1, err := stream.Recv() + if err != nil { + t.Fatalf("ev1: %v", err) + } + if ev1.Type != domain.StreamEventOutputTextDelta { + t.Fatalf("ev1 type = %v, want OutputTextDelta", ev1.Type) + } + if ev1.ReasoningDelta == nil || *ev1.ReasoningDelta != "thinking" { + t.Fatalf("ev1 reasoning = %v, want thinking", ev1.ReasoningDelta) + } + + ev2, err := stream.Recv() + if err != nil { + t.Fatalf("ev2: %v", err) + } + if ev2.ContentDelta == nil || *ev2.ContentDelta != "answer" { + t.Fatalf("ev2 content = %v, want answer", ev2.ContentDelta) + } +} + +func TestStream_ReasoningContentIgnoredForNonDeepSeekProviders(t *testing.T) { + sseData := "data: {\"id\":\"c4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"reasoning_content\":\"thinking\"},\"finish_reason\":null}]}\n\n" + + "data: {\"id\":\"c4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"answer\"},\"finish_reason\":null}]}\n\n" + + "data: [DONE]\n" + + stream := newTestStream(sseData) + + ev1, err := stream.Recv() + if err != nil { + t.Fatalf("ev1: %v", err) + } + if ev1.ReasoningDelta != nil { + t.Fatalf("ev1 reasoning = %v, want nil for non-DeepSeek provider", *ev1.ReasoningDelta) + } + + ev2, err := stream.Recv() + if err != nil { + t.Fatalf("ev2: %v", err) + } + if ev2.ContentDelta == nil || *ev2.ContentDelta != "answer" { + t.Fatalf("ev2 content = %v, want answer", ev2.ContentDelta) + } +} diff --git a/apps/gateway/internal/adapters/providers/registry/registry.go b/apps/gateway/internal/adapters/providers/registry/registry.go new file mode 100644 index 0000000..69ef0fe --- /dev/null +++ b/apps/gateway/internal/adapters/providers/registry/registry.go @@ -0,0 +1,37 @@ +package registry + +import ( + "fmt" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" +) + +// Registry selects provider adapters by name. +type Registry struct { + providers map[string]ports.GenerationProvider + defaultProvider ports.GenerationProvider +} + +func NewRegistry() *Registry { + return &Registry{providers: make(map[string]ports.GenerationProvider)} +} + +func (r *Registry) Register(name string, provider ports.GenerationProvider) { + r.providers[name] = provider +} + +// SetDefault sets a fallback adapter returned when no provider is explicitly +// registered under the requested name. +func (r *Registry) SetDefault(provider ports.GenerationProvider) { + r.defaultProvider = provider +} + +func (r *Registry) GetProvider(providerName string) (ports.GenerationProvider, error) { + if p, ok := r.providers[providerName]; ok { + return p, nil + } + if r.defaultProvider != nil { + return r.defaultProvider, nil + } + return nil, fmt.Errorf("unknown provider: %s", providerName) +} diff --git a/apps/gateway/internal/adapters/providers/tinfoil/adapter.go b/apps/gateway/internal/adapters/providers/tinfoil/adapter.go new file mode 100644 index 0000000..08aa325 --- /dev/null +++ b/apps/gateway/internal/adapters/providers/tinfoil/adapter.go @@ -0,0 +1,335 @@ +package tinfoil + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "runtime/debug" + "strings" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/providers/openai" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + tinfoilsdk "github.com/tinfoilsh/tinfoil-go" + verifierclient "github.com/tinfoilsh/tinfoil-go/verifier/client" +) + +const providerName = "tinfoil" + +// VerifiedClient is the narrow SDK surface the adapter needs after attestation +// and encrypted transport setup have succeeded. +type VerifiedClient interface { + HTTPClient() *http.Client + Enclave() string + Repo() string + TransportMode() string + GroundTruth() *verifierclient.GroundTruth +} + +// VerifiedClientFactory creates a client that has already failed closed if +// attestation, key binding, or transport setup cannot be completed. +type VerifiedClientFactory interface { + NewVerifiedClient(ctx context.Context, model domain.PublicModel) (VerifiedClient, error) +} + +// Adapter is the Tinfoil provider adapter. It reuses the OpenAI-compatible +// request/response mapper, but all HTTP traffic goes through Tinfoil's +// attested EHBP client instead of the generic OpenAI adapter. +type Adapter struct { + timeout time.Duration + factory VerifiedClientFactory + logger ports.Logger +} + +func NewAdapter(timeout time.Duration, logger ...ports.Logger) *Adapter { + return NewAdapterWithFactory(timeout, SDKClientFactory{}, logger...) +} + +func NewAdapterWithFactory(timeout time.Duration, factory VerifiedClientFactory, logger ...ports.Logger) *Adapter { + var l ports.Logger + if len(logger) > 0 { + l = logger[0] + } + if factory == nil { + factory = SDKClientFactory{} + } + return &Adapter{timeout: timeout, factory: factory, logger: l} +} + +func (a *Adapter) Generate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + apiKey := os.Getenv(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return domain.GenerateResult{}, domain.ErrProviderUnavailable(providerName + ": API key not configured") + } + + body := openai.BuildRequestBody(req, model) + body["stream"] = false + delete(body, "stream_options") + + verified, proof, err := a.newVerifiedClient(ctx, model) + if err != nil { + return domain.GenerateResult{}, err + } + + respBody, err := a.do(ctx, verified, apiKey, body) + if err != nil { + return domain.GenerateResult{}, openai.MapProviderErrorWithCompatibilityContext(err, model, body) + } + + result, err := openai.ParseResponse(respBody, req, model) + if err != nil { + return domain.GenerateResult{}, err + } + proof.ProviderResponseID = result.ID + result.TinfoilProof = proof + return result, nil +} + +func (a *Adapter) StreamGenerate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (ports.GenerationStream, error) { + apiKey := os.Getenv(model.ProviderConfig.APIKeySecretRef) + if apiKey == "" { + return nil, domain.ErrProviderUnavailable(providerName + ": API key not configured") + } + + body := openai.BuildRequestBody(req, model) + + verified, proof, err := a.newVerifiedClient(ctx, model) + if err != nil { + return nil, err + } + + resp, err := a.doStream(ctx, verified, apiKey, body) + if err != nil { + return nil, openai.MapProviderErrorWithCompatibilityContext(err, model, body) + } + + return &Stream{ + inner: openai.NewStream(resp, providerName), + proof: proof, + }, nil +} + +func (a *Adapter) newVerifiedClient(ctx context.Context, model domain.PublicModel) (VerifiedClient, *domain.TinfoilTransportProof, error) { + verified, err := a.factory.NewVerifiedClient(ctx, model) + if err != nil { + return nil, nil, domain.ErrProviderUnavailable(providerName).WithMeta("verification_error", err.Error()) + } + if verified == nil || verified.HTTPClient() == nil || verified.GroundTruth() == nil { + return nil, nil, domain.ErrProviderUnavailable(providerName).WithMeta("verification_error", "verified client returned no ground truth") + } + return verified, proofFromVerifiedClient(verified), nil +} + +func (a *Adapter) do(ctx context.Context, verified VerifiedClient, apiKey string, body map[string]any) ([]byte, error) { + if a.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, a.timeout) + defer cancel() + } + + req, err := newProviderRequest(ctx, verified, apiKey, body) + if err != nil { + return nil, err + } + resp, err := verified.HTTPClient().Do(req) + if err != nil { + return nil, fmt.Errorf("tinfoil request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, fmt.Errorf("failed to read Tinfoil response: %w", err) + } + if resp.StatusCode >= 400 { + return nil, openai.ParseProviderError(resp.StatusCode, respBody) + } + return respBody, nil +} + +func (a *Adapter) doStream(ctx context.Context, verified VerifiedClient, apiKey string, body map[string]any) (*http.Response, error) { + req, err := newProviderRequest(ctx, verified, apiKey, body) + if err != nil { + return nil, err + } + resp, err := verified.HTTPClient().Do(req) + if err != nil { + return nil, fmt.Errorf("tinfoil stream request failed: %w", err) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + resp.Body.Close() + return nil, openai.ParseProviderError(resp.StatusCode, respBody) + } + return resp, nil +} + +func newProviderRequest(ctx context.Context, verified VerifiedClient, apiKey string, body map[string]any) (*http.Request, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatCompletionsURL(verified), bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + return req, nil +} + +func chatCompletionsURL(verified VerifiedClient) string { + base := configuredProxyBaseURL() + if base == "" { + base = "https://" + verified.Enclave() + "/v1" + } + return base + "/chat/completions" +} + +func configuredProxyBaseURL() string { + return strings.TrimRight(strings.TrimSpace(os.Getenv("TINFOIL_PROXY_BASE_URL")), "/") +} + +func proofFromVerifiedClient(verified VerifiedClient) *domain.TinfoilTransportProof { + groundTruth := verified.GroundTruth() + now := time.Now().UTC() + evidence, _ := json.Marshal(map[string]any{ + "ground_truth": groundTruth, + "config_repo": verified.Repo(), + "transport_mode": verified.TransportMode(), + "sdk_version": tinfoilSDKVersion(), + "verified_at": now.Format(time.RFC3339Nano), + }) + enclaveHost := stringPtr(groundTruth.EnclaveHost) + if groundTruth.EnclaveHost == "" && verified.Enclave() != "" { + enclaveHost = stringPtr(verified.Enclave()) + } + return &domain.TinfoilTransportProof{ + Provider: providerName, + EnclaveHost: enclaveHost, + ConfigRepo: stringPtr(verified.Repo()), + Digest: stringPtr(groundTruth.Digest), + CodeFingerprint: stringPtr(groundTruth.CodeFingerprint), + EnclaveFingerprint: stringPtr(groundTruth.EnclaveFingerprint), + TLSPublicKey: stringPtr(groundTruth.TLSPublicKey), + HPKEPublicKey: stringPtr(groundTruth.HPKEPublicKey), + TransportMode: stringPtr(verified.TransportMode()), + SDKVersion: stringPtr(tinfoilSDKVersion()), + Status: domain.ProofStatusVerified, + VerificationEvidenceJSON: json.RawMessage(evidence), + CreatedAt: now, + VerifiedAt: &now, + } +} + +func stringPtr(value string) *string { + if value == "" { + return nil + } + return &value +} + +func tinfoilSDKVersion() string { + if info, ok := debug.ReadBuildInfo(); ok { + for _, dep := range info.Deps { + if dep.Path == "github.com/tinfoilsh/tinfoil-go" { + if dep.Version != "" { + return dep.Path + " " + dep.Version + } + return dep.Path + } + } + } + return "github.com/tinfoilsh/tinfoil-go" +} + +// Stream decorates the OpenAI-compatible stream with Tinfoil proof evidence. +type Stream struct { + inner ports.GenerationStream + proof *domain.TinfoilTransportProof +} + +func (s *Stream) Recv() (domain.StreamEvent, error) { + return s.inner.Recv() +} + +func (s *Stream) Close() error { + return s.inner.Close() +} + +func (s *Stream) VerifiedTransportProof() *domain.TinfoilTransportProof { + if s.proof == nil { + return nil + } + cp := *s.proof + if len(s.proof.VerificationEvidenceJSON) > 0 { + cp.VerificationEvidenceJSON = json.RawMessage(append([]byte(nil), s.proof.VerificationEvidenceJSON...)) + } + return &cp +} + +// SDKClientFactory is the production Tinfoil SDK factory. +type SDKClientFactory struct{} + +func (SDKClientFactory) NewVerifiedClient(_ context.Context, model domain.PublicModel) (VerifiedClient, error) { + opts := []tinfoilsdk.ClientOption{ + tinfoilsdk.WithTransport(tinfoilsdk.TransportEHBP), + } + if repo := strings.TrimSpace(os.Getenv("TINFOIL_CONFIG_REPO")); repo != "" { + opts = append(opts, tinfoilsdk.WithRepo(repo)) + } + if enclave := strings.TrimSpace(os.Getenv("TINFOIL_ENCLAVE_HOST")); enclave != "" { + opts = append(opts, tinfoilsdk.WithEnclave(enclave)) + } + // ProviderConfig.BaseURL is the OpenAI-compatible catalog URL. Do not treat + // it as the EHBP request proxy by default; the SDK-selected verified enclave + // is the native Tinfoil transport path. Operators can opt into an explicit + // compatible EHBP proxy with TINFOIL_PROXY_BASE_URL. + if proxyBaseURL := configuredProxyBaseURL(); proxyBaseURL != "" { + opts = append(opts, tinfoilsdk.WithBaseURL(proxyBaseURL)) + } + if bundleURL := strings.TrimSpace(os.Getenv("TINFOIL_ATTESTATION_BUNDLE_URL")); bundleURL != "" { + opts = append(opts, tinfoilsdk.WithAttestationBundleURL(strings.TrimRight(bundleURL, "/"))) + } + + client, err := tinfoilsdk.NewClientWithOptions(opts...) + if err != nil { + return nil, err + } + groundTruth, err := client.Verify() + if err != nil { + return nil, err + } + return &sdkVerifiedClient{client: client, groundTruth: groundTruth}, nil +} + +type sdkVerifiedClient struct { + client *tinfoilsdk.Client + groundTruth *verifierclient.GroundTruth +} + +func (c *sdkVerifiedClient) HTTPClient() *http.Client { + return c.client.HTTPClient() +} + +func (c *sdkVerifiedClient) Enclave() string { + return c.client.Enclave() +} + +func (c *sdkVerifiedClient) Repo() string { + return c.client.Repo() +} + +func (c *sdkVerifiedClient) TransportMode() string { + return string(c.client.Transport()) +} + +func (c *sdkVerifiedClient) GroundTruth() *verifierclient.GroundTruth { + return c.groundTruth +} diff --git a/apps/gateway/internal/adapters/providers/tinfoil/adapter_test.go b/apps/gateway/internal/adapters/providers/tinfoil/adapter_test.go new file mode 100644 index 0000000..cbf1c5b --- /dev/null +++ b/apps/gateway/internal/adapters/providers/tinfoil/adapter_test.go @@ -0,0 +1,286 @@ +package tinfoil + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + verifierclient "github.com/tinfoilsh/tinfoil-go/verifier/client" +) + +func TestGenerateUsesVerifiedClientAndStoresProofEvidence(t *testing.T) { + t.Setenv("TINFOIL_API_KEY", "test-key") + + var seenRequest bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenRequest = true + if r.URL.Path != "/chat/completions" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Fatalf("unexpected authorization header: %s", got) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request: %v", err) + } + if body["model"] != "kimi-k2-6" || body["stream"] != false { + t.Fatalf("unexpected request body: %#v", body) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{ + "id":"cmpl-tinfoil-1", + "created":1710000000, + "choices":[{"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":3,"completion_tokens":4,"total_tokens":7} + }`) + })) + defer server.Close() + t.Setenv("TINFOIL_PROXY_BASE_URL", server.URL) + + factory := &fakeFactory{client: fakeClient(server.Client())} + result, err := NewAdapterWithFactory(time.Second, factory).Generate(context.Background(), simpleRequest(false), simpleModel(server.URL)) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if !seenRequest { + t.Fatalf("provider request was not sent") + } + if result.ID != "cmpl-tinfoil-1" || result.TinfoilProof == nil { + t.Fatalf("unexpected result/proof: id=%q proof=%#v", result.ID, result.TinfoilProof) + } + if result.TinfoilProof.ProviderResponseID != "cmpl-tinfoil-1" { + t.Fatalf("proof response id was not filled: %#v", result.TinfoilProof.ProviderResponseID) + } + if result.TinfoilProof.EnclaveHost == nil || *result.TinfoilProof.EnclaveHost != "inference.tinfoil.sh" { + t.Fatalf("unexpected enclave host in proof: %#v", result.TinfoilProof.EnclaveHost) + } + if result.TinfoilProof.TransportMode == nil || *result.TinfoilProof.TransportMode != "ehbp" { + t.Fatalf("unexpected transport mode in proof: %#v", result.TinfoilProof.TransportMode) + } + if len(result.TinfoilProof.VerificationEvidenceJSON) == 0 { + t.Fatalf("expected verification evidence JSON") + } +} + +func TestStreamGenerateReturnsVerifiedTransportProof(t *testing.T) { + t.Setenv("TINFOIL_API_KEY", "test-key") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprint(w, "data: {\"id\":\"cmpl-stream-1\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":\"hi\"}}]}\n\n") + fmt.Fprint(w, "data: {\"id\":\"cmpl-stream-1\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1,\"total_tokens\":2}}\n\n") + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer server.Close() + t.Setenv("TINFOIL_PROXY_BASE_URL", server.URL) + + stream, err := NewAdapterWithFactory(time.Second, &fakeFactory{client: fakeClient(server.Client())}).StreamGenerate( + context.Background(), + simpleRequest(true), + simpleModel(server.URL), + ) + if err != nil { + t.Fatalf("StreamGenerate returned error: %v", err) + } + defer stream.Close() + + var text string + var completed bool + for { + event, recvErr := stream.Recv() + if recvErr == io.EOF { + break + } + if recvErr != nil { + t.Fatalf("Recv returned error: %v", recvErr) + } + if event.ContentDelta != nil { + text += *event.ContentDelta + } + if event.Type == domain.StreamEventCompleted { + completed = true + } + } + if text != "hi" || !completed { + t.Fatalf("unexpected stream events: text=%q completed=%v", text, completed) + } + proofProvider, ok := stream.(interface { + VerifiedTransportProof() *domain.TinfoilTransportProof + }) + if !ok { + t.Fatalf("stream does not expose proof evidence") + } + proof := proofProvider.VerifiedTransportProof() + if proof == nil || proof.TransportMode == nil || *proof.TransportMode != "ehbp" { + t.Fatalf("unexpected stream proof: %#v", proof) + } +} + +func TestGenerateFailsClosedWhenAttestationFails(t *testing.T) { + t.Setenv("TINFOIL_API_KEY", "test-key") + + var called bool + factory := &fakeFactory{err: errors.New("attestation failed"), onCall: func() { called = true }} + _, err := NewAdapterWithFactory(time.Second, factory).Generate(context.Background(), simpleRequest(false), simpleModel("https://example.invalid/v1")) + if err == nil { + t.Fatalf("expected error") + } + if !called { + t.Fatalf("expected factory to be called") + } + var gwErr *domain.GatewayError + if !errors.As(err, &gwErr) || gwErr.Code != domain.ErrCodeProviderUnavailable { + t.Fatalf("expected provider unavailable GatewayError, got %T %[1]v", err) + } +} + +func TestGenerateDoesNotVerifyWithoutAPIKey(t *testing.T) { + t.Setenv("TINFOIL_API_KEY", "") + + var called bool + factory := &fakeFactory{client: fakeClient(http.DefaultClient), onCall: func() { called = true }} + _, err := NewAdapterWithFactory(time.Second, factory).Generate(context.Background(), simpleRequest(false), simpleModel("https://example.invalid/v1")) + if err == nil { + t.Fatalf("expected error") + } + if called { + t.Fatalf("factory should not be called without an API key") + } +} + +func TestGenerateMapsProviderErrors(t *testing.T) { + t.Setenv("TINFOIL_API_KEY", "test-key") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"error":{"message":"rate limited","type":"rate_limit_error"}}`) + })) + defer server.Close() + t.Setenv("TINFOIL_PROXY_BASE_URL", server.URL) + + _, err := NewAdapterWithFactory(time.Second, &fakeFactory{client: fakeClient(server.Client())}).Generate( + context.Background(), + simpleRequest(false), + simpleModel(server.URL), + ) + var gwErr *domain.GatewayError + if !errors.As(err, &gwErr) || gwErr.HTTPStatus != http.StatusTooManyRequests { + t.Fatalf("expected mapped 429 GatewayError, got %T %[1]v", err) + } +} + +func TestRequestURLDefaultsToVerifiedEnclave(t *testing.T) { + t.Setenv("TINFOIL_PROXY_BASE_URL", "") + + req, err := newProviderRequest(context.Background(), fakeClient(http.DefaultClient), "test-key", map[string]any{"model": "deepseek-v4-pro"}) + if err != nil { + t.Fatalf("newProviderRequest returned error: %v", err) + } + if got, want := req.URL.String(), "https://inference.tinfoil.sh/v1/chat/completions"; got != want { + t.Fatalf("request URL = %q, want %q", got, want) + } +} + +func TestRequestURLUsesExplicitProxyBaseURL(t *testing.T) { + t.Setenv("TINFOIL_PROXY_BASE_URL", "https://proxy.example.test/v1/") + + req, err := newProviderRequest(context.Background(), fakeClient(http.DefaultClient), "test-key", map[string]any{"model": "deepseek-v4-pro"}) + if err != nil { + t.Fatalf("newProviderRequest returned error: %v", err) + } + if got, want := req.URL.String(), "https://proxy.example.test/v1/chat/completions"; got != want { + t.Fatalf("request URL = %q, want %q", got, want) + } +} + +func simpleRequest(stream bool) domain.GenerateRequest { + role := "user" + content := "hello" + return domain.GenerateRequest{ + PublicModelID: "tinfoil/kimi-k2-6", + Stream: stream, + Input: []domain.InputItem{{ + Type: domain.InputItemTypeMessage, + Role: &role, + Content: &content, + }}, + } +} + +func simpleModel(baseURL string) domain.PublicModel { + return domain.PublicModel{ + PublicModelID: "tinfoil/kimi-k2-6", + ProviderModelID: "tinfoil-kimi-k2-6", + UpstreamModelName: "kimi-k2-6", + ProofMode: domain.ProofModeTinfoilAttestedTransport, + MaxOutputTokens: 8192, + ProviderConfig: domain.ProviderConfig{ + ProviderName: "tinfoil", + BaseURL: baseURL, + APIKeySecretRef: "TINFOIL_API_KEY", + }, + } +} + +type fakeFactory struct { + client VerifiedClient + err error + onCall func() +} + +func (f *fakeFactory) NewVerifiedClient(context.Context, domain.PublicModel) (VerifiedClient, error) { + if f.onCall != nil { + f.onCall() + } + if f.err != nil { + return nil, f.err + } + return f.client, nil +} + +type fakeVerifiedClient struct { + httpClient *http.Client + groundTruth *verifierclient.GroundTruth +} + +func fakeClient(httpClient *http.Client) *fakeVerifiedClient { + return &fakeVerifiedClient{ + httpClient: httpClient, + groundTruth: &verifierclient.GroundTruth{ + EnclaveHost: "inference.tinfoil.sh", + TLSPublicKey: "tls-key", + HPKEPublicKey: "hpke-key", + Digest: "sha256:abc", + CodeFingerprint: "code-fp", + EnclaveFingerprint: "enclave-fp", + }, + } +} + +func (c *fakeVerifiedClient) HTTPClient() *http.Client { + return c.httpClient +} + +func (c *fakeVerifiedClient) Enclave() string { + return "inference.tinfoil.sh" +} + +func (c *fakeVerifiedClient) Repo() string { + return "tinfoilsh/confidential-model-router" +} + +func (c *fakeVerifiedClient) TransportMode() string { + return "ehbp" +} + +func (c *fakeVerifiedClient) GroundTruth() *verifierclient.GroundTruth { + return c.groundTruth +} diff --git a/apps/gateway/internal/adapters/router/http_client.go b/apps/gateway/internal/adapters/router/http_client.go new file mode 100644 index 0000000..e191440 --- /dev/null +++ b/apps/gateway/internal/adapters/router/http_client.go @@ -0,0 +1,108 @@ +package router + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// Client calls the standalone router service over HTTP. +type Client struct { + baseURL string + httpClient *http.Client +} + +func NewClient(baseURL string, timeout time.Duration) *Client { + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + httpClient: &http.Client{ + Timeout: timeout, + }, + } +} + +type routeRequest struct { + RouterID string `json:"router_id"` + Request domain.GenerateRequest `json:"request"` +} + +type routeResponse struct { + PublicModelID string `json:"public_model_id"` + Category *string `json:"category,omitempty"` + Score *float32 `json:"score,omitempty"` + CategoryScores []domain.RoutingCategoryScore `json:"category_scores,omitempty"` + FallbackUsed bool `json:"fallback_used"` + Reason string `json:"reason"` +} + +type errorResponse struct { + Error struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +func (c *Client) Route(ctx context.Context, req domain.RouteRequest) (domain.RouteDecision, error) { + body, err := json.Marshal(routeRequest{ + RouterID: req.RouterID, + Request: req.Request, + }) + if err != nil { + return domain.RouteDecision{}, domain.ErrInternal("failed to encode router request") + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/route", bytes.NewReader(body)) + if err != nil { + return domain.RouteDecision{}, domain.ErrInternal("failed to build router request") + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return domain.RouteDecision{}, domain.ErrProviderUnavailable("router").WithMeta("upstream_error", err.Error()) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var errResp errorResponse + if decodeErr := json.NewDecoder(resp.Body).Decode(&errResp); decodeErr == nil && errResp.Error.Code != "" { + return domain.RouteDecision{}, mapRouterError(req.RouterID, resp.StatusCode, errResp) + } + return domain.RouteDecision{}, domain.ErrProviderError(resp.StatusCode, fmt.Sprintf("router returned status %d", resp.StatusCode)) + } + + var out routeResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return domain.RouteDecision{}, domain.ErrProviderUnavailable("router").WithMeta("upstream_error", err.Error()) + } + + return domain.RouteDecision{ + PublicModelID: out.PublicModelID, + Category: out.Category, + Score: out.Score, + CategoryScores: out.CategoryScores, + FallbackUsed: out.FallbackUsed, + Reason: out.Reason, + }, nil +} + +func mapRouterError(routerID string, status int, errResp errorResponse) error { + switch errResp.Error.Code { + case domain.ErrCodeNotFound, domain.ErrCodeUnsupportedModel: + return domain.ErrUnsupportedModel(routerID) + case domain.ErrCodeInvalidField: + return domain.ErrInvalidField(errResp.Error.Message) + default: + if status >= 500 { + return domain.ErrProviderUnavailable("router").WithMeta("upstream_error", errResp.Error.Message) + } + return domain.ErrProviderError(status, errResp.Error.Message) + } +} diff --git a/apps/gateway/internal/adapters/storage/postgres/api_key_repo.go b/apps/gateway/internal/adapters/storage/postgres/api_key_repo.go new file mode 100644 index 0000000..1c659f1 --- /dev/null +++ b/apps/gateway/internal/adapters/storage/postgres/api_key_repo.go @@ -0,0 +1,62 @@ +package postgres + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/jackc/pgx/v5/pgxpool" +) + +// APIKeyRepo handles API key lookups. +type APIKeyRepo struct { + pool *pgxpool.Pool +} + +func NewAPIKeyRepo(pool *pgxpool.Pool) *APIKeyRepo { + return &APIKeyRepo{pool: pool} +} + +// FindByHash looks up an API key by its SHA-256 hash and returns the key and account. +func (r *APIKeyRepo) FindByHash(ctx context.Context, rawKey string) (domain.APIKey, domain.Account, error) { + hash := hashAPIKey(rawKey) + + var ( + key domain.APIKey + account domain.Account + extCustID *string + ) + + err := r.pool.QueryRow(ctx, ` + SELECT + k.id, k.account_id, k.name, k.key_prefix, k.active, k.pii_mode, k.last_used_at, + a.id, a.external_customer_id, a.status + FROM api_keys k + JOIN accounts a ON a.id = k.account_id + WHERE k.key_hash = $1 + `, hash).Scan( + &key.ID, &key.AccountID, &key.Name, &key.KeyPrefix, &key.Active, &key.PIIMode, &key.LastUsedAt, + &account.ID, &extCustID, &account.Status, + ) + if err != nil { + return domain.APIKey{}, domain.Account{}, fmt.Errorf("API key not found: %w", err) + } + + account.ExternalCustomerID = extCustID + + go func() { + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + r.pool.Exec(bgCtx, `UPDATE api_keys SET last_used_at = NOW() WHERE id = $1`, key.ID) + }() + + return key, account, nil +} + +func hashAPIKey(key string) string { + h := sha256.Sum256([]byte(key)) + return hex.EncodeToString(h[:]) +} diff --git a/apps/gateway/internal/adapters/storage/postgres/model_catalog_repo.go b/apps/gateway/internal/adapters/storage/postgres/model_catalog_repo.go new file mode 100644 index 0000000..47dc650 --- /dev/null +++ b/apps/gateway/internal/adapters/storage/postgres/model_catalog_repo.go @@ -0,0 +1,213 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/shopspring/decimal" +) + +// ModelCatalogRepo loads public models from Postgres. +type ModelCatalogRepo struct { + pool *pgxpool.Pool +} + +func NewModelCatalogRepo(pool *pgxpool.Pool) *ModelCatalogRepo { + return &ModelCatalogRepo{pool: pool} +} + +const modelQuery = ` +SELECT + m.id, m.public_model_id, m.display_name, m.description, + m.provider_model_id, pm.provider_model_name, + pc.id, pc.provider_name, pc.base_url, pc.api_key_secret_ref, pc.organization_ref, pc.active, + m.supports_chat_completions, + m.supports_chat_completions_stream, + m.supports_tools, m.supports_parallel_tool_calls, m.supports_structured_output, m.supports_reasoning, + m.proof_mode, + m.max_context_window, m.max_output_tokens, + m.active, + mp.currency, mp.input_price_per_1m_tokens_microcents, mp.output_price_per_1m_tokens_microcents, + mp.cache_read_price_per_1m_tokens_microcents, mp.cache_write_price_per_1m_tokens_microcents +FROM public_models m +JOIN provider_models pm ON pm.id = m.provider_model_id +JOIN provider_configs pc ON pc.id = m.provider_config_id +JOIN LATERAL ( + SELECT mp.* + FROM model_pricing mp + WHERE mp.provider_model_id = pm.id + AND mp.active = true + AND mp.effective_from <= NOW() + AND (mp.effective_to IS NULL OR mp.effective_to > NOW()) + ORDER BY mp.effective_from DESC + LIMIT 1 +) mp ON true +WHERE m.active = true + AND pm.active = true + AND pc.active = true +` + +func (r *ModelCatalogRepo) ListPublicModels(ctx context.Context) ([]domain.PublicModel, error) { + rows, err := r.pool.Query(ctx, modelQuery+` + ORDER BY m.public_model_id + `) + if err != nil { + return nil, fmt.Errorf("failed to query models: %w", err) + } + defer rows.Close() + + var models []domain.PublicModel + for rows.Next() { + m, err := scanModel(rows) + if err != nil { + return nil, err + } + models = append(models, m) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating models: %w", err) + } + + return models, nil +} + +func (r *ModelCatalogRepo) GetPublicModel(ctx context.Context, publicModelID string) (domain.PublicModel, error) { + row := r.pool.QueryRow(ctx, modelQuery+` + AND m.public_model_id = $1 + `, publicModelID) + + m, err := scanModelRow(row) + if err != nil { + return domain.PublicModel{}, domain.ErrUnsupportedModel(publicModelID) + } + + return m, nil +} + +func (r *ModelCatalogRepo) ListRouters(ctx context.Context) ([]domain.RouterEntry, error) { + rows, err := r.pool.Query(ctx, ` + SELECT r.id::text, r.router_id, r.display_name, r.description, + fallback.public_model_id, r.active, r.created_at, r.updated_at + FROM routers r + JOIN public_models fallback ON fallback.id = r.fallback_public_model_id + JOIN provider_models pm ON pm.id = fallback.provider_model_id + JOIN provider_configs pc ON pc.id = fallback.provider_config_id + WHERE r.active = true + AND fallback.active = true + AND pm.active = true + AND pc.active = true + AND EXISTS ( + SELECT 1 + FROM model_pricing mp + WHERE mp.provider_model_id = pm.id + AND mp.active = true + AND mp.effective_from <= NOW() + AND (mp.effective_to IS NULL OR mp.effective_to > NOW()) + ) + ORDER BY r.router_id + `) + if err != nil { + return nil, fmt.Errorf("failed to query routers: %w", err) + } + defer rows.Close() + + var routers []domain.RouterEntry + for rows.Next() { + var router domain.RouterEntry + if err := rows.Scan( + &router.ID, &router.RouterID, &router.DisplayName, &router.Description, + &router.FallbackPublicModelID, &router.Active, &router.CreatedAt, &router.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("failed to scan router: %w", err) + } + routers = append(routers, router) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating routers: %w", err) + } + return routers, nil +} + +func (r *ModelCatalogRepo) GetRouter(ctx context.Context, routerID string) (domain.RouterEntry, error) { + var router domain.RouterEntry + err := r.pool.QueryRow(ctx, ` + SELECT r.id::text, r.router_id, r.display_name, r.description, + fallback.public_model_id, r.active, r.created_at, r.updated_at + FROM routers r + JOIN public_models fallback ON fallback.id = r.fallback_public_model_id + JOIN provider_models pm ON pm.id = fallback.provider_model_id + JOIN provider_configs pc ON pc.id = fallback.provider_config_id + WHERE r.router_id = $1 AND r.active = true + AND fallback.active = true + AND pm.active = true + AND pc.active = true + AND EXISTS ( + SELECT 1 + FROM model_pricing mp + WHERE mp.provider_model_id = pm.id + AND mp.active = true + AND mp.effective_from <= NOW() + AND (mp.effective_to IS NULL OR mp.effective_to > NOW()) + ) + `, routerID).Scan( + &router.ID, &router.RouterID, &router.DisplayName, &router.Description, + &router.FallbackPublicModelID, &router.Active, &router.CreatedAt, &router.UpdatedAt, + ) + if err != nil { + if err == pgx.ErrNoRows { + return domain.RouterEntry{}, domain.ErrNotFound("router", routerID) + } + return domain.RouterEntry{}, fmt.Errorf("failed to get router: %w", err) + } + return router, nil +} + +type scannable interface { + Scan(dest ...any) error +} + +func scanModelRow(row scannable) (domain.PublicModel, error) { + var m domain.PublicModel + var inputPriceMicrocents, outputPriceMicrocents int64 + var cacheReadPriceMicrocents, cacheWritePriceMicrocents *int64 + + err := row.Scan( + &m.ID, &m.PublicModelID, &m.DisplayName, &m.Description, + &m.ProviderModelID, &m.UpstreamModelName, + &m.ProviderConfig.ID, &m.ProviderConfig.ProviderName, &m.ProviderConfig.BaseURL, + &m.ProviderConfig.APIKeySecretRef, &m.ProviderConfig.OrganizationRef, &m.ProviderConfig.Active, + &m.SupportsChatCompletions, + &m.SupportsChatCompletionsStream, + &m.SupportsTools, &m.SupportsParallelToolCalls, &m.SupportsStructuredOutput, &m.SupportsReasoning, + &m.ProofMode, + &m.MaxContextWindow, &m.MaxOutputTokens, + &m.Active, + &m.Currency, &inputPriceMicrocents, &outputPriceMicrocents, + &cacheReadPriceMicrocents, &cacheWritePriceMicrocents, + ) + if err != nil { + return domain.PublicModel{}, err + } + m.ProofMode = m.EffectiveProofMode() + + m.InputPricePerMillion = decimal.NewFromInt(inputPriceMicrocents).Div(decimal.NewFromInt(100_000_000)) + m.OutputPricePerMillion = decimal.NewFromInt(outputPriceMicrocents).Div(decimal.NewFromInt(100_000_000)) + if cacheReadPriceMicrocents != nil { + v := decimal.NewFromInt(*cacheReadPriceMicrocents).Div(decimal.NewFromInt(100_000_000)) + m.CacheReadPricePerMillion = &v + } + if cacheWritePriceMicrocents != nil { + v := decimal.NewFromInt(*cacheWritePriceMicrocents).Div(decimal.NewFromInt(100_000_000)) + m.CacheWritePricePerMillion = &v + } + + return m, nil +} + +func scanModel(rows interface{ Scan(dest ...any) error }) (domain.PublicModel, error) { + return scanModelRow(rows) +} diff --git a/apps/gateway/internal/adapters/storage/postgres/tinfoil_proof_repo.go b/apps/gateway/internal/adapters/storage/postgres/tinfoil_proof_repo.go new file mode 100644 index 0000000..1f7474d --- /dev/null +++ b/apps/gateway/internal/adapters/storage/postgres/tinfoil_proof_repo.go @@ -0,0 +1,103 @@ +package postgres + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// TinfoilProofRepo persists safe Tinfoil attested-transport proof evidence. +type TinfoilProofRepo struct { + pool *pgxpool.Pool +} + +func NewTinfoilProofRepo(pool *pgxpool.Pool) *TinfoilProofRepo { + return &TinfoilProofRepo{pool: pool} +} + +func (r *TinfoilProofRepo) UpsertTinfoilTransportProof(ctx context.Context, proof domain.TinfoilTransportProof) error { + if proof.CreatedAt.IsZero() { + proof.CreatedAt = time.Now().UTC() + } + var evidence any + if len(proof.VerificationEvidenceJSON) > 0 { + evidence = proof.VerificationEvidenceJSON + } + _, err := r.pool.Exec(ctx, ` + INSERT INTO tinfoil_transport_proofs ( + account_id, api_key_id, provider_name, public_model_id, upstream_model_id, + provider_response_id, enclave_host, config_repo, digest, code_fingerprint, + enclave_fingerprint, tls_public_key, hpke_public_key, transport_mode, sdk_version, + status, failure_reason, verification_evidence_json, created_at, verified_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, $9, $10, + $11, $12, $13, $14, $15, + $16, $17, $18, $19, $20 + ) + ON CONFLICT (provider_name, provider_response_id) DO UPDATE SET + account_id = EXCLUDED.account_id, + api_key_id = EXCLUDED.api_key_id, + public_model_id = EXCLUDED.public_model_id, + upstream_model_id = EXCLUDED.upstream_model_id, + enclave_host = EXCLUDED.enclave_host, + config_repo = EXCLUDED.config_repo, + digest = EXCLUDED.digest, + code_fingerprint = EXCLUDED.code_fingerprint, + enclave_fingerprint = EXCLUDED.enclave_fingerprint, + tls_public_key = EXCLUDED.tls_public_key, + hpke_public_key = EXCLUDED.hpke_public_key, + transport_mode = EXCLUDED.transport_mode, + sdk_version = EXCLUDED.sdk_version, + status = EXCLUDED.status, + failure_reason = EXCLUDED.failure_reason, + verification_evidence_json = EXCLUDED.verification_evidence_json, + verified_at = EXCLUDED.verified_at + `, proof.AccountID, proof.APIKeyID, proof.Provider, proof.PublicModelID, proof.UpstreamModelID, + proof.ProviderResponseID, proof.EnclaveHost, proof.ConfigRepo, proof.Digest, proof.CodeFingerprint, + proof.EnclaveFingerprint, proof.TLSPublicKey, proof.HPKEPublicKey, proof.TransportMode, proof.SDKVersion, + proof.Status, proof.FailureReason, evidence, proof.CreatedAt, proof.VerifiedAt) + if err != nil { + return fmt.Errorf("failed to upsert Tinfoil transport proof: %w", err) + } + return nil +} + +func (r *TinfoilProofRepo) GetTinfoilTransportProof(ctx context.Context, accountID, providerResponseID string) (domain.TinfoilTransportProof, error) { + row := r.pool.QueryRow(ctx, ` + SELECT id::text, account_id, api_key_id, provider_name, public_model_id, upstream_model_id, + provider_response_id, enclave_host, config_repo, digest, code_fingerprint, + enclave_fingerprint, tls_public_key, hpke_public_key, transport_mode, sdk_version, + status, failure_reason, verification_evidence_json, created_at, verified_at + FROM tinfoil_transport_proofs + WHERE account_id = $1 AND provider_response_id = $2 + `, accountID, providerResponseID) + proof, err := scanTinfoilTransportProof(row) + if err != nil { + if err == pgx.ErrNoRows { + return domain.TinfoilTransportProof{}, domain.ErrNotFound("tinfoil proof", providerResponseID) + } + return domain.TinfoilTransportProof{}, fmt.Errorf("failed to get Tinfoil transport proof: %w", err) + } + return proof, nil +} + +func scanTinfoilTransportProof(row interface{ Scan(dest ...any) error }) (domain.TinfoilTransportProof, error) { + var proof domain.TinfoilTransportProof + var evidence []byte + err := row.Scan( + &proof.ID, &proof.AccountID, &proof.APIKeyID, &proof.Provider, &proof.PublicModelID, &proof.UpstreamModelID, + &proof.ProviderResponseID, &proof.EnclaveHost, &proof.ConfigRepo, &proof.Digest, &proof.CodeFingerprint, + &proof.EnclaveFingerprint, &proof.TLSPublicKey, &proof.HPKEPublicKey, &proof.TransportMode, &proof.SDKVersion, + &proof.Status, &proof.FailureReason, &evidence, &proof.CreatedAt, &proof.VerifiedAt, + ) + if len(evidence) > 0 { + proof.VerificationEvidenceJSON = json.RawMessage(append([]byte(nil), evidence...)) + } + return proof, err +} diff --git a/apps/gateway/internal/adapters/storage/postgres/usage_repo.go b/apps/gateway/internal/adapters/storage/postgres/usage_repo.go new file mode 100644 index 0000000..b80c100 --- /dev/null +++ b/apps/gateway/internal/adapters/storage/postgres/usage_repo.go @@ -0,0 +1,370 @@ +package postgres + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/shopspring/decimal" +) + +// UsageRepo persists usage events to Postgres. +type UsageRepo struct { + pool *pgxpool.Pool +} + +func NewUsageRepo(pool *pgxpool.Pool) *UsageRepo { + return &UsageRepo{pool: pool} +} + +func (r *UsageRepo) GetSpendableBalance(ctx context.Context, accountID string) (int64, error) { + var balance domain.AccountBalance + err := r.pool.QueryRow(ctx, ` + SELECT account_id, currency, monthly_credit_total_microcents, monthly_credit_used_microcents, + prepaid_credit_microcents, billing_period_start, billing_period_end, updated_at + FROM account_balances + WHERE account_id = $1 + `, accountID).Scan( + &balance.AccountID, &balance.Currency, &balance.MonthlyCreditTotalMicrocents, &balance.MonthlyCreditUsedMicrocents, + &balance.PrepaidCreditMicrocents, &balance.BillingPeriodStart, &balance.BillingPeriodEnd, &balance.UpdatedAt, + ) + if err != nil { + return 0, err + } + + return balance.SpendableCreditMicrocentsAt(time.Now()), nil +} + +func (r *UsageRepo) RecordSuccess(ctx context.Context, auth domain.AuthContext, endpoint string, req domain.GenerateRequest, resp domain.GenerateResult, model domain.PublicModel, latencyMs int64) error { + var inputTokens, outputTokens *int64 + var cacheCreationTokens, cacheReadTokens *int64 + var inputPricePer1MTokenMicrocents, outputPricePer1MTokenMicrocents *int64 + var cacheReadPricePer1MTokenMicrocents, cacheWritePricePer1MTokenMicrocents *int64 + var costMicrocents, chargedMonthlyMicrocents, chargedPrepaidMicrocents *int64 + now := time.Now().UTC() + usageEventID := uuid.New().String() + requestID := uuid.New().String() + var charge domain.UsageCharge + + if resp.Usage != nil { + inputTokens = &resp.Usage.PromptTokens + outputTokens = &resp.Usage.CompletionTokens + if resp.Usage.CacheCreationTokens > 0 { + cacheCreationTokens = &resp.Usage.CacheCreationTokens + } + if resp.Usage.CacheReadTokens > 0 { + cacheReadTokens = &resp.Usage.CacheReadTokens + } + + million := decimal.NewFromInt(1_000_000) + toMicrocents := decimal.NewFromInt(100_000_000) + + // Determine prices for cache tokens (fallback to input price if not configured) + cacheReadPrice := model.InputPricePerMillion + if model.CacheReadPricePerMillion != nil { + cacheReadPrice = *model.CacheReadPricePerMillion + } + cacheWritePrice := model.InputPricePerMillion + if model.CacheWritePricePerMillion != nil { + cacheWritePrice = *model.CacheWritePricePerMillion + } + + // Split input tokens: total = uncached + cache_read + cache_creation + uncachedInput := resp.Usage.PromptTokens - resp.Usage.CacheReadTokens - resp.Usage.CacheCreationTokens + if uncachedInput < 0 { + uncachedInput = 0 + } + + // Cost = uncached×inputPrice + cacheRead×cacheReadPrice + cacheCreation×cacheWritePrice + output×outputPrice + uncachedCost := model.InputPricePerMillion.Mul(decimal.NewFromInt(uncachedInput)).Div(million) + cacheReadCost := cacheReadPrice.Mul(decimal.NewFromInt(resp.Usage.CacheReadTokens)).Div(million) + cacheWriteCost := cacheWritePrice.Mul(decimal.NewFromInt(resp.Usage.CacheCreationTokens)).Div(million) + outputCost := model.OutputPricePerMillion.Mul(decimal.NewFromInt(resp.Usage.CompletionTokens)).Div(million) + totalCost := uncachedCost.Add(cacheReadCost).Add(cacheWriteCost).Add(outputCost) + + inputMicrocents := model.InputPricePerMillion.Mul(toMicrocents).Round(0).IntPart() + outputMicrocents := model.OutputPricePerMillion.Mul(toMicrocents).Round(0).IntPart() + totalMicrocents := totalCost.Mul(toMicrocents).Ceil().IntPart() + if totalMicrocents < 1 { + totalMicrocents = 1 + } + + inputPricePer1MTokenMicrocents = &inputMicrocents + outputPricePer1MTokenMicrocents = &outputMicrocents + costMicrocents = &totalMicrocents + + // Snapshot cache prices for the usage event + if model.CacheReadPricePerMillion != nil { + v := cacheReadPrice.Mul(toMicrocents).Round(0).IntPart() + cacheReadPricePer1MTokenMicrocents = &v + } + if model.CacheWritePricePerMillion != nil { + v := cacheWritePrice.Mul(toMicrocents).Round(0).IntPart() + cacheWritePricePer1MTokenMicrocents = &v + } + } + + providerRequestID := stringPtrOrNil(resp.ID) + requestedModel := requestedModelID(req) + routingCategoryScoresJSON, err := marshalRoutingCategoryScores(req.RoutingCategoryScores) + if err != nil { + return err + } + + tx, err := r.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + if costMicrocents != nil { + balance, err := r.getBalanceForUpdate(ctx, tx, auth.Account.ID) + if err != nil { + return err + } + + updatedBalance, usageCharge := balance.ApplyUsageCharge(*costMicrocents, now) + charge = usageCharge + chargedMonthlyMicrocents = int64Ptr(charge.ChargedMonthlyMicrocents) + chargedPrepaidMicrocents = int64Ptr(charge.ChargedPrepaidMicrocents) + + if err := r.updateBalance(ctx, tx, updatedBalance); err != nil { + return err + } + } + + if _, err := tx.Exec(ctx, ` + INSERT INTO usage_events ( + id, request_id, account_id, api_key_id, public_model_id, + public_model_name, endpoint, + requested_public_model_id, router_id, routed_public_model_id, + matched_category, decision_reason, fallback_used, routing_score, routing_category_scores, + provider_name, provider_model_id, provider_request_id, + stream, success, finish_reason, + input_tokens, output_tokens, + cache_creation_tokens, cache_read_tokens, + input_price_per_1m_tokens_microcents, output_price_per_1m_tokens_microcents, + cache_read_price_per_1m_tokens_microcents, cache_write_price_per_1m_tokens_microcents, + cost_microcents, charged_monthly_microcents, charged_prepaid_microcents, + latency_ms, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, + $11, $12, $13, $14, $15::jsonb, + $16, $17, $18, + $19, $20, $21, + $22, $23, + $24, $25, + $26, $27, + $28, $29, + $30, $31, $32, + $33, + $34 + ) + `, + usageEventID, requestID, auth.Account.ID, auth.APIKey.ID, model.ID, + model.DisplayName, endpoint, + requestedModel, req.RouterID, req.RoutedPublicModelID, + req.MatchedCategory, req.DecisionReason, req.FallbackUsed, req.RoutingScore, routingCategoryScoresJSON, + resp.ProviderName, resp.ProviderModelID, providerRequestID, + req.Stream, true, resp.FinishReason, + inputTokens, outputTokens, + cacheCreationTokens, cacheReadTokens, + inputPricePer1MTokenMicrocents, outputPricePer1MTokenMicrocents, + cacheReadPricePer1MTokenMicrocents, cacheWritePricePer1MTokenMicrocents, + costMicrocents, chargedMonthlyMicrocents, chargedPrepaidMicrocents, + latencyMs, + now, + ); err != nil { + return err + } + + if err := insertUsageLedgerEntries(ctx, tx, auth.Account.ID, usageEventID, charge); err != nil { + return err + } + + return tx.Commit(ctx) +} + +func (r *UsageRepo) RecordFailure(ctx context.Context, auth *domain.AuthContext, endpoint string, req *domain.GenerateRequest, model *domain.PublicModel, err error, partialUsage *domain.Usage, latencyMs int64) error { + var accountID, apiKeyID string + if auth != nil { + accountID = auth.Account.ID + apiKeyID = auth.APIKey.ID + } + + var modelID *string + var providerName, providerModelID *string + if model != nil { + modelID = &model.ID + providerName = &model.ProviderConfig.ProviderName + providerModelID = &model.ProviderModelID + } + + var requestedModel string + var stream bool + var routerID, routedPublicModelID *string + var matchedCategory, decisionReason *string + var routingScore *float32 + var fallbackUsed *bool + var routingCategoryScores []domain.RoutingCategoryScore + if req != nil { + requestedModel = requestedModelID(*req) + stream = req.Stream + routerID = req.RouterID + routedPublicModelID = req.RoutedPublicModelID + matchedCategory = req.MatchedCategory + routingScore = req.RoutingScore + decisionReason = req.DecisionReason + fallbackUsed = req.FallbackUsed + routingCategoryScores = req.RoutingCategoryScores + } + routingCategoryScoresJSON, marshalErr := marshalRoutingCategoryScores(routingCategoryScores) + if marshalErr != nil { + return marshalErr + } + + var errType, errCode string + var gwErr *domain.GatewayError + if errors.As(err, &gwErr) { + errType = gwErr.Type + errCode = gwErr.Code + } else if err != nil { + errType = domain.ErrTypeInternal + errCode = domain.ErrCodeInternalError + } + + var inputTokens, outputTokens *int64 + if partialUsage != nil { + inputTokens = &partialUsage.PromptTokens + outputTokens = &partialUsage.CompletionTokens + } + + _, execErr := r.pool.Exec(ctx, ` + INSERT INTO usage_events ( + id, request_id, account_id, api_key_id, public_model_id, public_model_name, endpoint, + requested_public_model_id, router_id, routed_public_model_id, + matched_category, decision_reason, fallback_used, routing_score, routing_category_scores, + provider_name, provider_model_id, + stream, success, + input_tokens, output_tokens, + error_type, error_code, + latency_ms, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, + $11, $12, $13, $14, $15::jsonb, + $16, $17, + $18, $19, + $20, $21, + $22, $23, + $24, + $25 + ) + `, + uuid.New().String(), uuid.New().String(), accountID, apiKeyID, modelID, requestedModelPtr(model), endpoint, + requestedModel, routerID, routedPublicModelID, + matchedCategory, decisionReason, fallbackUsed, routingScore, routingCategoryScoresJSON, + providerName, providerModelID, + stream, false, + inputTokens, outputTokens, + errType, errCode, + latencyMs, + time.Now(), + ) + return execErr +} + +func marshalRoutingCategoryScores(scores []domain.RoutingCategoryScore) (*string, error) { + if len(scores) == 0 { + return nil, nil + } + raw, err := json.Marshal(scores) + if err != nil { + return nil, err + } + encoded := string(raw) + return &encoded, nil +} + +func stringPtrOrNil(v string) *string { + if v == "" { + return nil + } + return &v +} + +func requestedModelID(req domain.GenerateRequest) string { + if req.RequestedModelID != "" { + return req.RequestedModelID + } + return req.PublicModelID +} + +func requestedModelPtr(model *domain.PublicModel) *string { + if model == nil { + return nil + } + return &model.DisplayName +} + +func (r *UsageRepo) getBalanceForUpdate(ctx context.Context, tx pgx.Tx, accountID string) (domain.AccountBalance, error) { + var balance domain.AccountBalance + err := tx.QueryRow(ctx, ` + SELECT account_id, currency, monthly_credit_total_microcents, monthly_credit_used_microcents, + prepaid_credit_microcents, billing_period_start, billing_period_end, updated_at + FROM account_balances + WHERE account_id = $1 + FOR UPDATE + `, accountID).Scan( + &balance.AccountID, &balance.Currency, &balance.MonthlyCreditTotalMicrocents, &balance.MonthlyCreditUsedMicrocents, + &balance.PrepaidCreditMicrocents, &balance.BillingPeriodStart, &balance.BillingPeriodEnd, &balance.UpdatedAt, + ) + return balance, err +} + +func (r *UsageRepo) updateBalance(ctx context.Context, tx pgx.Tx, balance domain.AccountBalance) error { + _, err := tx.Exec(ctx, ` + UPDATE account_balances + SET monthly_credit_used_microcents = $1, + prepaid_credit_microcents = $2, + updated_at = NOW() + WHERE account_id = $3 + `, balance.MonthlyCreditUsedMicrocents, balance.PrepaidCreditMicrocents, balance.AccountID) + return err +} + +func insertUsageLedgerEntries(ctx context.Context, tx pgx.Tx, accountID, usageEventID string, charge domain.UsageCharge) error { + if charge.ChargedMonthlyMicrocents > 0 { + if _, err := tx.Exec(ctx, ` + INSERT INTO account_ledger (account_id, source_type, source_id, balance_type, amount_microcents, description) + VALUES ($1, $2, $3, $4, $5, $6) + `, accountID, domain.LedgerSourceTypeUsageEvent, usageEventID, domain.LedgerBalanceTypeMonthly, -charge.ChargedMonthlyMicrocents, "usage debit from monthly credits"); err != nil { + return err + } + } + if charge.ChargedPrepaidMicrocents > 0 { + if _, err := tx.Exec(ctx, ` + INSERT INTO account_ledger (account_id, source_type, source_id, balance_type, amount_microcents, description) + VALUES ($1, $2, $3, $4, $5, $6) + `, accountID, domain.LedgerSourceTypeUsageEvent, usageEventID, domain.LedgerBalanceTypePrepaid, -charge.ChargedPrepaidMicrocents, "usage debit from prepaid credits"); err != nil { + return err + } + } + return nil +} + +func int64Ptr(v int64) *int64 { + if v == 0 { + return nil + } + return &v +} diff --git a/apps/gateway/internal/application/ports/auth.go b/apps/gateway/internal/application/ports/auth.go new file mode 100644 index 0000000..9030ac2 --- /dev/null +++ b/apps/gateway/internal/application/ports/auth.go @@ -0,0 +1,12 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// AuthService authenticates API keys and returns the auth context. +type AuthService interface { + AuthenticateAPIKey(ctx context.Context, apiKey string) (domain.AuthContext, error) +} diff --git a/apps/gateway/internal/application/ports/billing.go b/apps/gateway/internal/application/ports/billing.go new file mode 100644 index 0000000..8a186ba --- /dev/null +++ b/apps/gateway/internal/application/ports/billing.go @@ -0,0 +1,18 @@ +package ports + +import "context" + +// BalanceChecker reports the spendable credit for an account. +type BalanceChecker interface { + GetSpendableBalance(ctx context.Context, accountID string) (int64, error) +} + +// CostReserver manages in-flight cost reservations for concurrent requests. +type CostReserver interface { + // TryReserve atomically checks that existing reservations plus maxCostMicrocents + // do not exceed availableBalanceMicrocents, and if so creates a reservation. + // Returns a reservation ID on success, or an error if the balance is insufficient. + TryReserve(accountID string, availableBalanceMicrocents int64, maxCostMicrocents int64) (string, error) + // Release removes a previously created reservation, freeing the held amount. + Release(reservationID string) +} diff --git a/apps/gateway/internal/application/ports/logger.go b/apps/gateway/internal/application/ports/logger.go new file mode 100644 index 0000000..e3215e8 --- /dev/null +++ b/apps/gateway/internal/application/ports/logger.go @@ -0,0 +1,9 @@ +package ports + +// Logger is the canonical logging port. +type Logger interface { + Debug(msg string, fields ...any) + Info(msg string, fields ...any) + Warn(msg string, fields ...any) + Error(msg string, fields ...any) +} diff --git a/apps/gateway/internal/application/ports/model_catalog.go b/apps/gateway/internal/application/ports/model_catalog.go new file mode 100644 index 0000000..34df591 --- /dev/null +++ b/apps/gateway/internal/application/ports/model_catalog.go @@ -0,0 +1,15 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// ModelCatalog loads public models from the database. +type ModelCatalog interface { + ListPublicModels(ctx context.Context) ([]domain.PublicModel, error) + GetPublicModel(ctx context.Context, publicModelID string) (domain.PublicModel, error) + ListRouters(ctx context.Context) ([]domain.RouterEntry, error) + GetRouter(ctx context.Context, routerID string) (domain.RouterEntry, error) +} diff --git a/apps/gateway/internal/application/ports/pii.go b/apps/gateway/internal/application/ports/pii.go new file mode 100644 index 0000000..34bcfaf --- /dev/null +++ b/apps/gateway/internal/application/ports/pii.go @@ -0,0 +1,32 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// PIIAnalyzeOptions controls detection for a single gateway request. +type PIIAnalyzeOptions struct { + // Language is an ISO-639-1 code such as "en"; an empty string means + // "use the implementation's default". + Language string + // Mode is the Nexus PII masking level selected on the API key. + Mode string +} + +// PIIFilter detects personally-identifiable information in arbitrary text so +// the gateway can mask it before forwarding requests to upstream providers and +// un-mask matching tokens on the way back. Implementations MUST be safe for +// concurrent use. +type PIIFilter interface { + // Analyze returns the PII spans found in `text`. Implementations should + // return entities with byte offsets into `text` (not character offsets) + // so that callers can substitute them with strings.Builder safely. + // + Analyze(ctx context.Context, text string, opts PIIAnalyzeOptions) ([]domain.PIIEntity, error) + + // Enabled reports whether the filter is operational. When false, the + // gateway will skip both detection and the un-masking pass. + Enabled() bool +} diff --git a/apps/gateway/internal/application/ports/provider.go b/apps/gateway/internal/application/ports/provider.go new file mode 100644 index 0000000..5503197 --- /dev/null +++ b/apps/gateway/internal/application/ports/provider.go @@ -0,0 +1,26 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// GenerationProvider translates canonical requests to upstream provider calls. +type GenerationProvider interface { + Generate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) + StreamGenerate(ctx context.Context, req domain.GenerateRequest, model domain.PublicModel) (GenerationStream, error) +} + +// GenerationStream reads canonical stream events from a provider. +type GenerationStream interface { + Recv() (domain.StreamEvent, error) + Close() error +} + +// VerifiedTransportProofProvider exposes safe proof evidence produced by a +// provider adapter whose request transport is verified before user content is +// sent upstream. +type VerifiedTransportProofProvider interface { + VerifiedTransportProof() *domain.TinfoilTransportProof +} diff --git a/apps/gateway/internal/application/ports/provider_registry.go b/apps/gateway/internal/application/ports/provider_registry.go new file mode 100644 index 0000000..ac4ab96 --- /dev/null +++ b/apps/gateway/internal/application/ports/provider_registry.go @@ -0,0 +1,6 @@ +package ports + +// ProviderRegistry selects a provider adapter by name. +type ProviderRegistry interface { + GetProvider(providerName string) (GenerationProvider, error) +} diff --git a/apps/gateway/internal/application/ports/router_client.go b/apps/gateway/internal/application/ports/router_client.go new file mode 100644 index 0000000..f4a2c47 --- /dev/null +++ b/apps/gateway/internal/application/ports/router_client.go @@ -0,0 +1,12 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// RouterClient resolves explicit router model IDs to concrete public models. +type RouterClient interface { + Route(ctx context.Context, req domain.RouteRequest) (domain.RouteDecision, error) +} diff --git a/apps/gateway/internal/application/ports/tinfoil_proof.go b/apps/gateway/internal/application/ports/tinfoil_proof.go new file mode 100644 index 0000000..952a739 --- /dev/null +++ b/apps/gateway/internal/application/ports/tinfoil_proof.go @@ -0,0 +1,14 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// TinfoilTransportProofRepository persists and retrieves safe per-response +// Tinfoil attested-transport proof evidence. +type TinfoilTransportProofRepository interface { + UpsertTinfoilTransportProof(ctx context.Context, proof domain.TinfoilTransportProof) error + GetTinfoilTransportProof(ctx context.Context, accountID, providerResponseID string) (domain.TinfoilTransportProof, error) +} diff --git a/apps/gateway/internal/application/ports/usage.go b/apps/gateway/internal/application/ports/usage.go new file mode 100644 index 0000000..d92f9e4 --- /dev/null +++ b/apps/gateway/internal/application/ports/usage.go @@ -0,0 +1,13 @@ +package ports + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// UsageRecorder persists usage events. +type UsageRecorder interface { + RecordSuccess(ctx context.Context, auth domain.AuthContext, endpoint string, req domain.GenerateRequest, resp domain.GenerateResult, model domain.PublicModel, latencyMs int64) error + RecordFailure(ctx context.Context, auth *domain.AuthContext, endpoint string, req *domain.GenerateRequest, model *domain.PublicModel, err error, partialUsage *domain.Usage, latencyMs int64) error +} diff --git a/apps/gateway/internal/application/services/chat_completions_service.go b/apps/gateway/internal/application/services/chat_completions_service.go new file mode 100644 index 0000000..7e5a9a6 --- /dev/null +++ b/apps/gateway/internal/application/services/chat_completions_service.go @@ -0,0 +1,30 @@ +package services + +import ( + "context" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// ChatCompletionsService handles /v1/chat/completions compatibility mapping. +type ChatCompletionsService struct { + generate *GenerateService + logger ports.Logger +} + +func NewChatCompletionsService(generate *GenerateService, logger ports.Logger) *ChatCompletionsService { + return &ChatCompletionsService{generate: generate, logger: logger} +} + +// Execute processes a non-streaming /v1/chat/completions request. +func (s *ChatCompletionsService) Execute(ctx context.Context, req domain.GenerateRequest, bearerToken string) (domain.GenerateResult, error) { + result, _, err := s.generate.Execute(ctx, domain.EndpointChatCompletions, req, bearerToken) + return result, err +} + +// ExecuteStream processes a streaming /v1/chat/completions request. +func (s *ChatCompletionsService) ExecuteStream(ctx context.Context, req domain.GenerateRequest, bearerToken string) (ports.GenerationStream, *domain.PublicModel, error) { + stream, _, model, err := s.generate.ExecuteStream(ctx, domain.EndpointChatCompletions, req, bearerToken) + return stream, model, err +} diff --git a/apps/gateway/internal/application/services/generate_service.go b/apps/gateway/internal/application/services/generate_service.go new file mode 100644 index 0000000..0fb35bc --- /dev/null +++ b/apps/gateway/internal/application/services/generate_service.go @@ -0,0 +1,526 @@ +package services + +import ( + "context" + "errors" + "io" + "strings" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/middleware" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// GenerateService is the single canonical execution path for all generation endpoints. +type GenerateService struct { + auth ports.AuthService + balances ports.BalanceChecker + catalog ports.ModelCatalog + router ports.RouterClient + registry ports.ProviderRegistry + usage ports.UsageRecorder + reserver ports.CostReserver + pii ports.PIIFilter + piiLang string + piiFailOpen bool + tinfoilProofs ports.TinfoilTransportProofRepository + logger ports.Logger +} + +func NewGenerateService( + auth ports.AuthService, + balances ports.BalanceChecker, + catalog ports.ModelCatalog, + router ports.RouterClient, + registry ports.ProviderRegistry, + usage ports.UsageRecorder, + reserver ports.CostReserver, + pii ports.PIIFilter, + logger ports.Logger, +) *GenerateService { + return &GenerateService{ + auth: auth, + balances: balances, + catalog: catalog, + router: router, + registry: registry, + usage: usage, + reserver: reserver, + pii: pii, + piiLang: "en", + logger: logger, + } +} + +// SetPIIOptions configures the PII filter language and failure mode. +// failOpen=true allows requests through when the filter is unreachable. +func (s *GenerateService) SetPIIOptions(language string, failOpen bool) { + if language != "" { + s.piiLang = language + } + s.piiFailOpen = failOpen +} + +func (s *GenerateService) SetTinfoilProofRepository(proofs ports.TinfoilTransportProofRepository) { + s.tinfoilProofs = proofs +} + +// Execute runs a non-streaming generation request through the canonical flow. +func (s *GenerateService) Execute(ctx context.Context, endpoint string, req domain.GenerateRequest, bearerToken string) (domain.GenerateResult, domain.AuthContext, error) { + start := time.Now() + requestID := middleware.GetRequestID(ctx) + + authCtx, err := s.auth.AuthenticateAPIKey(ctx, bearerToken) + if err != nil { + return domain.GenerateResult{}, domain.AuthContext{}, err + } + + model, execReq, err := s.resolveModel(ctx, req) + if err != nil { + return domain.GenerateResult{}, authCtx, err + } + + if err := s.validateRequest(endpoint, execReq, model); err != nil { + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return domain.GenerateResult{}, authCtx, err + } + + piiMapping, err := s.maskRequest(ctx, &execReq, authCtx.APIKey.PIIMode) + if err != nil { + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return domain.GenerateResult{}, authCtx, err + } + + reservationID, err := s.reserveCost(ctx, authCtx.Account.ID, model, execReq.MaxOutputTokens) + if err != nil { + return domain.GenerateResult{}, authCtx, err + } + defer s.reserver.Release(reservationID) + + provider, err := s.registry.GetProvider(model.ProviderConfig.ProviderName) + if err != nil { + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return domain.GenerateResult{}, authCtx, err + } + + result, err := provider.Generate(ctx, execReq, model) + if err != nil { + err = sanitizeErrorWithPIIMapping(err, piiMapping) + fields := s.buildErrorLogFields(ctx, requestID, &authCtx, endpoint, execReq.PublicModelID, model, err, time.Since(start).Milliseconds()) + s.logger.Error("provider generation failed", fields...) + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return domain.GenerateResult{}, authCtx, err + } + + s.storeTinfoilProof(ctx, authCtx, model, result) + + latencyMs := time.Since(start).Milliseconds() + s.logger.Info("generation completed", + "request_id", requestID, + "account_id", authCtx.Account.ID, + "endpoint", endpoint, + "model", execReq.PublicModelID, + "provider", model.ProviderConfig.ProviderName, + "latency_ms", latencyMs, + "finish_reason", result.FinishReason, + "gateway_status", 200, + "upstream_status", 200, + ) + + if err := s.usage.RecordSuccess(ctx, authCtx, endpoint, execReq, result, model, latencyMs); err != nil { + s.logger.Error("failed to record usage", + "request_id", requestID, + "account_id", authCtx.Account.ID, + "error", err, + ) + } + + unmaskResult(&result, piiMapping, s.logger, + "request_id", requestID, + "endpoint", endpoint, + "model", execReq.PublicModelID, + "provider", model.ProviderConfig.ProviderName, + "pii_mode", authCtx.APIKey.PIIMode, + ) + return result, authCtx, nil +} + +// ExecuteStream runs a streaming generation request through the canonical flow. +func (s *GenerateService) ExecuteStream(ctx context.Context, endpoint string, req domain.GenerateRequest, bearerToken string) (ports.GenerationStream, domain.AuthContext, *domain.PublicModel, error) { + start := time.Now() + requestID := middleware.GetRequestID(ctx) + + authCtx, err := s.auth.AuthenticateAPIKey(ctx, bearerToken) + if err != nil { + return nil, domain.AuthContext{}, nil, err + } + + model, execReq, err := s.resolveModel(ctx, req) + if err != nil { + return nil, authCtx, nil, err + } + + if err := s.validateRequest(endpoint, execReq, model); err != nil { + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return nil, authCtx, &model, err + } + + piiMapping, err := s.maskRequest(ctx, &execReq, authCtx.APIKey.PIIMode) + if err != nil { + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return nil, authCtx, &model, err + } + + reservationID, err := s.reserveCost(ctx, authCtx.Account.ID, model, execReq.MaxOutputTokens) + if err != nil { + return nil, authCtx, &model, err + } + + provider, err := s.registry.GetProvider(model.ProviderConfig.ProviderName) + if err != nil { + s.reserver.Release(reservationID) + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return nil, authCtx, &model, err + } + + stream, err := provider.StreamGenerate(ctx, execReq, model) + if err != nil { + err = sanitizeErrorWithPIIMapping(err, piiMapping) + fields := s.buildErrorLogFields(ctx, requestID, &authCtx, endpoint, execReq.PublicModelID, model, err, time.Since(start).Milliseconds()) + s.logger.Error("provider stream generation failed", fields...) + s.reserver.Release(reservationID) + s.usage.RecordFailure(ctx, &authCtx, endpoint, &execReq, &model, err, nil, time.Since(start).Milliseconds()) + return nil, authCtx, &model, err + } + + wrapped := &usageTrackingStream{ + inner: stream, + service: s, + ctx: ctx, + authCtx: authCtx, + endpoint: endpoint, + req: execReq, + model: model, + requestID: requestID, + reservationID: reservationID, + start: start, + piiMapping: piiMapping, + } + + if piiMapping != nil && piiMapping.Len() > 0 { + return newPIIUnmaskingStreamWithLogger(wrapped, piiMapping, s.logger, + "request_id", requestID, + "endpoint", endpoint, + "model", execReq.PublicModelID, + "provider", model.ProviderConfig.ProviderName, + "pii_mode", authCtx.APIKey.PIIMode, + ), authCtx, &model, nil + } + return wrapped, authCtx, &model, nil +} + +func (s *GenerateService) resolveModel(ctx context.Context, req domain.GenerateRequest) (domain.PublicModel, domain.GenerateRequest, error) { + requestedModelID := req.PublicModelID + if req.RequestedModelID == "" { + req.RequestedModelID = requestedModelID + } + + model, err := s.catalog.GetPublicModel(ctx, requestedModelID) + if err == nil { + return model, req, nil + } + if !isUnsupportedModelError(err) { + return domain.PublicModel{}, req, err + } + + routerEntry, routerErr := s.catalog.GetRouter(ctx, requestedModelID) + if routerErr != nil { + if isNotFoundError(routerErr) { + return domain.PublicModel{}, req, domain.ErrUnsupportedModel(requestedModelID) + } + return domain.PublicModel{}, req, routerErr + } + if s.router == nil { + return domain.PublicModel{}, req, domain.ErrProviderUnavailable("router") + } + + decision, err := s.router.Route(ctx, domain.RouteRequest{ + RouterID: routerEntry.RouterID, + Request: req, + }) + if err != nil { + return domain.PublicModel{}, req, err + } + if decision.PublicModelID == "" { + return domain.PublicModel{}, req, domain.ErrProviderUnavailable("router") + } + + model, err = s.catalog.GetPublicModel(ctx, decision.PublicModelID) + if err != nil { + if isUnsupportedModelError(err) { + return domain.PublicModel{}, req, domain.ErrUnsupportedModel(decision.PublicModelID) + } + return domain.PublicModel{}, req, err + } + + routerID := routerEntry.RouterID + routedPublicModelID := decision.PublicModelID + req.PublicModelID = decision.PublicModelID + req.RouterID = &routerID + req.RoutedPublicModelID = &routedPublicModelID + if decision.Category != nil { + req.MatchedCategory = decision.Category + } + req.RoutingScore = decision.Score + if len(decision.CategoryScores) > 0 { + req.RoutingCategoryScores = append([]domain.RoutingCategoryScore(nil), decision.CategoryScores...) + } + if decision.Reason != "" { + reason := decision.Reason + req.DecisionReason = &reason + } + fallback := decision.FallbackUsed + req.FallbackUsed = &fallback + return model, req, nil +} + +func isUnsupportedModelError(err error) bool { + var gwErr *domain.GatewayError + return errors.As(err, &gwErr) && gwErr.Code == domain.ErrCodeUnsupportedModel +} + +func isNotFoundError(err error) bool { + var gwErr *domain.GatewayError + return errors.As(err, &gwErr) && gwErr.Code == domain.ErrCodeNotFound +} + +func (s *GenerateService) validateRequest(endpoint string, req domain.GenerateRequest, model domain.PublicModel) error { + if !model.SupportsEndpoint(endpoint) { + return domain.ErrUnsupportedEndpoint(model.PublicModelID, endpoint) + } + + if req.Stream && !model.SupportsStreamForEndpoint(endpoint) { + return domain.ErrUnsupportedFeature("streaming for " + endpoint) + } + + if len(req.Tools) > 0 && !model.SupportsTools { + return domain.ErrUnsupportedFeature("tools") + } + + if req.ParallelToolCalls != nil && *req.ParallelToolCalls && !model.SupportsParallelToolCalls { + return domain.ErrUnsupportedFeature("parallel_tool_calls") + } + + if req.TextConfig != nil && !model.SupportsStructuredOutput { + return domain.ErrUnsupportedFeature("structured_output") + } + + if model.EffectiveProofMode() == domain.ProofModeTinfoilAttestedTransport && + !strings.EqualFold(model.ProviderConfig.ProviderName, "tinfoil") { + return domain.ErrUnsupportedFeature("Tinfoil verified transport for non-Tinfoil provider") + } + + return nil +} + +// reserveCost calculates the worst-case cost for a request and reserves it against the account balance. +func (s *GenerateService) reserveCost(ctx context.Context, accountID string, model domain.PublicModel, maxOutputTokensOverride *int) (string, error) { + maxCost := model.MaxCostMicrocents(maxOutputTokensOverride) + + availableBalance, err := s.balances.GetSpendableBalance(ctx, accountID) + if err != nil { + return "", domain.ErrInternal("failed to load account balance") + } + + reservationID, err := s.reserver.TryReserve(accountID, availableBalance, maxCost) + if err != nil { + return "", err + } + return reservationID, nil +} + +// usageTrackingStream wraps a GenerationStream to record usage on completion. +type usageTrackingStream struct { + inner ports.GenerationStream + service *GenerateService + ctx context.Context + authCtx domain.AuthContext + endpoint string + req domain.GenerateRequest + model domain.PublicModel + requestID string + reservationID string + start time.Time + piiMapping *domain.PIIMapping + lastUsage *domain.Usage + finishReason *string + providerResponseID string + finished bool +} + +func (s *usageTrackingStream) Recv() (domain.StreamEvent, error) { + event, err := s.inner.Recv() + if err != nil { + if err == io.EOF { + s.recordCompletion(nil) + return event, err + } + // If the model already sent a finish reason, treat post-completion + // errors (e.g. context canceled after client disconnect) as success. + if s.finishReason != nil { + s.recordCompletion(nil) + return event, io.EOF + } + s.recordCompletion(err) + return event, err + } + + if event.Usage != nil { + s.lastUsage = event.Usage + } + if event.ProviderResponseID != "" { + s.providerResponseID = event.ProviderResponseID + } + + if event.Type == domain.StreamEventCompleted { + if event.FinishReason != nil { + s.finishReason = event.FinishReason + } + } + + if event.Type == domain.StreamEventError { + var gwErr error + if event.Error != nil { + gwErr = event.Error + } + s.recordCompletion(gwErr) + } + + return event, nil +} + +func (s *usageTrackingStream) Close() error { + if !s.finished { + s.recordCompletion(context.Canceled) + } + return s.inner.Close() +} + +func (s *usageTrackingStream) recordCompletion(err error) { + if s.finished { + return + } + s.finished = true + ctx := context.WithoutCancel(s.ctx) + s.service.reserver.Release(s.reservationID) + latencyMs := time.Since(s.start).Milliseconds() + if err != nil { + err = sanitizeErrorWithPIIMapping(err, s.piiMapping) + fields := s.service.buildErrorLogFields(ctx, s.requestID, &s.authCtx, s.endpoint, s.req.PublicModelID, s.model, err, latencyMs) + s.service.logger.Error("stream error", fields...) + s.service.usage.RecordFailure(ctx, &s.authCtx, s.endpoint, &s.req, &s.model, err, s.lastUsage, latencyMs) + return + } + // Stream ended normally (io.EOF) — record success with accumulated usage. + result := domain.GenerateResult{ + ID: s.providerResponseID, + PublicModelID: s.model.PublicModelID, + ProviderName: s.model.ProviderConfig.ProviderName, + ProviderModelID: s.model.ProviderModelID, + FinishReason: s.finishReason, + Usage: s.lastUsage, + } + if verified, ok := s.inner.(ports.VerifiedTransportProofProvider); ok { + result.TinfoilProof = verified.VerifiedTransportProof() + } + s.service.storeTinfoilProof(ctx, s.authCtx, s.model, result) + s.service.logger.Info("generation completed", + "request_id", s.requestID, + "account_id", s.authCtx.Account.ID, + "endpoint", s.endpoint, + "model", s.req.PublicModelID, + "provider", s.model.ProviderConfig.ProviderName, + "latency_ms", latencyMs, + "finish_reason", s.finishReason, + "gateway_status", 200, + "upstream_status", 200, + ) + if recErr := s.service.usage.RecordSuccess(ctx, s.authCtx, s.endpoint, s.req, result, s.model, latencyMs); recErr != nil { + s.service.logger.Error("failed to record stream usage", + "request_id", s.requestID, + "account_id", s.authCtx.Account.ID, + "error", recErr, + ) + } +} + +// buildErrorLogFields builds a structured log field slice for error conditions, +// including request context, provider details, and any upstream error metadata. +func (s *GenerateService) buildErrorLogFields(ctx context.Context, requestID string, authCtx *domain.AuthContext, endpoint, publicModelID string, model domain.PublicModel, err error, latencyMs int64) []any { + fields := []any{ + "request_id", requestID, + "endpoint", endpoint, + "model", publicModelID, + "provider", model.ProviderConfig.ProviderName, + "provider_model", model.UpstreamModelName, + "latency_ms", latencyMs, + "error", err.Error(), + } + if authCtx != nil { + fields = append(fields, "account_id", authCtx.Account.ID) + } + // Merge structured upstream metadata from GatewayError if present. + var gwErr *domain.GatewayError + if errors.As(err, &gwErr) { + fields = append(fields, "error_code", gwErr.Code) + fields = append(fields, "gateway_status", gwErr.HTTPStatus) + fields = append(fields, gwErr.LogFields()...) + } + return fields +} + +func (s *GenerateService) storeTinfoilProof(ctx context.Context, auth domain.AuthContext, model domain.PublicModel, result domain.GenerateResult) { + if s == nil || s.tinfoilProofs == nil || model.EffectiveProofMode() != domain.ProofModeTinfoilAttestedTransport { + return + } + if result.TinfoilProof == nil { + if s.logger != nil { + s.logger.Warn("Tinfoil proof mode enabled but provider returned no proof evidence", + "provider", model.ProviderConfig.ProviderName, + "model", model.PublicModelID, + "provider_response_id", result.ID, + ) + } + return + } + proof := *result.TinfoilProof + proof.AccountID = auth.Account.ID + proof.APIKeyID = auth.APIKey.ID + proof.Provider = model.ProviderConfig.ProviderName + proof.PublicModelID = model.PublicModelID + proof.UpstreamModelID = model.UpstreamModelName + if proof.ProviderResponseID == "" { + proof.ProviderResponseID = result.ID + } + if proof.ProviderResponseID == "" { + if s.logger != nil { + s.logger.Error("failed to store Tinfoil proof: missing provider response id", + "provider", model.ProviderConfig.ProviderName, + "model", model.PublicModelID, + ) + } + return + } + if proof.CreatedAt.IsZero() { + proof.CreatedAt = time.Now().UTC() + } + if err := s.tinfoilProofs.UpsertTinfoilTransportProof(ctx, proof); err != nil && s.logger != nil { + s.logger.Error("failed to store Tinfoil transport proof", + "provider", model.ProviderConfig.ProviderName, + "model", model.PublicModelID, + "provider_response_id", proof.ProviderResponseID, + "error", err, + ) + } +} diff --git a/apps/gateway/internal/application/services/generate_service_test.go b/apps/gateway/internal/application/services/generate_service_test.go new file mode 100644 index 0000000..eb5522c --- /dev/null +++ b/apps/gateway/internal/application/services/generate_service_test.go @@ -0,0 +1,351 @@ +package services + +import ( + "context" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/shopspring/decimal" +) + +type stubAuthService struct { + authCtx domain.AuthContext + err error +} + +func (s *stubAuthService) AuthenticateAPIKey(_ context.Context, _ string) (domain.AuthContext, error) { + return s.authCtx, s.err +} + +type stubBalanceChecker struct { + balance int64 + err error +} + +func (s *stubBalanceChecker) GetSpendableBalance(_ context.Context, _ string) (int64, error) { + return s.balance, s.err +} + +type stubModelCatalog struct { + models map[string]domain.PublicModel + routers map[string]domain.RouterEntry + model domain.PublicModel + err error +} + +func (s *stubModelCatalog) ListPublicModels(_ context.Context) ([]domain.PublicModel, error) { + return nil, nil +} + +func (s *stubModelCatalog) GetPublicModel(_ context.Context, publicModelID string) (domain.PublicModel, error) { + if s.models != nil { + model, ok := s.models[publicModelID] + if !ok { + return domain.PublicModel{}, domain.ErrUnsupportedModel(publicModelID) + } + return model, nil + } + return s.model, s.err +} + +func (s *stubModelCatalog) ListRouters(_ context.Context) ([]domain.RouterEntry, error) { + return nil, nil +} + +func (s *stubModelCatalog) GetRouter(_ context.Context, routerID string) (domain.RouterEntry, error) { + router, ok := s.routers[routerID] + if !ok { + return domain.RouterEntry{}, domain.ErrNotFound("router", routerID) + } + return router, nil +} + +type stubRouterClient struct { + decision domain.RouteDecision + err error + calls int +} + +func (s *stubRouterClient) Route(_ context.Context, _ domain.RouteRequest) (domain.RouteDecision, error) { + s.calls++ + if s.err != nil { + return domain.RouteDecision{}, s.err + } + return s.decision, nil +} + +type stubUsageRecorder struct { + failureCalls int + lastSuccessReq domain.GenerateRequest +} + +func (s *stubUsageRecorder) RecordSuccess(_ context.Context, _ domain.AuthContext, _ string, req domain.GenerateRequest, _ domain.GenerateResult, _ domain.PublicModel, _ int64) error { + s.lastSuccessReq = req + return nil +} + +func (s *stubUsageRecorder) RecordFailure(_ context.Context, _ *domain.AuthContext, _ string, _ *domain.GenerateRequest, _ *domain.PublicModel, _ error, _ *domain.Usage, _ int64) error { + s.failureCalls++ + return nil +} + +type stubCostReserver struct { + reserveErr error + maxCost int64 +} + +func (s *stubCostReserver) TryReserve(_ string, _ int64, maxCost int64) (string, error) { + s.maxCost = maxCost + if s.reserveErr != nil { + return "", s.reserveErr + } + return "reservation-1", nil +} + +func (s *stubCostReserver) Release(_ string) {} + +type stubProviderRegistry struct { + provider *stubProvider + err error +} + +func (s *stubProviderRegistry) GetProvider(_ string) (ports.GenerationProvider, error) { + if s.err != nil { + return nil, s.err + } + return s.provider, nil +} + +type stubProvider struct{} + +func (s *stubProvider) Generate(_ context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + return domain.GenerateResult{ + ID: "provider-response", + PublicModelID: req.PublicModelID, + ProviderName: model.ProviderConfig.ProviderName, + ProviderModelID: model.ProviderModelID, + }, nil +} + +func (s *stubProvider) StreamGenerate(_ context.Context, _ domain.GenerateRequest, _ domain.PublicModel) (ports.GenerationStream, error) { + return nil, nil +} + +type stubLogger struct{} + +func (stubLogger) Debug(string, ...any) {} +func (stubLogger) Info(string, ...any) {} +func (stubLogger) Warn(string, ...any) {} +func (stubLogger) Error(string, ...any) {} + +func TestGenerateServiceExecute_RejectsWhenBalanceIsEmpty(t *testing.T) { + usage := &stubUsageRecorder{} + svc := NewGenerateService( + &stubAuthService{ + authCtx: domain.AuthContext{ + Account: domain.Account{ID: "acc1", Status: domain.AccountStatusActive}, + APIKey: domain.APIKey{ID: "key1", Active: true}, + }, + }, + &stubBalanceChecker{balance: 0}, + &stubModelCatalog{ + model: domain.PublicModel{ + PublicModelID: "openai/gpt-4.1-mini", + SupportsChatCompletions: true, + MaxContextWindow: 100000, + MaxOutputTokens: 16384, + InputPricePerMillion: decimal.NewFromFloat(0.75), + OutputPricePerMillion: decimal.NewFromFloat(4.50), + }, + }, + nil, + nil, + usage, + &stubCostReserver{reserveErr: domain.ErrInsufficientBalance()}, + nil, // pii filter + stubLogger{}, + ) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + }, "sk-test") + if err == nil { + t.Fatal("expected insufficient balance error") + } + gwErr, ok := err.(*domain.GatewayError) + if !ok { + t.Fatalf("expected gateway error, got %T", err) + } + if gwErr.Code != domain.ErrCodeInsufficientBalance { + t.Fatalf("error code = %s, want %s", gwErr.Code, domain.ErrCodeInsufficientBalance) + } + if usage.failureCalls != 0 { + t.Fatalf("failure calls = %d, want 0", usage.failureCalls) + } +} + +func TestGenerateServiceExecute_DoesNotRouteUnknownModel(t *testing.T) { + router := &stubRouterClient{decision: domain.RouteDecision{PublicModelID: "minimax/minimax-m2.7"}} + svc := NewGenerateService( + &stubAuthService{authCtx: domain.AuthContext{ + Account: domain.Account{ID: "acc1", Status: domain.AccountStatusActive}, + APIKey: domain.APIKey{ID: "key1", Active: true}, + }}, + &stubBalanceChecker{balance: 1_000_000_000}, + &stubModelCatalog{models: map[string]domain.PublicModel{}}, + router, + nil, + &stubUsageRecorder{}, + &stubCostReserver{}, + nil, // pii filter + stubLogger{}, + ) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "unknown/model", + }, "sk-test") + if err == nil { + t.Fatal("expected unsupported model error") + } + gwErr, ok := err.(*domain.GatewayError) + if !ok || gwErr.Code != domain.ErrCodeUnsupportedModel { + t.Fatalf("error = %v, want unsupported_model", err) + } + if router.calls != 0 { + t.Fatalf("router calls = %d, want 0", router.calls) + } +} + +func TestGenerateServiceExecute_RoutesExplicitRouterToSelectedModel(t *testing.T) { + usage := &stubUsageRecorder{} + reserver := &stubCostReserver{} + category := "long-context" + reason := "embedding_matched" + score := float32(0.456) + categoryScores := []domain.RoutingCategoryScore{ + {Category: "long-context", Score: score, Threshold: 0.30, PassedThreshold: true, Selected: true}, + {Category: "fast", Score: 0.111, Threshold: 0.30, PassedThreshold: false, Selected: false}, + } + router := &stubRouterClient{decision: domain.RouteDecision{ + PublicModelID: "minimax/minimax-m2.7", + Category: &category, + Score: &score, + CategoryScores: categoryScores, + FallbackUsed: false, + Reason: reason, + }} + selected := domain.PublicModel{ + PublicModelID: "minimax/minimax-m2.7", + DisplayName: "MiniMax M2.7", + ProviderModelID: "novita-minimax-m2.7", + ProviderConfig: domain.ProviderConfig{ProviderName: "novita"}, + SupportsChatCompletions: true, + SupportsChatCompletionsStream: true, + MaxContextWindow: 1000, + MaxOutputTokens: 100, + InputPricePerMillion: decimal.NewFromFloat(0.03), + OutputPricePerMillion: decimal.NewFromFloat(0.12), + } + svc := NewGenerateService( + &stubAuthService{authCtx: domain.AuthContext{ + Account: domain.Account{ID: "acc1", Status: domain.AccountStatusActive}, + APIKey: domain.APIKey{ID: "key1", Active: true}, + }}, + &stubBalanceChecker{balance: 1_000_000_000}, + &stubModelCatalog{ + models: map[string]domain.PublicModel{ + "minimax/minimax-m2.7": selected, + }, + routers: map[string]domain.RouterEntry{ + "dappnode/router": {RouterID: "dappnode/router", Active: true}, + }, + }, + router, + &stubProviderRegistry{provider: &stubProvider{}}, + usage, + reserver, + nil, // pii filter + stubLogger{}, + ) + + result, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "dappnode/router", + }, "sk-test") + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + if router.calls != 1 { + t.Fatalf("router calls = %d, want 1", router.calls) + } + if result.PublicModelID != "minimax/minimax-m2.7" { + t.Fatalf("result model = %q, want routed model", result.PublicModelID) + } + if usage.lastSuccessReq.RequestedModelID != "dappnode/router" { + t.Fatalf("requested model = %q, want router id", usage.lastSuccessReq.RequestedModelID) + } + if usage.lastSuccessReq.RouterID == nil || *usage.lastSuccessReq.RouterID != "dappnode/router" { + t.Fatalf("router id = %#v, want dappnode/router", usage.lastSuccessReq.RouterID) + } + if usage.lastSuccessReq.RoutedPublicModelID == nil || *usage.lastSuccessReq.RoutedPublicModelID != "minimax/minimax-m2.7" { + t.Fatalf("routed public model = %#v, want minimax/minimax-m2.7", usage.lastSuccessReq.RoutedPublicModelID) + } + if usage.lastSuccessReq.MatchedCategory == nil || *usage.lastSuccessReq.MatchedCategory != category { + t.Fatalf("matched category = %#v, want %s", usage.lastSuccessReq.MatchedCategory, category) + } + if usage.lastSuccessReq.RoutingScore == nil || *usage.lastSuccessReq.RoutingScore != score { + t.Fatalf("routing score = %#v, want %v", usage.lastSuccessReq.RoutingScore, score) + } + if len(usage.lastSuccessReq.RoutingCategoryScores) != len(categoryScores) { + t.Fatalf("category scores len = %d, want %d", len(usage.lastSuccessReq.RoutingCategoryScores), len(categoryScores)) + } + if usage.lastSuccessReq.RoutingCategoryScores[0] != categoryScores[0] || usage.lastSuccessReq.RoutingCategoryScores[1] != categoryScores[1] { + t.Fatalf("category scores = %#v, want %#v", usage.lastSuccessReq.RoutingCategoryScores, categoryScores) + } + if usage.lastSuccessReq.DecisionReason == nil || *usage.lastSuccessReq.DecisionReason != reason { + t.Fatalf("decision reason = %#v, want %s", usage.lastSuccessReq.DecisionReason, reason) + } + if usage.lastSuccessReq.FallbackUsed == nil || *usage.lastSuccessReq.FallbackUsed { + t.Fatalf("fallback used = %#v, want false", usage.lastSuccessReq.FallbackUsed) + } + if reserver.maxCost != selected.MaxCostMicrocents(nil) { + t.Fatalf("reserved max cost = %d, want %d", reserver.maxCost, selected.MaxCostMicrocents(nil)) + } +} + +func TestGenerateServiceExecute_ReturnsRouterErrorForExplicitRouterOutage(t *testing.T) { + router := &stubRouterClient{err: domain.ErrProviderUnavailable("router")} + svc := NewGenerateService( + &stubAuthService{authCtx: domain.AuthContext{ + Account: domain.Account{ID: "acc1", Status: domain.AccountStatusActive}, + APIKey: domain.APIKey{ID: "key1", Active: true}, + }}, + &stubBalanceChecker{balance: 1_000_000_000}, + &stubModelCatalog{ + models: map[string]domain.PublicModel{}, + routers: map[string]domain.RouterEntry{ + "dappnode/router": {RouterID: "dappnode/router", Active: true}, + }, + }, + router, + nil, + &stubUsageRecorder{}, + &stubCostReserver{}, + nil, // pii filter + stubLogger{}, + ) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "dappnode/router", + }, "sk-test") + if err == nil { + t.Fatal("expected router outage error") + } + gwErr, ok := err.(*domain.GatewayError) + if !ok || gwErr.Code != domain.ErrCodeProviderUnavailable { + t.Fatalf("error = %v, want provider_unavailable", err) + } + if router.calls != 1 { + t.Fatalf("router calls = %d, want 1", router.calls) + } +} diff --git a/apps/gateway/internal/application/services/list_models_service.go b/apps/gateway/internal/application/services/list_models_service.go new file mode 100644 index 0000000..a63615a --- /dev/null +++ b/apps/gateway/internal/application/services/list_models_service.go @@ -0,0 +1,139 @@ +package services + +import ( + "context" + "sync" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "golang.org/x/sync/singleflight" +) + +// CatalogCacheTTL is how long /v1/models results are cached in memory. +// The catalog changes only when admins edit models/routers, so a short TTL +// is invisible to users and collapses request floods into one DB hit per window. +const CatalogCacheTTL = 60 * time.Second + +const defaultEURToUSDRate = 1.08 + +type EURToUSDRateProvider interface { + EURToUSD(ctx context.Context) (float64, error) +} + +type fixedEURToUSDRate float64 + +func (r fixedEURToUSDRate) EURToUSD(context.Context) (float64, error) { + return float64(r), nil +} + +// ListModelsService handles the GET /v1/models flow. +// +// The endpoint is intentionally public (no auth) so the landing page and +// other unauthenticated clients can render the live model catalog. To +// protect the database from request floods, results are cached in memory +// for CatalogCacheTTL and concurrent misses are coalesced via singleflight. +type ListModelsService struct { + catalog ports.ModelCatalog + logger ports.Logger + eurToUSD EURToUSDRateProvider + fallback float64 + ttl time.Duration + now func() time.Time + + mu sync.RWMutex + cached []domain.ModelCatalogEntry + expiresAt time.Time + + sf singleflight.Group +} + +func NewListModelsService(catalog ports.ModelCatalog, logger ports.Logger) *ListModelsService { + return NewListModelsServiceWithRateProvider(catalog, logger, fixedEURToUSDRate(defaultEURToUSDRate), defaultEURToUSDRate) +} + +func NewListModelsServiceWithRateProvider(catalog ports.ModelCatalog, logger ports.Logger, eurToUSD EURToUSDRateProvider, fallback float64) *ListModelsService { + if eurToUSD == nil { + eurToUSD = fixedEURToUSDRate(defaultEURToUSDRate) + } + if fallback <= 0 { + fallback = defaultEURToUSDRate + } + return &ListModelsService{ + catalog: catalog, + logger: logger, + eurToUSD: eurToUSD, + fallback: fallback, + ttl: CatalogCacheTTL, + now: time.Now, + } +} + +func (s *ListModelsService) Execute(ctx context.Context) ([]domain.ModelCatalogEntry, error) { + if entries, ok := s.fromCache(); ok { + return entries, nil + } + + v, err, _ := s.sf.Do("list", func() (any, error) { + if entries, ok := s.fromCache(); ok { + return entries, nil + } + return s.refresh(ctx) + }) + if err != nil { + return nil, err + } + return v.([]domain.ModelCatalogEntry), nil +} + +func (s *ListModelsService) fromCache() ([]domain.ModelCatalogEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cached != nil && s.now().Before(s.expiresAt) { + return s.cached, true + } + return nil, false +} + +func (s *ListModelsService) refresh(ctx context.Context) ([]domain.ModelCatalogEntry, error) { + models, err := s.catalog.ListPublicModels(ctx) + if err != nil { + s.logger.Error("failed to list models", "error", err) + return nil, domain.ErrInternal("failed to list models") + } + + routers, err := s.catalog.ListRouters(ctx) + if err != nil { + s.logger.Error("failed to list routers", "error", err) + return nil, domain.ErrInternal("failed to list models") + } + + eurToUSDRate := s.fallback + if rate, err := s.eurToUSD.EURToUSD(ctx); err == nil && rate > 0 { + eurToUSDRate = rate + } else if err != nil { + s.logger.Warn("failed to fetch EUR->USD rate; using fallback", "error", err) + } + entries := make([]domain.ModelCatalogEntry, 0, len(models)+len(routers)) + for i := range models { + entries = append(entries, domain.ModelCatalogEntry{ + Kind: domain.CatalogKindPublicModel, + PublicModel: &models[i], + EURToUSDRate: eurToUSDRate, + }) + } + for i := range routers { + entries = append(entries, domain.ModelCatalogEntry{ + Kind: domain.CatalogKindRouter, + Router: &routers[i], + EURToUSDRate: eurToUSDRate, + }) + } + + s.mu.Lock() + s.cached = entries + s.expiresAt = s.now().Add(s.ttl) + s.mu.Unlock() + + return entries, nil +} diff --git a/apps/gateway/internal/application/services/list_models_service_test.go b/apps/gateway/internal/application/services/list_models_service_test.go new file mode 100644 index 0000000..48c71e9 --- /dev/null +++ b/apps/gateway/internal/application/services/list_models_service_test.go @@ -0,0 +1,195 @@ +package services + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +type countingCatalog struct { + models []domain.PublicModel + routers []domain.RouterEntry + modelsCalls atomic.Int64 + routersCalls atomic.Int64 + err error +} + +func (c *countingCatalog) ListPublicModels(_ context.Context) ([]domain.PublicModel, error) { + c.modelsCalls.Add(1) + if c.err != nil { + return nil, c.err + } + return c.models, nil +} + +func (c *countingCatalog) GetPublicModel(_ context.Context, _ string) (domain.PublicModel, error) { + return domain.PublicModel{}, nil +} + +func (c *countingCatalog) ListRouters(_ context.Context) ([]domain.RouterEntry, error) { + c.routersCalls.Add(1) + if c.err != nil { + return nil, c.err + } + return c.routers, nil +} + +func (c *countingCatalog) GetRouter(_ context.Context, _ string) (domain.RouterEntry, error) { + return domain.RouterEntry{}, nil +} + +type countingRateProvider struct { + calls atomic.Int64 + rate float64 +} + +func (p *countingRateProvider) EURToUSD(context.Context) (float64, error) { + p.calls.Add(1) + if p.rate <= 0 { + return defaultEURToUSDRate, nil + } + return p.rate, nil +} + +func TestListModelsService_CachesResults(t *testing.T) { + catalog := &countingCatalog{ + models: []domain.PublicModel{{PublicModelID: "openai/gpt-4.1-mini"}}, + } + svc := NewListModelsService(catalog, stubLogger{}) + + for i := 0; i < 10; i++ { + if _, err := svc.Execute(context.Background()); err != nil { + t.Fatalf("Execute returned error: %v", err) + } + } + + if got := catalog.modelsCalls.Load(); got != 1 { + t.Errorf("ListPublicModels called %d times, want 1", got) + } + if got := catalog.routersCalls.Load(); got != 1 { + t.Errorf("ListRouters called %d times, want 1", got) + } +} + +func TestListModelsService_CachesEURToUSDRateWithCatalog(t *testing.T) { + catalog := &countingCatalog{ + models: []domain.PublicModel{{PublicModelID: "openai/gpt-4.1-mini"}}, + } + rates := &countingRateProvider{rate: 1.12} + svc := NewListModelsServiceWithRateProvider(catalog, stubLogger{}, rates, defaultEURToUSDRate) + + now := time.Unix(1000, 0) + svc.now = func() time.Time { return now } + + first, err := svc.Execute(context.Background()) + if err != nil { + t.Fatalf("first Execute: %v", err) + } + for i := 0; i < 10; i++ { + entries, err := svc.Execute(context.Background()) + if err != nil { + t.Fatalf("cached Execute: %v", err) + } + if entries[0].EURToUSDRate != 1.12 { + t.Fatalf("cached rate = %v, want 1.12", entries[0].EURToUSDRate) + } + } + + if first[0].EURToUSDRate != 1.12 { + t.Fatalf("rate = %v, want 1.12", first[0].EURToUSDRate) + } + if got := rates.calls.Load(); got != 1 { + t.Fatalf("EURToUSD called %d times before TTL, want 1", got) + } + + rates.rate = 1.2 + now = now.Add(CatalogCacheTTL + time.Second) + entries, err := svc.Execute(context.Background()) + if err != nil { + t.Fatalf("post-TTL Execute: %v", err) + } + if entries[0].EURToUSDRate != 1.2 { + t.Fatalf("refreshed rate = %v, want 1.2", entries[0].EURToUSDRate) + } + if got := rates.calls.Load(); got != 2 { + t.Fatalf("EURToUSD called %d times after TTL refresh, want 2", got) + } +} + +func TestListModelsService_RefreshesAfterTTL(t *testing.T) { + catalog := &countingCatalog{ + models: []domain.PublicModel{{PublicModelID: "openai/gpt-4.1-mini"}}, + } + svc := NewListModelsService(catalog, stubLogger{}) + + now := time.Unix(1000, 0) + svc.now = func() time.Time { return now } + + if _, err := svc.Execute(context.Background()); err != nil { + t.Fatalf("first Execute: %v", err) + } + if _, err := svc.Execute(context.Background()); err != nil { + t.Fatalf("second Execute: %v", err) + } + if got := catalog.modelsCalls.Load(); got != 1 { + t.Fatalf("before TTL: ListPublicModels called %d times, want 1", got) + } + + now = now.Add(CatalogCacheTTL + time.Second) + if _, err := svc.Execute(context.Background()); err != nil { + t.Fatalf("post-TTL Execute: %v", err) + } + if got := catalog.modelsCalls.Load(); got != 2 { + t.Errorf("after TTL: ListPublicModels called %d times, want 2", got) + } +} + +func TestListModelsService_CoalescesConcurrentMisses(t *testing.T) { + catalog := &countingCatalog{ + models: []domain.PublicModel{{PublicModelID: "openai/gpt-4.1-mini"}}, + } + svc := NewListModelsService(catalog, stubLogger{}) + + var wg sync.WaitGroup + start := make(chan struct{}) + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + if _, err := svc.Execute(context.Background()); err != nil { + t.Errorf("Execute: %v", err) + } + }() + } + close(start) + wg.Wait() + + if got := catalog.modelsCalls.Load(); got != 1 { + t.Errorf("ListPublicModels called %d times under concurrency, want 1", got) + } +} + +func TestListModelsService_DoesNotCacheErrors(t *testing.T) { + catalog := &countingCatalog{err: errors.New("db down")} + svc := NewListModelsService(catalog, stubLogger{}) + + if _, err := svc.Execute(context.Background()); err == nil { + t.Fatal("expected error on first call") + } + + catalog.err = nil + catalog.models = []domain.PublicModel{{PublicModelID: "openai/gpt-4.1-mini"}} + + if _, err := svc.Execute(context.Background()); err != nil { + t.Fatalf("expected success after recovery, got: %v", err) + } + if got := catalog.modelsCalls.Load(); got != 2 { + t.Errorf("ListPublicModels called %d times, want 2 (errors must not cache)", got) + } +} diff --git a/apps/gateway/internal/application/services/pii_masking.go b/apps/gateway/internal/application/services/pii_masking.go new file mode 100644 index 0000000..2f06314 --- /dev/null +++ b/apps/gateway/internal/application/services/pii_masking.go @@ -0,0 +1,798 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "io" + "sort" + "strings" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +// maskRequest detects PII in the dynamic LLM-bound fields of `req` and +// rewrites those fields in place with deterministic placeholder tokens. +// It covers instructions, input content, assistant reasoning content, +// assistant tool-call arguments, and the OpenAI-compatible user field. +// The returned mapping must be passed to unmaskResult / piiUnmaskingStream so +// the upstream model's response can be restored to its original form. +// +// Returns (nil, nil) when the filter is disabled or there is nothing to mask. +// On filter failure, behavior depends on s.piiFailOpen: fail-closed (default) +// returns an error; fail-open logs and returns (nil, nil) so the request +// proceeds with the original text. +func (s *GenerateService) maskRequest(ctx context.Context, req *domain.GenerateRequest, piiMode string) (*domain.PIIMapping, error) { + mode, ok := domain.NormalizeAPIKeyPIIMode(piiMode) + if !ok || mode == domain.APIKeyPIIModeOff { + return nil, nil + } + if s.pii == nil || !s.pii.Enabled() { + return nil, nil + } + + original := cloneGenerateRequest(*req) + masking := &requestPIIMasker{ + service: s, + ctx: ctx, + piiMode: mode, + mapping: domain.NewPIIMapping(), + cache: make(map[string]string), + entityCounts: make(map[string]int), + surfaceCounts: make(map[string]int), + } + + maskPtr := func(surface string, ptr **string) error { + if ptr == nil || *ptr == nil || **ptr == "" { + return nil + } + masked, err := masking.maskText(surface, **ptr) + if err != nil { + return err + } + **ptr = masked + return nil + } + + maskJSONPtr := func(surface string, ptr *string) (string, error) { + if ptr == nil || *ptr == "" { + return "", nil + } + return masking.maskJSONAwareString(surface, *ptr) + } + + if err := maskPtr("instructions", &req.Instructions); err != nil { + return s.handlePIIMaskError(req, original, err) + } + if err := maskPtr("user", &req.User); err != nil { + return s.handlePIIMaskError(req, original, err) + } + for i := range req.Input { + item := &req.Input[i] + if item.Content != nil && *item.Content != "" { + var ( + masked string + err error + ) + if item.Role != nil && *item.Role == "tool" { + masked, err = masking.maskJSONAwareString("tool_result_content", *item.Content) + } else { + masked, err = masking.maskText("message_content", *item.Content) + } + if err != nil { + return s.handlePIIMaskError(req, original, err) + } + *item.Content = masked + } + if err := maskPtr("assistant_reasoning_content", &item.ReasoningContent); err != nil { + return s.handlePIIMaskError(req, original, err) + } + for j := range item.ToolCalls { + masked, err := maskJSONPtr("assistant_tool_call_arguments", &item.ToolCalls[j].ArgumentsJSON) + if err != nil { + return s.handlePIIMaskError(req, original, err) + } + if masked != "" { + item.ToolCalls[j].ArgumentsJSON = masked + } + } + } + + if masking.mapping.Len() == 0 { + return nil, nil + } + s.logger.Debug("pii masked", + "pii_mode", mode, + "tokens", masking.mapping.Len(), + "entity_counts", masking.entityCounts, + "surface_counts", masking.surfaceCounts, + ) + return masking.mapping, nil +} + +func (s *GenerateService) handlePIIMaskError(req *domain.GenerateRequest, original domain.GenerateRequest, err error) (*domain.PIIMapping, error) { + s.logger.Warn("pii filter error", "error", err, "fail_open", s.piiFailOpen) + if s.piiFailOpen { + *req = original + return nil, nil + } + return nil, domain.ErrInternal("pii filter unavailable") +} + +type requestPIIMasker struct { + service *GenerateService + ctx context.Context + piiMode string + mapping *domain.PIIMapping + cache map[string]string + entityCounts map[string]int + surfaceCounts map[string]int +} + +func (m *requestPIIMasker) maskText(surface, text string) (string, error) { + if text == "" { + return text, nil + } + if masked, ok := m.cache[text]; ok { + return masked, nil + } + entities, err := m.service.pii.Analyze(m.ctx, text, ports.PIIAnalyzeOptions{ + Language: m.service.piiLang, + Mode: m.piiMode, + }) + if err != nil { + return "", err + } + m.trackEntities(surface, entities) + masked := domain.ApplyMask(text, entities, m.mapping) + m.cache[text] = masked + return masked, nil +} + +func (m *requestPIIMasker) trackEntities(surface string, entities []domain.PIIEntity) { + for _, entity := range entities { + m.entityCounts[entity.Type]++ + if surface != "" { + m.surfaceCounts[surface]++ + } + } +} + +func (m *requestPIIMasker) maskJSONAwareString(surface, raw string) (string, error) { + if raw == "" { + return raw, nil + } + + var value any + dec := json.NewDecoder(strings.NewReader(raw)) + dec.UseNumber() + if err := dec.Decode(&value); err != nil { + return m.maskText(surface, raw) + } + if dec.Decode(&struct{}{}) != io.EOF { + return m.maskText(surface, raw) + } + + masked, changed, err := m.maskAnyWithChanged(surface, value) + if err != nil { + return "", err + } + if !changed { + return raw, nil + } + return marshalJSONNoEscape(masked) +} + +func (m *requestPIIMasker) maskAnyWithChanged(surface string, value any) (any, bool, error) { + switch v := value.(type) { + case string: + masked, err := m.maskText(surface, v) + return masked, masked != v, err + case []any: + out := make([]any, len(v)) + changed := false + for i := range v { + masked, itemChanged, err := m.maskAnyWithChanged(surface, v[i]) + if err != nil { + return nil, false, err + } + out[i] = masked + changed = changed || itemChanged + } + return out, changed, nil + case map[string]any: + out := make(map[string]any, len(v)) + changed := false + for key, val := range v { + masked, itemChanged, err := m.maskAnyWithChanged(surface, val) + if err != nil { + return nil, false, err + } + out[key] = masked + changed = changed || itemChanged + } + return out, changed, nil + default: + return value, false, nil + } +} + +func marshalJSONNoEscape(value any) (string, error) { + var b strings.Builder + enc := json.NewEncoder(&b) + enc.SetEscapeHTML(false) + if err := enc.Encode(value); err != nil { + return "", err + } + return strings.TrimSuffix(b.String(), "\n"), nil +} + +// unmaskResult walks every output text and tool-call argument field of a +// non-streaming generation result and restores PII using `mapping`. Safe to +// call with a nil mapping. +func unmaskResult(result *domain.GenerateResult, mapping *domain.PIIMapping, logger ports.Logger, logFields ...any) { + if result == nil || mapping == nil || mapping.Len() == 0 { + return + } + unresolved := make(map[string][]string) + for i := range result.Output { + item := &result.Output[i] + if item.Content != nil && *item.Content != "" { + restored := domain.Unmask(*item.Content, mapping) + trackUnresolvedPIITokens(unresolved, mapping, "assistant_content", restored) + item.Content = &restored + } + if item.ReasoningContent != nil && *item.ReasoningContent != "" { + restored := domain.Unmask(*item.ReasoningContent, mapping) + trackUnresolvedPIITokens(unresolved, mapping, "assistant_reasoning_content", restored) + item.ReasoningContent = &restored + } + for j := range item.ToolCalls { + if item.ToolCalls[j].ArgumentsJSON != "" { + item.ToolCalls[j].ArgumentsJSON = domain.Unmask(item.ToolCalls[j].ArgumentsJSON, mapping) + trackUnresolvedPIITokens(unresolved, mapping, "assistant_tool_call_arguments", item.ToolCalls[j].ArgumentsJSON) + } + } + } + logUnresolvedPIITokenGroups(logger, unresolved, append(logFields, "stream", false)...) +} + +// piiUnmaskingStream wraps a GenerationStream and restores PII tokens in +// output content, reasoning, and tool-call argument deltas before they reach +// the client. It buffers partial tokens that span chunk boundaries +// (e.g. "[PER" + "SON_1]") so substitution always sees a complete token. +type piiUnmaskingStream struct { + inner ports.GenerationStream + mapping *domain.PIIMapping + logger ports.Logger + logFields []any + textBuf strings.Builder + reasonBuf strings.Builder + toolArgBufs map[int]*strings.Builder + restoredText strings.Builder + restoredReason strings.Builder + restoredToolArgs map[int]*strings.Builder + pending []domain.StreamEvent + eofPending bool + reported bool + closed bool +} + +func newPIIUnmaskingStream(inner ports.GenerationStream, mapping *domain.PIIMapping) *piiUnmaskingStream { + return newPIIUnmaskingStreamWithLogger(inner, mapping, nil) +} + +func newPIIUnmaskingStreamWithLogger(inner ports.GenerationStream, mapping *domain.PIIMapping, logger ports.Logger, logFields ...any) *piiUnmaskingStream { + return &piiUnmaskingStream{ + inner: inner, + mapping: mapping, + logger: logger, + logFields: append([]any(nil), logFields...), + toolArgBufs: make(map[int]*strings.Builder), + restoredToolArgs: make(map[int]*strings.Builder), + } +} + +// Recv pulls the next event from the wrapped stream and rewrites any text +// delta. Non-text events pass through untouched. +func (s *piiUnmaskingStream) Recv() (domain.StreamEvent, error) { + if len(s.pending) > 0 { + event := s.pending[0] + s.pending = s.pending[1:] + return event, nil + } + if s.eofPending { + return domain.StreamEvent{}, io.EOF + } + + event, err := s.inner.Recv() + + // When the upstream stream finishes, flush any text we held back so the + // client doesn't lose a trailing fragment. The tail is returned with a nil + // error first because the HTTP handler ignores events returned with EOF. + if err == io.EOF { + if tails := s.flushBufferedEvents(); len(tails) > 0 { + s.reportUnresolvedToolArgs() + s.eofPending = true + s.pending = append(s.pending, tails[1:]...) + return tails[0], nil + } + s.reportUnresolvedToolArgs() + return event, err + } + + if err != nil { + return event, err + } + + if event.Type == domain.StreamEventCompleted { + return s.rewriteCompleted(event), nil + } + + if event.Type == domain.StreamEventError && event.Error != nil { + if sanitized, ok := sanitizeErrorWithPIIMapping(event.Error, s.mapping).(*domain.GatewayError); ok { + event.Error = sanitized + } + return event, nil + } + + if event.ContentDelta != nil && *event.ContentDelta != "" { + emit := feedPIIStreamBuffer(&s.textBuf, *event.ContentDelta, s.mapping) + s.trackRestoredText(emit) + event.ContentDelta = &emit + } + if event.ReasoningDelta != nil && *event.ReasoningDelta != "" { + emit := feedPIIStreamBuffer(&s.reasonBuf, *event.ReasoningDelta, s.mapping) + s.trackRestoredReasoning(emit) + event.ReasoningDelta = &emit + } + if event.ToolCallDelta != nil && event.ToolCallDelta.ArgumentsDelta != nil && *event.ToolCallDelta.ArgumentsDelta != "" { + buf := s.toolArgBuffer(event.ToolCallDelta.Index) + emit := feedPIIStreamBuffer(buf, *event.ToolCallDelta.ArgumentsDelta, s.mapping) + s.trackRestoredToolArg(event.ToolCallDelta.Index, emit) + event.ToolCallDelta.ArgumentsDelta = &emit + } + return event, nil +} + +func (s *piiUnmaskingStream) rewriteCompleted(event domain.StreamEvent) domain.StreamEvent { + if event.ContentDelta != nil && *event.ContentDelta != "" { + content := feedPIIStreamBuffer(&s.textBuf, *event.ContentDelta, s.mapping) + s.trackRestoredText(content) + event.ContentDelta = &content + } + if event.ReasoningDelta != nil && *event.ReasoningDelta != "" { + reasoning := feedPIIStreamBuffer(&s.reasonBuf, *event.ReasoningDelta, s.mapping) + s.trackRestoredReasoning(reasoning) + event.ReasoningDelta = &reasoning + } + if event.ToolCallDelta != nil && event.ToolCallDelta.ArgumentsDelta != nil && *event.ToolCallDelta.ArgumentsDelta != "" { + buf := s.toolArgBuffer(event.ToolCallDelta.Index) + args := feedPIIStreamBuffer(buf, *event.ToolCallDelta.ArgumentsDelta, s.mapping) + s.trackRestoredToolArg(event.ToolCallDelta.Index, args) + event.ToolCallDelta.ArgumentsDelta = &args + } + + prefixes := completedDeltaPrefixEvents(event) + tails := s.flushBufferedEvents() + s.reportUnresolvedToolArgs() + if len(prefixes) == 0 && len(tails) == 0 { + return event + } + + event.ContentDelta = nil + event.ReasoningDelta = nil + event.ToolCallDelta = nil + + events := make([]domain.StreamEvent, 0, len(prefixes)+len(tails)+1) + events = append(events, prefixes...) + events = append(events, tails...) + events = append(events, event) + s.pending = append(s.pending, events[1:]...) + return events[0] +} + +// Close releases the wrapped stream. Buffered text is discarded — Recv already +// returns the tail on EOF, and a Close before EOF means the client gave up. +func (s *piiUnmaskingStream) Close() error { + s.closed = true + s.textBuf.Reset() + s.reasonBuf.Reset() + for _, buf := range s.toolArgBufs { + buf.Reset() + } + return s.inner.Close() +} + +func (s *piiUnmaskingStream) toolArgBuffer(index int) *strings.Builder { + if buf, ok := s.toolArgBufs[index]; ok { + return buf + } + buf := &strings.Builder{} + s.toolArgBufs[index] = buf + return buf +} + +func (s *piiUnmaskingStream) flushBufferedEvents() []domain.StreamEvent { + events := make([]domain.StreamEvent, 0, 2+len(s.toolArgBufs)) + if tail := flushPIIStreamBuffer(&s.textBuf, s.mapping); tail != "" { + s.trackRestoredText(tail) + events = append(events, contentDeltaEvent(tail)) + } + if tail := flushPIIStreamBuffer(&s.reasonBuf, s.mapping); tail != "" { + s.trackRestoredReasoning(tail) + events = append(events, reasoningDeltaEvent(tail)) + } + indexes := make([]int, 0, len(s.toolArgBufs)) + for index := range s.toolArgBufs { + indexes = append(indexes, index) + } + sort.Ints(indexes) + for _, index := range indexes { + buf := s.toolArgBufs[index] + if tail := flushPIIStreamBuffer(buf, s.mapping); tail != "" { + s.trackRestoredToolArg(index, tail) + events = append(events, toolCallArgumentDeltaEvent(index, tail)) + } + } + return events +} + +func (s *piiUnmaskingStream) trackRestoredToolArg(index int, delta string) { + if delta == "" { + return + } + buf, ok := s.restoredToolArgs[index] + if !ok { + buf = &strings.Builder{} + s.restoredToolArgs[index] = buf + } + buf.WriteString(delta) +} + +func (s *piiUnmaskingStream) trackRestoredText(delta string) { + if delta != "" { + s.restoredText.WriteString(delta) + } +} + +func (s *piiUnmaskingStream) trackRestoredReasoning(delta string) { + if delta != "" { + s.restoredReason.WriteString(delta) + } +} + +func (s *piiUnmaskingStream) reportUnresolvedToolArgs() { + if s.reported { + return + } + s.reported = true + unresolved := make(map[string][]string) + trackUnresolvedPIITokens(unresolved, s.mapping, "assistant_content", s.restoredText.String()) + trackUnresolvedPIITokens(unresolved, s.mapping, "assistant_reasoning_content", s.restoredReason.String()) + for _, buf := range s.restoredToolArgs { + trackUnresolvedPIITokens(unresolved, s.mapping, "assistant_tool_call_arguments", buf.String()) + } + logUnresolvedPIITokenGroups(s.logger, unresolved, append(s.logFields, "stream", true)...) +} + +// feedPIIStreamBuffer appends `delta` to the internal buffer and returns the largest prefix +// that contains only complete (or no) placeholder tokens, with those tokens +// already replaced by their original values. +// +// The buffer holds back from the last unmatched '[' onward so we never split +// a token across two emitted chunks. +func feedPIIStreamBuffer(buf *strings.Builder, delta string, mapping *domain.PIIMapping) string { + buf.WriteString(delta) + full := buf.String() + + // Find the last '[' that has no matching ']' after it. Everything up to + // that index is safe to emit; everything from it onward stays buffered. + cut := len(full) + for i := len(full) - 1; i >= 0; i-- { + if full[i] == '[' { + if !strings.ContainsRune(full[i:], ']') { + cut = i + break + } + // Has a closing bracket — entire string is safe. + break + } + if full[i] == ']' { + // We hit a closing bracket before any opener — safe. + break + } + } + if bareCut := bareAliasHoldStart(full, mapping); bareCut >= 0 && bareCut < cut { + cut = bareCut + } + + safe := full[:cut] + tail := full[cut:] + + buf.Reset() + buf.WriteString(tail) + + return domain.Unmask(safe, mapping) +} + +func bareAliasHoldStart(text string, mapping *domain.PIIMapping) int { + if mapping == nil || mapping.Len() == 0 || text == "" { + return -1 + } + aliases := mapping.BareTokenAliases() + if len(aliases) == 0 { + return -1 + } + maxLen := 0 + for _, alias := range aliases { + if len(alias) > maxLen { + maxLen = len(alias) + } + } + startAt := len(text) - maxLen + if startAt < 0 { + startAt = 0 + } + for start := len(text) - 1; start >= startAt; start-- { + if !isBareTokenBoundary(text, start-1) { + continue + } + suffix := text[start:] + for _, alias := range aliases { + if len(suffix) <= len(alias) && strings.EqualFold(alias[:len(suffix)], suffix) { + return start + } + } + } + return -1 +} + +func isBareTokenBoundary(text string, index int) bool { + if index < 0 || index >= len(text) { + return true + } + c := text[index] + return !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_') +} + +// flushPIIStreamBuffer returns any buffered text, unmasking whatever placeholders are +// complete and leaving partial ones intact. +func flushPIIStreamBuffer(buf *strings.Builder, mapping *domain.PIIMapping) string { + if buf.Len() == 0 { + return "" + } + tail := buf.String() + buf.Reset() + return domain.Unmask(tail, mapping) +} + +func completedDeltaPrefixEvents(event domain.StreamEvent) []domain.StreamEvent { + events := make([]domain.StreamEvent, 0, 2) + if (event.ContentDelta != nil && *event.ContentDelta != "") || (event.ReasoningDelta != nil && *event.ReasoningDelta != "") { + events = append(events, domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ContentDelta: event.ContentDelta, + ReasoningDelta: event.ReasoningDelta, + }) + } + if event.ToolCallDelta != nil && event.ToolCallDelta.ArgumentsDelta != nil && *event.ToolCallDelta.ArgumentsDelta != "" { + events = append(events, domain.StreamEvent{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: event.ToolCallDelta, + }) + } + return events +} + +func contentDeltaEvent(delta string) domain.StreamEvent { + return domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ContentDelta: &delta, + } +} + +func reasoningDeltaEvent(delta string) domain.StreamEvent { + return domain.StreamEvent{ + Type: domain.StreamEventOutputTextDelta, + ReasoningDelta: &delta, + } +} + +func toolCallArgumentDeltaEvent(index int, delta string) domain.StreamEvent { + return domain.StreamEvent{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: &domain.ToolCallDelta{ + Index: index, + ArgumentsDelta: &delta, + }, + } +} + +func sanitizeErrorWithPIIMapping(err error, mapping *domain.PIIMapping) error { + if err == nil || mapping == nil || mapping.Len() == 0 { + return err + } + var gwErr *domain.GatewayError + if !errors.As(err, &gwErr) { + return errors.New(mapping.MaskKnownOriginals(err.Error())) + } + cp := *gwErr + cp.Message = mapping.MaskKnownOriginals(cp.Message) + if len(gwErr.Metadata) > 0 { + cp.Metadata = make(map[string]any, len(gwErr.Metadata)) + for key, value := range gwErr.Metadata { + cp.Metadata[key] = sanitizePIILogValue(value, mapping) + } + } + return &cp +} + +func trackUnresolvedPIITokens(groups map[string][]string, mapping *domain.PIIMapping, surface, text string) { + if groups == nil || mapping == nil || mapping.Len() == 0 || text == "" { + return + } + tokens := mapping.UnresolvedTokens(text) + if len(tokens) == 0 { + return + } + groups[surface] = mergeStringSets(groups[surface], tokens) +} + +func logUnresolvedPIITokenGroups(logger ports.Logger, groups map[string][]string, fields ...any) { + if logger == nil || len(groups) == 0 { + return + } + surfaces := make([]string, 0, len(groups)) + for surface := range groups { + surfaces = append(surfaces, surface) + } + sort.Strings(surfaces) + for _, surface := range surfaces { + tokens := groups[surface] + sort.Strings(tokens) + logFields := append([]any(nil), fields...) + logFields = append(logFields, + "surface", surface, + "unresolved_token_count", len(tokens), + "unresolved_tokens", tokens, + ) + logger.Warn("pii restoration unresolved tokens", logFields...) + } +} + +func mergeStringSets(existing []string, incoming []string) []string { + seen := make(map[string]struct{}, len(existing)+len(incoming)) + out := make([]string, 0, len(existing)+len(incoming)) + for _, token := range existing { + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + for _, token := range incoming { + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + return out +} + +func sanitizePIILogValue(value any, mapping *domain.PIIMapping) any { + switch v := value.(type) { + case string: + return mapping.MaskKnownOriginals(v) + case []any: + out := make([]any, len(v)) + for i := range v { + out[i] = sanitizePIILogValue(v[i], mapping) + } + return out + case []string: + out := make([]string, len(v)) + for i := range v { + out[i] = mapping.MaskKnownOriginals(v[i]) + } + return out + case map[string]any: + out := make(map[string]any, len(v)) + for key, item := range v { + out[key] = sanitizePIILogValue(item, mapping) + } + return out + default: + return value + } +} + +func cloneGenerateRequest(req domain.GenerateRequest) domain.GenerateRequest { + clone := req + clone.RouterID = cloneStringPtr(req.RouterID) + clone.RoutedPublicModelID = cloneStringPtr(req.RoutedPublicModelID) + clone.MatchedCategory = cloneStringPtr(req.MatchedCategory) + clone.DecisionReason = cloneStringPtr(req.DecisionReason) + clone.Instructions = cloneStringPtr(req.Instructions) + clone.User = cloneStringPtr(req.User) + clone.ServiceTier = cloneStringPtr(req.ServiceTier) + if req.Input != nil { + clone.Input = make([]domain.InputItem, len(req.Input)) + for i := range req.Input { + clone.Input[i] = req.Input[i] + clone.Input[i].Role = cloneStringPtr(req.Input[i].Role) + clone.Input[i].Content = cloneStringPtr(req.Input[i].Content) + clone.Input[i].ReasoningContent = cloneStringPtr(req.Input[i].ReasoningContent) + clone.Input[i].ToolCallID = cloneStringPtr(req.Input[i].ToolCallID) + if req.Input[i].ToolCalls != nil { + clone.Input[i].ToolCalls = append([]domain.ToolCall(nil), req.Input[i].ToolCalls...) + } + } + } + if req.Tools != nil { + clone.Tools = make([]domain.ToolDefinition, len(req.Tools)) + for i := range req.Tools { + clone.Tools[i] = req.Tools[i] + clone.Tools[i].Parameters = cloneMap(req.Tools[i].Parameters) + } + } + if req.Stop != nil { + clone.Stop = append([]string(nil), req.Stop...) + } + if req.RoutingCategoryScores != nil { + clone.RoutingCategoryScores = append([]domain.RoutingCategoryScore(nil), req.RoutingCategoryScores...) + } + clone.Metadata = cloneMap(req.Metadata) + clone.ProviderOptions = cloneMap(req.ProviderOptions) + if req.TextConfig != nil { + tc := *req.TextConfig + tc.FormatType = cloneStringPtr(req.TextConfig.FormatType) + tc.JSONSchema = cloneMap(req.TextConfig.JSONSchema) + clone.TextConfig = &tc + } + if req.ToolChoice != nil { + tc := *req.ToolChoice + tc.FunctionName = cloneStringPtr(req.ToolChoice.FunctionName) + clone.ToolChoice = &tc + } + if req.LogitBias != nil { + clone.LogitBias = make(map[string]int, len(req.LogitBias)) + for key, value := range req.LogitBias { + clone.LogitBias[key] = value + } + } + return clone +} + +func cloneStringPtr(in *string) *string { + if in == nil { + return nil + } + out := *in + return &out +} + +func cloneMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + raw, err := json.Marshal(in) + if err != nil { + out := make(map[string]any, len(in)) + for key, value := range in { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(raw, &out); err != nil { + shallow := make(map[string]any, len(in)) + for key, value := range in { + shallow[key] = value + } + return shallow + } + return out +} diff --git a/apps/gateway/internal/application/services/pii_masking_test.go b/apps/gateway/internal/application/services/pii_masking_test.go new file mode 100644 index 0000000..3721413 --- /dev/null +++ b/apps/gateway/internal/application/services/pii_masking_test.go @@ -0,0 +1,853 @@ +package services + +import ( + "context" + "io" + "strings" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/shopspring/decimal" +) + +// fakePIIFilter is a test double for ports.PIIFilter. It returns the +// pre-configured entities when called and tracks invocation counts so tests +// can assert the masking pipeline ran. +type fakePIIFilter struct { + enabled bool + err error + calls int + lastText string + lastOpts ports.PIIAnalyzeOptions + // byText optionally maps an exact input string to entity spans, so tests + // can return different detections for different request fields. + byText map[string][]domain.PIIEntity + // entities is returned when byText does not match. + entities []domain.PIIEntity +} + +type loggedWarning struct { + msg string + fields []any +} + +type captureLogger struct { + warnings []loggedWarning +} + +func (l *captureLogger) Debug(string, ...any) {} +func (l *captureLogger) Info(string, ...any) {} +func (l *captureLogger) Error(string, ...any) {} +func (l *captureLogger) Warn(msg string, fields ...any) { + l.warnings = append(l.warnings, loggedWarning{msg: msg, fields: append([]any(nil), fields...)}) +} + +func (f *fakePIIFilter) Enabled() bool { return f.enabled } + +func (f *fakePIIFilter) Analyze(_ context.Context, text string, opts ports.PIIAnalyzeOptions) ([]domain.PIIEntity, error) { + f.calls++ + f.lastText = text + f.lastOpts = opts + if f.err != nil { + return nil, f.err + } + if f.byText != nil { + if v, ok := f.byText[text]; ok { + return v, nil + } + return nil, nil + } + return f.entities, nil +} + +// captureProvider records the exact GenerateRequest the service forwards so +// tests can assert that user-supplied PII never reaches the upstream call. +type captureProvider struct { + lastReq domain.GenerateRequest + respondMsg string + respondOutput []domain.OutputItem +} + +func (p *captureProvider) Generate(_ context.Context, req domain.GenerateRequest, model domain.PublicModel) (domain.GenerateResult, error) { + p.lastReq = req + msg := p.respondMsg + role := "assistant" + output := p.respondOutput + if output == nil { + output = []domain.OutputItem{ + {Type: domain.OutputItemTypeMessage, Role: &role, Content: &msg}, + } + } + return domain.GenerateResult{ + ID: "resp-1", + PublicModelID: req.PublicModelID, + ProviderName: model.ProviderConfig.ProviderName, + ProviderModelID: model.ProviderModelID, + Output: output, + }, nil +} + +func (p *captureProvider) StreamGenerate(context.Context, domain.GenerateRequest, domain.PublicModel) (ports.GenerationStream, error) { + return nil, nil +} + +func buildPIITestService(filter ports.PIIFilter, prov *captureProvider) (*GenerateService, *stubUsageRecorder) { + return buildPIITestServiceWithPIIMode(filter, prov, domain.APIKeyPIIModeBalanced) +} + +func buildPIITestServiceWithPIIMode(filter ports.PIIFilter, prov *captureProvider, piiMode string) (*GenerateService, *stubUsageRecorder) { + usage := &stubUsageRecorder{} + selected := domain.PublicModel{ + PublicModelID: "openai/gpt-4.1-mini", + ProviderModelID: "gpt-4.1-mini", + ProviderConfig: domain.ProviderConfig{ProviderName: "openai"}, + SupportsChatCompletions: true, + SupportsTools: true, + MaxContextWindow: 100000, + MaxOutputTokens: 16384, + InputPricePerMillion: decimal.NewFromFloat(0.75), + OutputPricePerMillion: decimal.NewFromFloat(4.50), + } + svc := NewGenerateService( + &stubAuthService{authCtx: domain.AuthContext{ + Account: domain.Account{ID: "acc1", Status: domain.AccountStatusActive}, + APIKey: domain.APIKey{ID: "key1", Active: true, PIIMode: piiMode}, + }}, + &stubBalanceChecker{balance: 1_000_000_000}, + &stubModelCatalog{model: selected}, + nil, + &captureProviderRegistry{p: prov}, + usage, + &stubCostReserver{}, + filter, + stubLogger{}, + ) + return svc, usage +} + +// captureProviderRegistry returns the captureProvider as a GenerationProvider. +type captureProviderRegistry struct { + p *captureProvider +} + +func (r *captureProviderRegistry) GetProvider(string) (ports.GenerationProvider, error) { + return r.p, nil +} + +func newMessageInput(text string) []domain.InputItem { + role := "user" + t := text + return []domain.InputItem{ + {Type: domain.InputItemTypeMessage, Role: &role, Content: &t}, + } +} + +// captureProvider implements ports.GenerationProvider directly so we can +// observe the exact request the service forwards upstream. + +func TestGenerateServiceExecute_MasksRequestAndUnmasksResponse(t *testing.T) { + prompt := "My name is John Smith, email john@x.com" + // Pre-computed byte offsets for "John Smith" (11..21) and "john@x.com" (29..39). + filter := &fakePIIFilter{ + enabled: true, + entities: []domain.PIIEntity{ + {Type: "PERSON", Start: 11, End: 21, Score: 0.99}, + {Type: "EMAIL", Start: 29, End: 39, Score: 0.99}, + }, + } + prov := &captureProvider{respondMsg: "Hi [PERSON_1], I see your email [EMAIL_1]."} + + svc, _ := buildPIITestService(filter, prov) + + result, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput(prompt), + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + + // The provider must have seen only placeholders. + gotContent := *prov.lastReq.Input[0].Content + if strings.Contains(gotContent, "John Smith") || strings.Contains(gotContent, "john@x.com") { + t.Fatalf("upstream prompt still contains PII: %q", gotContent) + } + if !strings.Contains(gotContent, "[PERSON_1]") || !strings.Contains(gotContent, "[EMAIL_1]") { + t.Fatalf("upstream prompt missing placeholders: %q", gotContent) + } + + // The client-facing result must have the original PII restored. + if got := *result.Output[0].Content; got != "Hi John Smith, I see your email john@x.com." { + t.Fatalf("response not unmasked: %q", got) + } + if filter.calls != 1 { + t.Fatalf("filter calls = %d, want 1", filter.calls) + } +} + +func TestGenerateServiceExecute_MasksInstructionsContentReasoningAndToolArgs(t *testing.T) { + instructions := "Call Jane at jane@example.com" + prompt := "Email Jane at jane@example.com" + toolArgs := `{"email":"jane@example.com","note":"Call Jane"}` + role := "user" + assistantInputRole := "assistant" + assistantRole := "assistant" + reasoning := "I should contact Jane" + filter := &fakePIIFilter{ + enabled: true, + byText: map[string][]domain.PIIEntity{ + instructions: { + {Type: "PERSON", Start: 5, End: 9, Score: 0.99}, + {Type: "EMAIL", Start: 13, End: 29, Score: 0.99}, + }, + prompt: { + {Type: "PERSON", Start: 6, End: 10, Score: 0.99}, + {Type: "EMAIL", Start: 14, End: 30, Score: 0.99}, + }, + reasoning: { + {Type: "PERSON", Start: 17, End: 21, Score: 0.99}, + }, + "jane@example.com": { + {Type: "EMAIL", Start: 0, End: 16, Score: 0.99}, + }, + "Call Jane": { + {Type: "PERSON", Start: 5, End: 9, Score: 0.99}, + }, + }, + } + prov := &captureProvider{ + respondOutput: []domain.OutputItem{ + { + Type: domain.OutputItemTypeMessage, + Role: &assistantRole, + Content: strPtr("Message [PERSON_1] at [EMAIL_1]."), + ToolCalls: []domain.ToolCall{ + {ID: "call_1", Name: "send_email", ArgumentsJSON: `{"email":"[EMAIL_1]"}`}, + }, + }, + }, + } + svc, _ := buildPIITestService(filter, prov) + + result, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Instructions: &instructions, + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: &role, + Content: strPtr(prompt), + }, + { + Type: domain.InputItemTypeMessage, + Role: &assistantInputRole, + ReasoningContent: &reasoning, + ToolCalls: []domain.ToolCall{ + {ID: "call_1", Name: "send_email", ArgumentsJSON: toolArgs}, + }, + }, + }, + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + + if got := *prov.lastReq.Instructions; strings.Contains(got, "Jane") || strings.Contains(got, "jane@example.com") { + t.Fatalf("upstream instructions still contain PII: %q", got) + } + if got := *prov.lastReq.Input[0].Content; strings.Contains(got, "Jane") || strings.Contains(got, "jane@example.com") { + t.Fatalf("upstream content still contains PII: %q", got) + } + if got := *prov.lastReq.Input[1].ReasoningContent; strings.Contains(got, "Jane") { + t.Fatalf("upstream reasoning content still contains PII: %q", got) + } + wantArgs := `{"email":"[EMAIL_1]","note":"Call [PERSON_1]"}` + if got := prov.lastReq.Input[1].ToolCalls[0].ArgumentsJSON; got != wantArgs { + t.Fatalf("tool-call arguments not masked:\n got = %q\nwant = %q", got, wantArgs) + } + if filter.calls != 5 { + t.Fatalf("filter calls = %d, want 5 for instructions, content, reasoning, and JSON values", filter.calls) + } + + if got := *result.Output[0].Content; got != "Message Jane at jane@example.com." { + t.Fatalf("response content not restored: %q", got) + } + if got := result.Output[0].ToolCalls[0].ArgumentsJSON; got != `{"email":"jane@example.com"}` { + t.Fatalf("response tool-call arguments not restored, got %q", got) + } +} + +func TestGenerateServiceExecute_MasksToolResultJSONAndUserField(t *testing.T) { + toolRole := "tool" + user := "alice@example.com" + toolResult := `{"customer":{"name":"Alice","email":"alice@example.com"},"paid":false}` + filter := &fakePIIFilter{ + enabled: true, + byText: map[string][]domain.PIIEntity{ + "user prompt": { + {Type: "PERSON", Start: 0, End: 0, Score: 0.99}, // invalid span proves no accidental fallback masking + }, + "alice@example.com": { + {Type: "EMAIL", Start: 0, End: 17, Score: 0.99}, + }, + "Alice": { + {Type: "PERSON", Start: 0, End: 5, Score: 0.99}, + }, + }, + } + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + User: &user, + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: &toolRole, + Content: &toolResult, + ToolCallID: strPtr("call_1"), + }, + }, + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + + if got := *prov.lastReq.User; got != "[EMAIL_1]" { + t.Fatalf("user field not masked: %q", got) + } + gotContent := *prov.lastReq.Input[0].Content + if strings.Contains(gotContent, "Alice") || strings.Contains(gotContent, "alice@example.com") { + t.Fatalf("tool result still contains PII: %q", gotContent) + } + if !strings.Contains(gotContent, "[PERSON_1]") || !strings.Contains(gotContent, "[EMAIL_1]") { + t.Fatalf("tool result missing placeholders: %q", gotContent) + } +} + +func TestGenerateServiceExecute_MasksInvalidJSONToolArgumentsAsPlainText(t *testing.T) { + assistantRole := "assistant" + args := `email=jane@example.com note=Call Jane` + filter := &fakePIIFilter{ + enabled: true, + byText: map[string][]domain.PIIEntity{ + args: { + {Type: "EMAIL", Start: 6, End: 22, Score: 0.99}, + {Type: "PERSON", Start: 33, End: 37, Score: 0.99}, + }, + }, + } + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: []domain.InputItem{ + { + Type: domain.InputItemTypeMessage, + Role: &assistantRole, + ToolCalls: []domain.ToolCall{ + {ID: "call_1", Name: "send_email", ArgumentsJSON: args}, + }, + }, + }, + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + + want := `email=[EMAIL_1] note=Call [PERSON_1]` + if got := prov.lastReq.Input[0].ToolCalls[0].ArgumentsJSON; got != want { + t.Fatalf("invalid JSON tool args not masked as plain text:\n got = %q\nwant = %q", got, want) + } +} + +func TestGenerateServiceExecute_StaticFieldsAreNotScanned(t *testing.T) { + description := "Send email to Jane" + filter := &fakePIIFilter{ + enabled: true, + byText: map[string][]domain.PIIEntity{ + description: { + {Type: "PERSON", Start: 14, End: 18, Score: 0.99}, + }, + }, + } + + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("hello"), + Tools: []domain.ToolDefinition{ + {Name: "send_email", Description: description, Parameters: map[string]any{"type": "object"}}, + }, + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + if got := prov.lastReq.Tools[0].Description; got != description { + t.Fatalf("static description should not be masked, got %q", got) + } +} + +func TestGenerateServiceExecute_FailsClosedWhenFilterErrors(t *testing.T) { + filter := &fakePIIFilter{enabled: true, err: io.ErrUnexpectedEOF} + prov := &captureProvider{} + svc, usage := buildPIITestService(filter, prov) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("hello"), + }, "sk-test") + if err == nil { + t.Fatal("expected fail-closed error when filter errors") + } + if usage.failureCalls == 0 { + t.Fatalf("expected failure recorded") + } + if prov.lastReq.PublicModelID != "" { + t.Fatalf("provider should not have been called") + } +} + +func TestGenerateServiceExecute_FailOpenContinuesWhenFilterErrors(t *testing.T) { + filter := &fakePIIFilter{enabled: true, err: io.ErrUnexpectedEOF} + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + svc.SetPIIOptions("en", true) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("hello"), + }, "sk-test") + if err != nil { + t.Fatalf("Execute (fail-open): %v", err) + } + if got := *prov.lastReq.Input[0].Content; got != "hello" { + t.Fatalf("prompt should pass through on fail-open, got %q", got) + } +} + +func TestGenerateServiceExecute_BypassesWhenFilterDisabled(t *testing.T) { + filter := &fakePIIFilter{enabled: false} + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("My name is John Smith"), + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + if filter.calls != 0 { + t.Fatalf("filter should not be called when disabled") + } + if got := *prov.lastReq.Input[0].Content; got != "My name is John Smith" { + t.Fatalf("prompt should pass through, got %q", got) + } +} + +func TestGenerateServiceExecute_BypassesWhenAPIKeyPIIModeOff(t *testing.T) { + filter := &fakePIIFilter{enabled: true, entities: []domain.PIIEntity{{Type: "PERSON", Start: 11, End: 21, Score: 0.99}}} + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestServiceWithPIIMode(filter, prov, domain.APIKeyPIIModeOff) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("My name is John Smith"), + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + if filter.calls != 0 { + t.Fatalf("filter should not be called when pii_mode=off, calls=%d", filter.calls) + } + if got := *prov.lastReq.Input[0].Content; got != "My name is John Smith" { + t.Fatalf("prompt should pass through, got %q", got) + } +} + +func TestGenerateServiceExecute_PassesAPIKeyPIIModeToFilter(t *testing.T) { + filter := &fakePIIFilter{enabled: true, entities: []domain.PIIEntity{{Type: "EMAIL_ADDRESS", Start: 6, End: 22, Score: 0.99}}} + prov := &captureProvider{respondMsg: "Email [EMAIL_ADDRESS_1]"} + svc, _ := buildPIITestServiceWithPIIMode(filter, prov, domain.APIKeyPIIModeLow) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Input: newMessageInput("email jane@example.com"), + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + if filter.lastOpts.Mode != domain.APIKeyPIIModeLow { + t.Fatalf("filter mode = %q, want low", filter.lastOpts.Mode) + } +} + +func TestGenerateServiceExecute_SkipsEmptyPIIFields(t *testing.T) { + instructions := "" + content := "" + filter := &fakePIIFilter{enabled: true, entities: []domain.PIIEntity{{Type: "PERSON", Start: 0, End: 4}}} + prov := &captureProvider{respondMsg: "ok"} + svc, _ := buildPIITestService(filter, prov) + + _, _, err := svc.Execute(context.Background(), domain.EndpointChatCompletions, domain.GenerateRequest{ + PublicModelID: "openai/gpt-4.1-mini", + Instructions: &instructions, + Input: []domain.InputItem{ + {Type: domain.InputItemTypeMessage, Content: &content}, + }, + }, "sk-test") + if err != nil { + t.Fatalf("Execute: %v", err) + } + if filter.calls != 0 { + t.Fatalf("filter calls = %d, want 0 for empty fields", filter.calls) + } + if got := *prov.lastReq.Input[0].Content; got != "" { + t.Fatalf("empty content should pass through, got %q", got) + } +} + +// --- streaming unmask wrapper tests --- + +// scriptedStream replays a fixed sequence of events to the wrapper under test. +type scriptedStream struct { + events []domain.StreamEvent + errs []error + idx int + closed bool +} + +func (s *scriptedStream) Recv() (domain.StreamEvent, error) { + if s.idx >= len(s.events) { + return domain.StreamEvent{}, io.EOF + } + ev := s.events[s.idx] + var err error + if s.idx < len(s.errs) { + err = s.errs[s.idx] + } + s.idx++ + return ev, err +} + +func (s *scriptedStream) Close() error { s.closed = true; return nil } + +func textDelta(s string) domain.StreamEvent { + v := s + return domain.StreamEvent{Type: domain.StreamEventOutputTextDelta, ContentDelta: &v} +} + +func completedEvent() domain.StreamEvent { + reason := "stop" + return domain.StreamEvent{Type: domain.StreamEventCompleted, FinishReason: &reason} +} + +func toolCallDelta(args string) domain.StreamEvent { + return domain.StreamEvent{ + Type: domain.StreamEventToolCallDelta, + ToolCallDelta: &domain.ToolCallDelta{ + Index: 0, + ArgumentsDelta: &args, + }, + } +} + +func collectStream(t *testing.T, st ports.GenerationStream) string { + t.Helper() + var b strings.Builder + for { + ev, err := st.Recv() + if err == io.EOF { + return b.String() + } + if err != nil { + t.Fatalf("Recv error: %v", err) + } + if ev.ContentDelta != nil { + b.WriteString(*ev.ContentDelta) + } + } +} + +func collectStreamUntilCompleted(t *testing.T, st ports.GenerationStream) string { + t.Helper() + var b strings.Builder + for { + ev, err := st.Recv() + if err == io.EOF { + return b.String() + } + if err != nil { + t.Fatalf("Recv error: %v", err) + } + if ev.ContentDelta != nil { + b.WriteString(*ev.ContentDelta) + } + if ev.Type == domain.StreamEventCompleted { + for { + if _, err := st.Recv(); err != nil { + return b.String() + } + } + } + } +} + +func TestPIIUnmaskingStream_PassthroughWithoutTokens(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + inner := &scriptedStream{events: []domain.StreamEvent{textDelta("hello "), textDelta("world")}} + got := collectStream(t, newPIIUnmaskingStream(inner, m)) + if got != "hello world" { + t.Fatalf("got %q, want hello world", got) + } +} + +func TestPIIUnmaskingStream_RestoresCompleteToken(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + inner := &scriptedStream{events: []domain.StreamEvent{textDelta("hi [PERSON_1] there")}} + got := collectStream(t, newPIIUnmaskingStream(inner, m)) + if got != "hi John there" { + t.Fatalf("got %q", got) + } +} + +func TestPIIUnmaskingStream_BuffersTokenSplitAcrossChunks(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + inner := &scriptedStream{events: []domain.StreamEvent{ + textDelta("hi [PER"), + textDelta("SON_1] there"), + }} + got := collectStream(t, newPIIUnmaskingStream(inner, m)) + if got != "hi John there" { + t.Fatalf("got %q", got) + } +} + +func TestPIIUnmaskingStream_KeepsUnknownTokensVerbatim(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + inner := &scriptedStream{events: []domain.StreamEvent{textDelta("see [GHOST_5] and [PERSON_1]")}} + got := collectStream(t, newPIIUnmaskingStream(inner, m)) + if got != "see [GHOST_5] and John" { + t.Fatalf("got %q", got) + } +} + +func TestPIIUnmaskingStream_FlushesTailOnEOF(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + // Last chunk leaves an unterminated `[PER` in the buffer. + inner := &scriptedStream{events: []domain.StreamEvent{textDelta("end with [PER")}} + got := collectStream(t, newPIIUnmaskingStream(inner, m)) + if got != "end with [PER" { + t.Fatalf("got %q", got) + } +} + +func TestPIIUnmaskingStream_FlushesTailBeforeCompleted(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + inner := &scriptedStream{events: []domain.StreamEvent{ + textDelta("hi [PER"), + completedEvent(), + }} + got := collectStreamUntilCompleted(t, newPIIUnmaskingStream(inner, m)) + if got != "hi [PER" { + t.Fatalf("got %q", got) + } +} + +func TestPIIUnmaskingStream_RestoresToolCallArgumentDeltas(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL", "jane@example.com") + inner := &scriptedStream{events: []domain.StreamEvent{ + toolCallDelta(`{"email":"[EMAIL_1]"}`), + }} + st := newPIIUnmaskingStream(inner, m) + + ev, err := st.Recv() + if err != nil { + t.Fatalf("Recv: %v", err) + } + if ev.ToolCallDelta == nil || ev.ToolCallDelta.ArgumentsDelta == nil { + t.Fatalf("expected tool-call argument delta, got %#v", ev) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `{"email":"jane@example.com"}` { + t.Fatalf("tool-call argument delta not restored, got %q", got) + } +} + +func TestPIIUnmaskingStream_RestoresBracketlessToolCallArgumentAlias(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + inner := &scriptedStream{events: []domain.StreamEvent{ + toolCallDelta(`{"email":"EMAIL_ADDRESS_1"}`), + }} + st := newPIIUnmaskingStream(inner, m) + + ev, err := st.Recv() + if err != nil { + t.Fatalf("Recv: %v", err) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `{"email":"jane@example.com"}` { + t.Fatalf("tool-call argument delta not restored, got %q", got) + } +} + +func TestPIIUnmaskingStream_BuffersBracketlessAliasSplitAcrossToolCallArgumentDeltas(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + inner := &scriptedStream{events: []domain.StreamEvent{ + toolCallDelta(`{"email":"EMAIL_ADD`), + toolCallDelta(`RESS_1"}`), + }} + st := newPIIUnmaskingStream(inner, m) + + ev, err := st.Recv() + if err != nil { + t.Fatalf("Recv first: %v", err) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `{"email":"` { + t.Fatalf("first tool-call argument delta = %q", got) + } + + ev, err = st.Recv() + if err != nil { + t.Fatalf("Recv second: %v", err) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `jane@example.com"}` { + t.Fatalf("second tool-call argument delta = %q", got) + } +} + +func TestPIIUnmaskingStream_BuffersTokenSplitAcrossToolCallArgumentDeltas(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL", "jane@example.com") + inner := &scriptedStream{events: []domain.StreamEvent{ + toolCallDelta(`{"email":"[EMA`), + toolCallDelta(`IL_1]"}`), + }} + st := newPIIUnmaskingStream(inner, m) + + ev, err := st.Recv() + if err != nil { + t.Fatalf("Recv first: %v", err) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `{"email":"` { + t.Fatalf("first tool-call argument delta = %q", got) + } + + ev, err = st.Recv() + if err != nil { + t.Fatalf("Recv second: %v", err) + } + if got := *ev.ToolCallDelta.ArgumentsDelta; got != `jane@example.com"}` { + t.Fatalf("second tool-call argument delta = %q", got) + } +} + +func TestUnmaskResult_LogsUnresolvedToolCallArgumentTokens(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + logger := &captureLogger{} + role := "assistant" + result := domain.GenerateResult{ + Output: []domain.OutputItem{ + { + Type: domain.OutputItemTypeMessage, + Role: &role, + ToolCalls: []domain.ToolCall{ + {ID: "call_1", Name: "send_email", ArgumentsJSON: `{"email":"EMAIL_ADDRESS_2"}`}, + }, + }, + }, + } + + unmaskResult(&result, m, logger, "request_id", "req-1") + + if len(logger.warnings) != 1 { + t.Fatalf("warnings = %d, want 1", len(logger.warnings)) + } + if logger.warnings[0].msg != "pii restoration unresolved tokens" { + t.Fatalf("warning msg = %q", logger.warnings[0].msg) + } + if got := result.Output[0].ToolCalls[0].ArgumentsJSON; got != `{"email":"EMAIL_ADDRESS_2"}` { + t.Fatalf("arguments changed unexpectedly: %q", got) + } +} + +func TestPIIUnmaskingStream_LogsUnresolvedToolCallArgumentTokensOnCompleted(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + logger := &captureLogger{} + inner := &scriptedStream{events: []domain.StreamEvent{ + toolCallDelta(`{"email":"EMAIL_ADDRESS_2"}`), + completedEvent(), + }} + st := newPIIUnmaskingStreamWithLogger(inner, m, logger, "request_id", "req-1") + + for { + _, err := st.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Recv: %v", err) + } + } + + if len(logger.warnings) != 1 { + t.Fatalf("warnings = %d, want 1", len(logger.warnings)) + } + if logger.warnings[0].msg != "pii restoration unresolved tokens" { + t.Fatalf("warning msg = %q", logger.warnings[0].msg) + } +} + +func TestSanitizeErrorWithPIIMapping_MasksKnownOriginalsInMessageAndMetadata(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL", "jane@example.com") + err := domain.ErrProviderError(502, "provider echoed jane@example.com").WithMeta( + "upstream_error", `bad request for jane@example.com`, + "nested", map[string]any{"body": "jane@example.com"}, + ) + + sanitized, ok := sanitizeErrorWithPIIMapping(err, m).(*domain.GatewayError) + if !ok { + t.Fatalf("expected GatewayError") + } + if strings.Contains(sanitized.Message, "jane@example.com") { + t.Fatalf("message still contains PII: %q", sanitized.Message) + } + if got := sanitized.Metadata["upstream_error"]; got != `bad request for [EMAIL_1]` { + t.Fatalf("upstream_error = %#v", got) + } + nested := sanitized.Metadata["nested"].(map[string]any) + if got := nested["body"]; got != "[EMAIL_1]" { + t.Fatalf("nested body = %#v", got) + } +} + +func TestPIIUnmaskingStream_CloseDelegates(t *testing.T) { + m := domain.NewPIIMapping() + inner := &scriptedStream{} + st := newPIIUnmaskingStream(inner, m) + if err := st.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if !inner.closed { + t.Fatal("inner stream not closed") + } +} + +func strPtr(v string) *string { + return &v +} diff --git a/apps/gateway/main.go b/apps/gateway/main.go new file mode 100644 index 0000000..cf6f908 --- /dev/null +++ b/apps/gateway/main.go @@ -0,0 +1,190 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/auth/apikeys" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/billing" + gwhttp "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/http/handlers" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/observability/metrics" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/pii/presidio" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/providers/anthropic" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/providers/openai" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/providers/registry" + tinfoilprovider "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/providers/tinfoil" + routerclient "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/router" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/adapters/storage/postgres" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/ports" + "github.com/dappnode/dappnode-nexus-gateway/apps/gateway/internal/application/services" + "github.com/dappnode/dappnode-nexus-gateway/pkg/fx" + "github.com/dappnode/dappnode-nexus-gateway/pkg/observability/logger" +) + +func main() { + port := envOr("PORT", "8080") + dbURL := envOr("DATABASE_URL", "postgres://nexus:nexus@localhost:5432/nexus?sslmode=disable") + logLevel := envOr("LOG_LEVEL", "info") + routerURL := envOr("ROUTER_URL", "http://localhost:8083") + providerTimeout := 300 * time.Second + + zapLogger, err := logger.NewZapLogger(logLevel) + if err != nil { + log.Fatalf("failed to create logger: %v", err) + } + defer zapLogger.Sync() + + zapLogger.Info("starting gateway", "port", port) + + ctx := context.Background() + pool, err := pgxpool.New(ctx, dbURL) + if err != nil { + log.Fatalf("failed to connect to database: %v", err) + } + defer pool.Close() + + if err := pool.Ping(ctx); err != nil { + log.Fatalf("failed to ping database: %v", err) + } + zapLogger.Info("database connected") + + apiKeyRepo := postgres.NewAPIKeyRepo(pool) + modelCatalogRepo := postgres.NewModelCatalogRepo(pool) + usageRepo := postgres.NewUsageRepo(pool) + tinfoilProofRepo := postgres.NewTinfoilProofRepo(pool) + + authService := apikeys.NewService(apiKeyRepo, zapLogger) + + providerRegistry := registry.NewRegistry() + providerRegistry.Register("anthropic", anthropic.NewAdapter(providerTimeout)) + providerRegistry.Register("tinfoil", tinfoilprovider.NewAdapter(providerTimeout, zapLogger)) + // Any provider not explicitly registered falls back to the OpenAI-compatible adapter. + // This allows adding new providers (e.g. novita, mistral) via DB only — no code changes. + providerRegistry.SetDefault(openai.NewAdapter(providerTimeout, zapLogger)) + + costReserver := billing.NewInMemoryReserver() + routerClient := routerclient.NewClient(routerURL, 5*time.Second) + eurToUSDFallbackRate := envFloatOr("EUR_TO_USD_FALLBACK_RATE", 1.08) + fxProvider := fx.NewFrankfurter(2 * time.Second) + + piiFilter, piiLang, piiFailOpen := buildPIIFilter(zapLogger) + + listModelsSvc := services.NewListModelsServiceWithRateProvider(modelCatalogRepo, zapLogger, fxProvider, eurToUSDFallbackRate) + generateSvc := services.NewGenerateService(authService, usageRepo, modelCatalogRepo, routerClient, providerRegistry, usageRepo, costReserver, piiFilter, zapLogger) + generateSvc.SetPIIOptions(piiLang, piiFailOpen) + generateSvc.SetTinfoilProofRepository(tinfoilProofRepo) + chatCompletionsSvc := services.NewChatCompletionsService(generateSvc, zapLogger) + + healthHandler := handlers.NewHealthHandler() + modelsHandler := handlers.NewModelsHandler(listModelsSvc) + chatHandler := handlers.NewChatCompletionsHandler(chatCompletionsSvc, zapLogger) + tinfoilHandler := handlers.NewTinfoilHandler(authService, tinfoilProofRepo, zapLogger) + + router := gwhttp.NewRouter(healthHandler, modelsHandler, chatHandler, tinfoilHandler, zapLogger) + + metricsMux := http.NewServeMux() + metricsMux.Handle("/metrics", metrics.Handler()) + + srv := &http.Server{ + Addr: ":" + port, + Handler: router, + ReadTimeout: 30 * time.Second, + WriteTimeout: 300 * time.Second, + IdleTimeout: 120 * time.Second, + } + + go func() { + metricsAddr := envOr("METRICS_PORT", "9090") + zapLogger.Info("metrics server starting", "port", metricsAddr) + if err := http.ListenAndServe(":"+metricsAddr, metricsMux); err != nil { + zapLogger.Error("metrics server error", "error", err) + } + }() + + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + zapLogger.Info("shutting down") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + srv.Shutdown(shutdownCtx) + }() + + zapLogger.Info("gateway listening", "addr", srv.Addr) + if err := srv.ListenAndServe(); err != http.ErrServerClosed { + log.Fatalf("server error: %v", err) + } + zapLogger.Info("gateway stopped") +} + +func envOr(key, defaultVal string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultVal +} + +func envFloatOr(key string, defaultVal float64) float64 { + if v := os.Getenv(key); v != "" { + if parsed, err := strconv.ParseFloat(v, 64); err == nil && parsed > 0 { + return parsed + } + } + return defaultVal +} + +// buildPIIFilter wires the Presidio adapter from environment variables. +// PII_FILTER_ENABLED defaults to true; per-key pii_mode still decides whether +// a request is masked. Setting it to "false" / "0" returns the noop filter so +// masking becomes a no-op even for privacy-enabled keys. +func buildPIIFilter(log *logger.ZapLogger) (ports.PIIFilter, string, bool) { + enabled := strings.ToLower(envOr("PII_FILTER_ENABLED", "true")) + if enabled == "false" || enabled == "0" || enabled == "no" { + log.Info("pii filter disabled") + return presidio.NewNoopFilter(), "en", false + } + + url := envOr("PRESIDIO_ANALYZER_URL", "http://presidio-analyzer:3000") + language := envOr("PII_FILTER_LANGUAGE", "en") + failOpen := strings.EqualFold(envOr("PII_FILTER_FAIL_OPEN", "false"), "true") + + threshold := 0.4 + if v := os.Getenv("PII_FILTER_SCORE_THRESHOLD"); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + threshold = f + } + } + + timeout := 1500 * time.Millisecond + if v := os.Getenv("PII_FILTER_TIMEOUT_MS"); v != "" { + if ms, err := strconv.Atoi(v); err == nil && ms > 0 { + timeout = time.Duration(ms) * time.Millisecond + } + } + + log.Info("pii filter enabled", + "analyzer_url", url, + "language", language, + "score_threshold", threshold, + "timeout_ms", timeout.Milliseconds(), + "fail_open", failOpen, + ) + return presidio.NewAdapter(presidio.Config{ + BaseURL: url, + DefaultLanguage: language, + ScoreThreshold: threshold, + Timeout: timeout, + Logger: log, + }), language, failOpen +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5b91caf --- /dev/null +++ b/go.mod @@ -0,0 +1,102 @@ +module github.com/dappnode/dappnode-nexus-gateway + +go 1.26.4 + +require ( + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.8.0 + github.com/prometheus/client_golang v1.23.2 + github.com/shopspring/decimal v1.4.0 + github.com/tinfoilsh/tinfoil-go v0.13.1 + go.uber.org/zap v1.27.1 + golang.org/x/sync v0.20.0 +) + +require ( + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/blang/semver v3.5.1+incompatible // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cyberphone/json-canonicalization v0.0.0-20241213102144-19d51d7fe467 // indirect + github.com/digitorus/pkcs7 v0.0.0-20250730155240-ffadbf3f398c // indirect + github.com/digitorus/timestamp v0.0.0-20250524132541-c45532741eea // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/analysis v0.24.1 // indirect + github.com/go-openapi/errors v0.22.6 // indirect + github.com/go-openapi/jsonpointer v0.22.4 // indirect + github.com/go-openapi/jsonreference v0.21.4 // indirect + github.com/go-openapi/loads v0.23.2 // indirect + github.com/go-openapi/runtime v0.29.2 // indirect + github.com/go-openapi/spec v0.22.3 // indirect + github.com/go-openapi/strfmt v0.25.0 // indirect + github.com/go-openapi/swag v0.25.4 // indirect + github.com/go-openapi/swag/cmdutils v0.25.4 // indirect + github.com/go-openapi/swag/conv v0.25.4 // indirect + github.com/go-openapi/swag/fileutils v0.25.4 // indirect + github.com/go-openapi/swag/jsonname v0.25.4 // indirect + github.com/go-openapi/swag/jsonutils v0.25.4 // indirect + github.com/go-openapi/swag/loading v0.25.4 // indirect + github.com/go-openapi/swag/mangling v0.25.4 // indirect + github.com/go-openapi/swag/netutils v0.25.4 // indirect + github.com/go-openapi/swag/stringutils v0.25.4 // indirect + github.com/go-openapi/swag/typeutils v0.25.4 // indirect + github.com/go-openapi/swag/yamlutils v0.25.4 // indirect + github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/google/certificate-transparency-go v1.3.2 // indirect + github.com/google/go-containerregistry v0.20.7 // indirect + github.com/google/go-sev-guest v0.14.1 // indirect + github.com/google/go-tdx-guest v0.3.1 // indirect + github.com/google/logger v1.1.1 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect + github.com/in-toto/attestation v1.1.2 // indirect + github.com/in-toto/in-toto-golang v0.9.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/openai/openai-go/v3 v3.16.0 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.17.0 // indirect + github.com/secure-systems-lab/go-securesystemslib v0.10.0 // indirect + github.com/shibumi/go-pathspec v1.3.0 // indirect + github.com/sigstore/protobuf-specs v0.5.0 // indirect + github.com/sigstore/rekor v1.5.0 // indirect + github.com/sigstore/rekor-tiles/v2 v2.0.1 // indirect + github.com/sigstore/sigstore v1.10.4 // indirect + github.com/sigstore/sigstore-go v1.1.4 // indirect + github.com/sigstore/timestamp-authority/v2 v2.0.3 // indirect + github.com/sirupsen/logrus v1.9.4 // indirect + github.com/theupdateframework/go-tuf/v2 v2.4.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/tinfoilsh/encrypted-http-body-protocol v0.2.3 // indirect + github.com/transparency-dev/formats v0.0.0-20251027093029-9ba98ff6507f // indirect + github.com/transparency-dev/merkle v0.0.2 // indirect + go.mongodb.org/mongo-driver v1.17.6 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel v1.39.0 // indirect + go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/trace v1.39.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.52.0 // indirect + golang.org/x/mod v0.35.0 // indirect + golang.org/x/net v0.55.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/term v0.43.0 // indirect + golang.org/x/text v0.37.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect + google.golang.org/grpc v1.79.3 // indirect + google.golang.org/protobuf v1.36.11 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..95bb7d0 --- /dev/null +++ b/go.sum @@ -0,0 +1,434 @@ +cloud.google.com/go v0.121.6 h1:waZiuajrI28iAf40cWgycWNgaXPO06dupuS+sgibK6c= +cloud.google.com/go v0.121.6/go.mod h1:coChdst4Ea5vUpiALcYKXEpR1S9ZgXbhEzzMcMR66vI= +cloud.google.com/go/auth v0.18.0 h1:wnqy5hrv7p3k7cShwAU/Br3nzod7fxoqG+k0VZ+/Pk0= +cloud.google.com/go/auth v0.18.0/go.mod h1:wwkPM1AgE1f2u6dG443MiWoD8C3BtOywNsUMcUTVDRo= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/kms v1.23.2 h1:4IYDQL5hG4L+HzJBhzejUySoUOheh3Lk5YT4PCyyW6k= +cloud.google.com/go/kms v1.23.2/go.mod h1:rZ5kK0I7Kn9W4erhYVoIRPtpizjunlrfU4fUkumUp8g= +cloud.google.com/go/longrunning v0.7.0 h1:FV0+SYF1RIj59gyoWDRi45GiYUMM3K1qO51qoboQT1E= +cloud.google.com/go/longrunning v0.7.0/go.mod h1:ySn2yXmjbK9Ba0zsQqunhDkYi0+9rlXIwnoAf+h+TPY= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/AdamKorcz/go-fuzz-headers-1 v0.0.0-20230919221257-8b5d3ce2d11d h1:zjqpY4C7H15HjRPEenkS4SAn3Jy2eRRjkjZbGR30TOg= +github.com/AdamKorcz/go-fuzz-headers-1 v0.0.0-20230919221257-8b5d3ce2d11d/go.mod h1:XNqJ7hv2kY++g8XEHREpi+JqZo3+0l+CH2egBVN4yqM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0 h1:E4MgwLBGeVB5f2MdcIVD3ELVAWpr+WD6MUe1i+tM/PA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0/go.mod h1:Y2b/1clN4zsAoUd/pgNAQHjLDnTis/6ROkUfyob6psM= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0= +github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= +github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.5 h1:pz3duhAfUgnxbtVhIK39PGF/AHYyrzGEyRD9Og0QrE8= +github.com/aws/aws-sdk-go-v2/config v1.32.5/go.mod h1:xmDjzSUs/d0BB7ClzYPAZMmgQdrodNjPPhd6bGASwoE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5 h1:xMo63RlqP3ZZydpJDMBsH9uJ10hgHYfQFIk1cHDXrR4= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5/go.mod h1:hhbH6oRcou+LpXfA/0vPElh/e0M3aFeOblE1sssAAEk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/kms v1.49.1 h1:U0asSZ3ifpuIehDPkRI2rxHbmFUMplDA2VeR9Uogrmw= +github.com/aws/aws-sdk-go-v2/service/kms v1.49.1/go.mod h1:NZo9WJqQ0sxQ1Yqu1IwCHQFQunTms2MlVgejg16S1rY= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 h1:eYnlt6QxnFINKzwxP5/Ucs1vkG7VT3Iezmvfgc2waUw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= +github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb h1:EDmT6Q9Zs+SbUoc7Ik9EfrFqcylYqgPZ9ANSbTAntnE= +github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb/go.mod h1:ZjrT6AXHbDs86ZSdt/osfBi5qfexBrKUdONk989Wnk4= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= +github.com/cyberphone/json-canonicalization v0.0.0-20241213102144-19d51d7fe467 h1:uX1JmpONuD549D73r6cgnxyUu18Zb7yHAy5AYU0Pm4Q= +github.com/cyberphone/json-canonicalization v0.0.0-20241213102144-19d51d7fe467/go.mod h1:uzvlm1mxhHkdfqitSA92i7Se+S9ksOn3a3qmv/kyOCw= +github.com/danieljoos/wincred v1.2.0 h1:ozqKHaLK0W/ii4KVbbvluM91W2H3Sh0BncbUNPS7jLE= +github.com/danieljoos/wincred v1.2.0/go.mod h1:FzQLLMKBFdvu+osBrnFODiv32YGwCfx0SkRa/eYHgec= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/digitorus/pkcs7 v0.0.0-20230713084857-e76b763bdc49/go.mod h1:SKVExuS+vpu2l9IoOc0RwqE7NYnb0JlcFHFnEJkVDzc= +github.com/digitorus/pkcs7 v0.0.0-20250730155240-ffadbf3f398c h1:g349iS+CtAvba7i0Ee9EP1TlTZ9w+UncBY6HSmsFZa0= +github.com/digitorus/pkcs7 v0.0.0-20250730155240-ffadbf3f398c/go.mod h1:mCGGmWkOQvEuLdIRfPIpXViBfpWto4AhwtJlAvo62SQ= +github.com/digitorus/timestamp v0.0.0-20250524132541-c45532741eea h1:ALRwvjsSP53QmnN3Bcj0NpR8SsFLnskny/EIMebAk1c= +github.com/digitorus/timestamp v0.0.0-20250524132541-c45532741eea/go.mod h1:GvWntX9qiTlOud0WkQ6ewFm0LPy5JUR1Xo0Ngbd1w6Y= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= +github.com/go-chi/chi/v5 v5.2.4 h1:WtFKPHwlywe8Srng8j2BhOD9312j9cGUxG1SP4V2cR4= +github.com/go-chi/chi/v5 v5.2.4/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.24.1 h1:Xp+7Yn/KOnVWYG8d+hPksOYnCYImE3TieBa7rBOesYM= +github.com/go-openapi/analysis v0.24.1/go.mod h1:dU+qxX7QGU1rl7IYhBC8bIfmWQdX4Buoea4TGtxXY84= +github.com/go-openapi/errors v0.22.6 h1:eDxcf89O8odEnohIXwEjY1IB4ph5vmbUsBMsFNwXWPo= +github.com/go-openapi/errors v0.22.6/go.mod h1:z9S8ASTUqx7+CP1Q8dD8ewGH/1JWFFLX/2PmAYNQLgk= +github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4= +github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80= +github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmGjjySRCHtC8= +github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4= +github.com/go-openapi/loads v0.23.2 h1:rJXAcP7g1+lWyBHC7iTY+WAF0rprtM+pm8Jxv1uQJp4= +github.com/go-openapi/loads v0.23.2/go.mod h1:IEVw1GfRt/P2Pplkelxzj9BYFajiWOtY2nHZNj4UnWY= +github.com/go-openapi/runtime v0.29.2 h1:UmwSGWNmWQqKm1c2MGgXVpC2FTGwPDQeUsBMufc5Yj0= +github.com/go-openapi/runtime v0.29.2/go.mod h1:biq5kJXRJKBJxTDJXAa00DOTa/anflQPhT0/wmjuy+0= +github.com/go-openapi/spec v0.22.3 h1:qRSmj6Smz2rEBxMnLRBMeBWxbbOvuOoElvSvObIgwQc= +github.com/go-openapi/spec v0.22.3/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs= +github.com/go-openapi/strfmt v0.25.0 h1:7R0RX7mbKLa9EYCTHRcCuIPcaqlyQiWNPTXwClK0saQ= +github.com/go-openapi/strfmt v0.25.0/go.mod h1:nNXct7OzbwrMY9+5tLX4I21pzcmE6ccMGXl3jFdPfn8= +github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU= +github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ= +github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4= +github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0= +github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4= +github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU= +github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y= +github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk= +github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= +github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= +github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA= +github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM= +github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s= +github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE= +github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48= +github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg= +github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0= +github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg= +github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8= +github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0= +github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw= +github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE= +github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw= +github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= +github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= +github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= +github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= +github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/certificate-transparency-go v1.3.2 h1:9ahSNZF2o7SYMaKaXhAumVEzXB2QaayzII9C8rv7v+A= +github.com/google/certificate-transparency-go v1.3.2/go.mod h1:H5FpMUaGa5Ab2+KCYsxg6sELw3Flkl7pGZzWdBoYLXs= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-configfs-tsm v0.2.2 h1:YnJ9rXIOj5BYD7/0DNnzs8AOp7UcvjfTvt215EWcs98= +github.com/google/go-configfs-tsm v0.2.2/go.mod h1:EL1GTDFMb5PZQWDviGfZV9n87WeGTR/JUg13RfwkgRo= +github.com/google/go-containerregistry v0.20.7 h1:24VGNpS0IwrOZ2ms2P1QE3Xa5X9p4phx0aUgzYzHW6I= +github.com/google/go-containerregistry v0.20.7/go.mod h1:Lx5LCZQjLH1QBaMPeGwsME9biPeo1lPx6lbGj/UmzgM= +github.com/google/go-sev-guest v0.14.1 h1:j/DXy9jk1qSW/dEV9vDiQnhAVFD1zqnWNVu6p1J0Jgo= +github.com/google/go-sev-guest v0.14.1/go.mod h1:SK9vW+uyfuzYdVN0m8BShL3OQCtXZe/JPF7ZkpD3760= +github.com/google/go-tdx-guest v0.3.1 h1:gl0KvjdsD4RrJzyLefDOvFOUH3NAJri/3qvaL5m83Iw= +github.com/google/go-tdx-guest v0.3.1/go.mod h1:/rc3d7rnPykOPuY8U9saMyEps0PZDThLk/RygXm04nE= +github.com/google/logger v1.1.1 h1:+6Z2geNxc9G+4D4oDO9njjjn2d0wN5d7uOo0vOIW1NQ= +github.com/google/logger v1.1.1/go.mod h1:BkeJZ+1FhQ+/d087r4dzojEg1u2ZX+ZqG1jTUrLM+zQ= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/trillian v1.7.2 h1:EPBxc4YWY4Ak8tcuhyFleY+zYlbCDCa4Sn24e1Ka8Js= +github.com/google/trillian v1.7.2/go.mod h1:mfQJW4qRH6/ilABtPYNBerVJAJ/upxHLX81zxNQw05s= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.9 h1:TOpi/QG8iDcZlkQlGlFUti/ZtyLkliXvHDcyUIMuFrU= +github.com/googleapis/enterprise-certificate-proxy v0.3.9/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y= +github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14= +github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= +github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= +github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= +github.com/hashicorp/vault/api v1.22.0 h1:+HYFquE35/B74fHoIeXlZIP2YADVboaPjaSicHEZiH0= +github.com/hashicorp/vault/api v1.22.0/go.mod h1:IUZA2cDvr4Ok3+NtK2Oq/r+lJeXkeCrHRmqdyWfpmGM= +github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef h1:A9HsByNhogrvm9cWb28sjiS3i7tcKCkflWFEkHfuAgM= +github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs= +github.com/in-toto/attestation v1.1.2 h1:MBFn6lsMq6dptQZJBhalXTcWMb/aJy3V+GX3VYj/V1E= +github.com/in-toto/attestation v1.1.2/go.mod h1:gYFddHMZj3DiQ0b62ltNi1Vj5rC879bTmBbrv9CRHpM= +github.com/in-toto/in-toto-golang v0.9.0 h1:tHny7ac4KgtsfrG6ybU8gVOZux2H8jN05AXJ9EBM1XU= +github.com/in-toto/in-toto-golang v0.9.0/go.mod h1:xsBVrVsHNsB61++S6Dy2vWosKhuA3lUTQd+eF9HdeMo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b h1:ZGiXF8sz7PDk6RgkP+A/SFfUD0ZR/AgG6SpRNEDKZy8= +github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b/go.mod h1:hQmNrgofl+IY/8L+n20H6E6PWBBTokdsv+q49j0QhsU= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= +github.com/jmespath/go-jmespath v0.4.1-0.20220621161143-b0104c826a24 h1:liMMTbpW34dhU4az1GN0pTPADwNmvoRSeoZ6PItiqnY= +github.com/jmespath/go-jmespath v0.4.1-0.20220621161143-b0104c826a24/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/letsencrypt/boulder v0.20251110.0 h1:J8MnKICeilO91dyQ2n5eBbab24neHzUpYMUIOdOtbjc= +github.com/letsencrypt/boulder v0.20251110.0/go.mod h1:ogKCJQwll82m7OVHWyTuf8eeFCjuzdRQlgnZcCl0V+8= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= +github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/openai/openai-go/v3 v3.16.0 h1:VdqS+GFZgAvEOBcWNyvLVwPlYEIboW5xwiUCcLrVf8c= +github.com/openai/openai-go/v3 v3.16.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= +github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/sassoftware/relic v7.2.1+incompatible h1:Pwyh1F3I0r4clFJXkSI8bOyJINGqpgjJU3DYAZeI05A= +github.com/sassoftware/relic v7.2.1+incompatible/go.mod h1:CWfAxv73/iLZ17rbyhIEq3K9hs5w6FpNMdUT//qR+zk= +github.com/sassoftware/relic/v7 v7.6.2 h1:rS44Lbv9G9eXsukknS4mSjIAuuX+lMq/FnStgmZlUv4= +github.com/sassoftware/relic/v7 v7.6.2/go.mod h1:kjmP0IBVkJZ6gXeAu35/KCEfca//+PKM6vTAsyDPY+k= +github.com/secure-systems-lab/go-securesystemslib v0.10.0 h1:l+H5ErcW0PAehBNrBxoGv1jjNpGYdZ9RcheFkB2WI14= +github.com/secure-systems-lab/go-securesystemslib v0.10.0/go.mod h1:MRKONWmRoFzPNQ9USRF9i1mc7MvAVvF1LlW8X5VWDvk= +github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= +github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/shibumi/go-pathspec v1.3.0 h1:QUyMZhFo0Md5B8zV8x2tesohbb5kfbpTi9rBnKh5dkI= +github.com/shibumi/go-pathspec v1.3.0/go.mod h1:Xutfslp817l2I1cZvgcfeMQJG5QnU2lh5tVaaMCl3jE= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/sigstore/protobuf-specs v0.5.0 h1:F8YTI65xOHw70NrvPwJ5PhAzsvTnuJMGLkA4FIkofAY= +github.com/sigstore/protobuf-specs v0.5.0/go.mod h1:+gXR+38nIa2oEupqDdzg4qSBT0Os+sP7oYv6alWewWc= +github.com/sigstore/rekor v1.5.0 h1:rL7SghHd5HLCtsCrxw0yQg+NczGvM75EjSPPWuGjaiQ= +github.com/sigstore/rekor v1.5.0/go.mod h1:D7JoVCUkxwQOpPDNYeu+CE8zeBC18Y5uDo6tF8s2rcQ= +github.com/sigstore/rekor-tiles/v2 v2.0.1 h1:1Wfz15oSRNGF5Dzb0lWn5W8+lfO50ork4PGIfEKjZeo= +github.com/sigstore/rekor-tiles/v2 v2.0.1/go.mod h1:Pjsbhzj5hc3MKY8FfVTYHBUHQEnP0ozC4huatu4x7OU= +github.com/sigstore/sigstore v1.10.4 h1:ytOmxMgLdcUed3w1SbbZOgcxqwMG61lh1TmZLN+WeZE= +github.com/sigstore/sigstore v1.10.4/go.mod h1:tDiyrdOref3q6qJxm2G+JHghqfmvifB7hw+EReAfnbI= +github.com/sigstore/sigstore-go v1.1.4 h1:wTTsgCHOfqiEzVyBYA6mDczGtBkN7cM8mPpjJj5QvMg= +github.com/sigstore/sigstore-go v1.1.4/go.mod h1:2U/mQOT9cjjxrtIUeKDVhL+sHBKsnWddn8URlswdBsg= +github.com/sigstore/sigstore/pkg/signature/kms/aws v1.10.3 h1:D/FRl5J9UYAJPGZRAJbP0dH78pfwWnKsyCSBwFBU8CI= +github.com/sigstore/sigstore/pkg/signature/kms/aws v1.10.3/go.mod h1:2GIWuNvTRMvrzd0Nl8RNqxrt9H7X0OBStwOSzGYRjYw= +github.com/sigstore/sigstore/pkg/signature/kms/azure v1.10.3 h1:k5VMLf/ms7hh6MLgVoorM0K+hSMwZLXoywlxh4CXqP8= +github.com/sigstore/sigstore/pkg/signature/kms/azure v1.10.3/go.mod h1:S1Bp3dmP7jYlXcGLAxG81wRbE01NIZING8ZIy0dJlAI= +github.com/sigstore/sigstore/pkg/signature/kms/gcp v1.10.0 h1:iUEf5MZYOuXGnXxdF/WrarJrk0DTVHqeIOjYdtpVXtc= +github.com/sigstore/sigstore/pkg/signature/kms/gcp v1.10.0/go.mod h1:i6vg5JfEQix46R1rhQlrKmUtJoeH91drltyYOJEk1T4= +github.com/sigstore/sigstore/pkg/signature/kms/hashivault v1.10.3 h1:lJSdaC/aOlFHlvqmmV696n1HdXLMLEKGwpNZMV0sKts= +github.com/sigstore/sigstore/pkg/signature/kms/hashivault v1.10.3/go.mod h1:b2rV9qPbt/jv/Yy75AIOZThP8j+pe1ZdLEjOwmjPdoA= +github.com/sigstore/timestamp-authority/v2 v2.0.3 h1:sRyYNtdED/ttLCMdaYnwpf0zre1A9chvjTnCmWWxN8Y= +github.com/sigstore/timestamp-authority/v2 v2.0.3/go.mod h1:mDaHxkt3HmZYoIlwYj4QWo0RUr7VjYU52aVO5f5Qb3I= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/theupdateframework/go-tuf v0.7.0 h1:CqbQFrWo1ae3/I0UCblSbczevCCbS31Qvs5LdxRWqRI= +github.com/theupdateframework/go-tuf v0.7.0/go.mod h1:uEB7WSY+7ZIugK6R1hiBMBjQftaFzn7ZCDJcp1tCUug= +github.com/theupdateframework/go-tuf/v2 v2.4.1 h1:K6ewW064rKZCPkRo1W/CTbTtm/+IB4+coG1iNURAGCw= +github.com/theupdateframework/go-tuf/v2 v2.4.1/go.mod h1:Nex2enPVYDFCklrnbTzl3OVwD7fgIAj0J5++z/rvCj8= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/tinfoilsh/encrypted-http-body-protocol v0.2.3 h1:+pTW3gfvgZnlt6sOEHwdHAtUOOvq1YO7rhtAcGVD6W4= +github.com/tinfoilsh/encrypted-http-body-protocol v0.2.3/go.mod h1:THDK0GFNny7Pcc+nO3AQi4f6Wf1cDkfLatmHAUlIn5s= +github.com/tinfoilsh/tinfoil-go v0.13.1 h1:rgfF8/t5b7JQ9ijX5Wa8XRSM0ODe9i1/0uoZUq7R1MM= +github.com/tinfoilsh/tinfoil-go v0.13.1/go.mod h1:guFi02dl4Ff8s0GJSsLKJcvpZg00EfqnQJ6MUFszda0= +github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 h1:N9UxlsOzu5mttdjhxkDLbzwtEecuXmlxZVo/ds7JKJI= +github.com/tink-crypto/tink-go-awskms/v2 v2.1.0/go.mod h1:PxSp9GlOkKL9rlybW804uspnHuO9nbD98V/fDX4uSis= +github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 h1:3B9i6XBXNTRspfkTC0asN5W0K6GhOSgcujNiECNRNb0= +github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0/go.mod h1:jY5YN2BqD/KSCHM9SqZPIpJNG/u3zwfLXHgws4x2IRw= +github.com/tink-crypto/tink-go-hcvault/v2 v2.3.0 h1:6nAX1aRGnkg2SEUMwO5toB2tQkP0Jd6cbmZ/K5Le1V0= +github.com/tink-crypto/tink-go-hcvault/v2 v2.3.0/go.mod h1:HOC5NWW1wBI2Vke1FGcRBvDATkEYE7AUDiYbXqi2sBw= +github.com/tink-crypto/tink-go/v2 v2.5.0 h1:B8KLF6AofxdBIE4UJIaFbmoj5/1ehEtt7/MmzfI4Zpw= +github.com/tink-crypto/tink-go/v2 v2.5.0/go.mod h1:2WbBA6pfNsAfBwDCggboaHeB2X29wkU8XHtGwh2YIk8= +github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 h1:e/5i7d4oYZ+C1wj2THlRK+oAhjeS/TRQwMfkIuet3w0= +github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399/go.mod h1:LdwHTNJT99C5fTAzDz0ud328OgXz+gierycbcIx2fRs= +github.com/transparency-dev/formats v0.0.0-20251027093029-9ba98ff6507f h1:Kf4JGP264MtiqoBx2oZca4mS+fPfbmZVNAcDq6WVlQU= +github.com/transparency-dev/formats v0.0.0-20251027093029-9ba98ff6507f/go.mod h1:g85IafeFJZLxlzZCDRu4JLpfS7HKzR+Hw9qRh3bVzDI= +github.com/transparency-dev/merkle v0.0.2 h1:Q9nBoQcZcgPamMkGn7ghV8XiTZ/kRxn1yCG81+twTK4= +github.com/transparency-dev/merkle v0.0.2/go.mod h1:pqSy+OXefQ1EDUVmAJ8MUhHB9TXGuzVAT58PqBoHz1A= +github.com/zalando/go-keyring v0.2.3 h1:v9CUu9phlABObO4LPWycf+zwMG7nlbb3t/B5wa97yms= +github.com/zalando/go-keyring v0.2.3/go.mod h1:HL4k+OXQfJUWaMnqyuSOc0drfGPX2b51Du6K+MRgZMk= +go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss= +go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.step.sm/crypto v0.74.0 h1:/APBEv45yYR4qQFg47HA8w1nesIGcxh44pGyQNw6JRA= +go.step.sm/crypto v0.74.0/go.mod h1:UoXqCAJjjRgzPte0Llaqen7O9P7XjPmgjgTHQGkKCDk= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= +golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/api v0.260.0 h1:XbNi5E6bOVEj/uLXQRlt6TKuEzMD7zvW/6tNwltE4P4= +google.golang.org/api v0.260.0/go.mod h1:Shj1j0Phr/9sloYrKomICzdYgsSDImpTxME8rGLaZ/o= +google.golang.org/genproto v0.0.0-20251202230838-ff82c1b0f217 h1:GvESR9BIyHUahIb0NcTum6itIWtdoglGX+rnGxm2934= +google.golang.org/genproto v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:yJ2HH4EHEDTd3JiLmhds6NkJ17ITVYOdV3m3VKOnws0= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b h1:Mv8VFug0MP9e5vUxfBcE3vUkV6CImK3cMNMIDFjmzxU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= +k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= +software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= +software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/pkg/domain/account.go b/pkg/domain/account.go new file mode 100644 index 0000000..e74e7b9 --- /dev/null +++ b/pkg/domain/account.go @@ -0,0 +1,200 @@ +package domain + +import "time" + +// Account status constants. +const ( + AccountStatusActive = "active" + AccountStatusInactive = "inactive" +) + +// Account plan constants. +const ( + AccountPlanFree = "free" + AccountPlanPro20Monthly = "pro_20_monthly" +) + +// Account represents a gateway account. +type Account struct { + ID string + AuthgearSubjectID *string + ExternalCustomerID *string + Email *string + Name *string + Status string + Plan string + SubscriptionCancelAtPeriodEnd bool + SubscriptionCancelAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + // LastUsedAt is the most recent `last_used_at` across the account's API keys. + // Populated only by list queries that aggregate key usage; nil elsewhere. + LastUsedAt *time.Time +} + +// IsActive returns true if the account status is "active". +func (a Account) IsActive() bool { + return a.Status == AccountStatusActive +} + +// SubscriptionStatusMessage returns a user-facing summary of subscription cancellation state. +func (a Account) SubscriptionStatusMessage() *string { + if !a.SubscriptionCancelAtPeriodEnd { + return nil + } + if a.SubscriptionCancelAt == nil { + message := "scheduled to cancel at period end" + return &message + } + + message := "scheduled to cancel on " + a.SubscriptionCancelAt.UTC().Format("January 2, 2006") + return &message +} + +// APIKey represents a hashed API key record. +type APIKey struct { + ID string + AccountID string + Name *string + KeyPrefix string + KeyHash string + Active bool + PIIMode string + LastUsedAt *time.Time + ExpiresAt *time.Time + RevokedAt *time.Time + CreatedAt time.Time +} + +// IsKeyActive returns true if the key is marked active. +func (k APIKey) IsKeyActive() bool { + return k.Active +} + +// API key PII masking levels. The level is stored per key and can be changed +// explicitly from the control plane. +const ( + APIKeyPIIModeOff = "off" + APIKeyPIIModeLow = "low" + APIKeyPIIModeBalanced = "balanced" + APIKeyPIIModeHigh = "high" +) + +// DefaultAPIKeyPIIMode is used when callers omit pii_mode on key creation. +const DefaultAPIKeyPIIMode = APIKeyPIIModeOff + +// NormalizeAPIKeyPIIMode returns a valid mode, defaulting empty input to off. +func NormalizeAPIKeyPIIMode(mode string) (string, bool) { + if mode == "" { + return DefaultAPIKeyPIIMode, true + } + switch mode { + case APIKeyPIIModeOff, APIKeyPIIModeLow, APIKeyPIIModeBalanced, APIKeyPIIModeHigh: + return mode, true + default: + return "", false + } +} + +// CreateAPIKeyParams holds parameters for creating a new API key. +type CreateAPIKeyParams struct { + AccountID string + Name *string + PIIMode string + ExpiresAt *time.Time +} + +// CreateAPIKeyResult is returned after key creation, including the raw key shown only once. +type CreateAPIKeyResult struct { + APIKey APIKey + RawKey string +} + +// AccountBalance represents the billing balance for an account. +type AccountBalance struct { + AccountID string + Currency string + MonthlyCreditTotalMicrocents int64 + MonthlyCreditUsedMicrocents int64 + PrepaidCreditMicrocents int64 + BillingPeriodStart time.Time + BillingPeriodEnd time.Time + UpdatedAt time.Time +} + +// RemainingMonthlyMicrocents returns unused monthly credits. +func (b AccountBalance) RemainingMonthlyMicrocents() int64 { + remaining := b.MonthlyCreditTotalMicrocents - b.MonthlyCreditUsedMicrocents + if remaining < 0 { + return 0 + } + return remaining +} + +// ActiveRemainingMonthlyMicrocents returns the unused monthly credits when the billing period is active. +func (b AccountBalance) ActiveRemainingMonthlyMicrocents(now time.Time) int64 { + if now.Before(b.BillingPeriodStart) || !now.Before(b.BillingPeriodEnd) { + return 0 + } + return b.RemainingMonthlyMicrocents() +} + +// TotalAvailableMicrocents returns the total available credits (monthly remaining + prepaid). +func (b AccountBalance) TotalAvailableMicrocents() int64 { + return b.RemainingMonthlyMicrocents() + maxInt64(0, b.PrepaidCreditMicrocents) +} + +// SpendableCreditMicrocentsAt returns the spendable credits at the provided time. +func (b AccountBalance) SpendableCreditMicrocentsAt(now time.Time) int64 { + return b.ActiveRemainingMonthlyMicrocents(now) + maxInt64(0, b.PrepaidCreditMicrocents) +} + +// ApplyUsageCharge applies a usage cost against monthly credits first, then prepaid credits. +func (b AccountBalance) ApplyUsageCharge(costMicrocents int64, now time.Time) (AccountBalance, UsageCharge) { + if costMicrocents <= 0 { + return b, UsageCharge{CostMicrocents: costMicrocents} + } + + updated := b + remainingCost := costMicrocents + + chargedMonthly := minInt64(remainingCost, updated.ActiveRemainingMonthlyMicrocents(now)) + if chargedMonthly > 0 { + updated.MonthlyCreditUsedMicrocents += chargedMonthly + remainingCost -= chargedMonthly + } + + chargedPrepaid := minInt64(remainingCost, maxInt64(0, updated.PrepaidCreditMicrocents)) + if chargedPrepaid > 0 { + updated.PrepaidCreditMicrocents -= chargedPrepaid + remainingCost -= chargedPrepaid + } + + return updated, UsageCharge{ + CostMicrocents: costMicrocents, + ChargedMonthlyMicrocents: chargedMonthly, + ChargedPrepaidMicrocents: chargedPrepaid, + UnchargedCostMicrocents: remainingCost, + SpendableCreditMicrocents: chargedMonthly + chargedPrepaid, + } +} + +func minInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func maxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// AuthContext carries the authenticated account and API key for a request. +type AuthContext struct { + Account Account + APIKey APIKey +} diff --git a/pkg/domain/account_test.go b/pkg/domain/account_test.go new file mode 100644 index 0000000..4837b67 --- /dev/null +++ b/pkg/domain/account_test.go @@ -0,0 +1,32 @@ +package domain_test + +import ( + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestNormalizeAPIKeyPIIMode(t *testing.T) { + tests := []struct { + name string + mode string + want string + wantOK bool + }{ + {name: "empty defaults off", mode: "", want: domain.APIKeyPIIModeOff, wantOK: true}, + {name: "off", mode: domain.APIKeyPIIModeOff, want: domain.APIKeyPIIModeOff, wantOK: true}, + {name: "low", mode: domain.APIKeyPIIModeLow, want: domain.APIKeyPIIModeLow, wantOK: true}, + {name: "balanced", mode: domain.APIKeyPIIModeBalanced, want: domain.APIKeyPIIModeBalanced, wantOK: true}, + {name: "high", mode: domain.APIKeyPIIModeHigh, want: domain.APIKeyPIIModeHigh, wantOK: true}, + {name: "invalid", mode: "everything", wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := domain.NormalizeAPIKeyPIIMode(tt.mode) + if ok != tt.wantOK || got != tt.want { + t.Fatalf("NormalizeAPIKeyPIIMode(%q) = %q,%v want %q,%v", tt.mode, got, ok, tt.want, tt.wantOK) + } + }) + } +} diff --git a/pkg/domain/billing.go b/pkg/domain/billing.go new file mode 100644 index 0000000..e442d49 --- /dev/null +++ b/pkg/domain/billing.go @@ -0,0 +1,44 @@ +package domain + +import "time" + +const ( + LedgerBalanceTypeMonthly = "monthly" + LedgerBalanceTypePrepaid = "prepaid" +) + +const ( + LedgerSourceTypeUsageEvent = "usage_event" + LedgerSourceTypeStripeInvoice = "stripe_invoice" + LedgerSourceTypeStripeCheckoutSession = "stripe_checkout_session" + LedgerSourceTypePromoCode = "promo_code" +) + +// UsageCharge describes how a usage cost was charged against account credits. +type UsageCharge struct { + CostMicrocents int64 + ChargedMonthlyMicrocents int64 + ChargedPrepaidMicrocents int64 + UnchargedCostMicrocents int64 + SpendableCreditMicrocents int64 +} + +// HasCharge reports whether any funds were actually charged. +func (c UsageCharge) HasCharge() bool { + return c.ChargedMonthlyMicrocents > 0 || c.ChargedPrepaidMicrocents > 0 +} + +// AccountLedgerEntry records an idempotent billing mutation for an account. +type AccountLedgerEntry struct { + ID string + AccountID string + SourceType string + SourceID string + // SourceLabel is optionally populated by query paths that hydrate a + // human-readable source, such as a redeemed promo code. + SourceLabel *string + BalanceType string + AmountMicrocents int64 + Description *string + CreatedAt time.Time +} diff --git a/pkg/domain/domain_test.go b/pkg/domain/domain_test.go new file mode 100644 index 0000000..3076561 --- /dev/null +++ b/pkg/domain/domain_test.go @@ -0,0 +1,239 @@ +package domain_test + +import ( + "testing" + "time" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" + "github.com/shopspring/decimal" +) + +func TestGatewayError(t *testing.T) { + err := domain.ErrInvalidAPIKey("test message") + if err.HTTPStatus != 401 { + t.Errorf("expected HTTP 401, got %d", err.HTTPStatus) + } + if err.Type != domain.ErrTypeAuthentication { + t.Errorf("expected type %s, got %s", domain.ErrTypeAuthentication, err.Type) + } + if err.Code != domain.ErrCodeInvalidAPIKey { + t.Errorf("expected code %s, got %s", domain.ErrCodeInvalidAPIKey, err.Code) + } + if err.Error() == "" { + t.Error("expected non-empty error message") + } +} + +func TestErrorConstructors(t *testing.T) { + tests := []struct { + name string + err *domain.GatewayError + wantStatus int + wantType string + wantCode string + }{ + {"InvalidAPIKey", domain.ErrInvalidAPIKey("bad"), 401, domain.ErrTypeAuthentication, domain.ErrCodeInvalidAPIKey}, + {"InactiveAPIKey", domain.ErrInactiveAPIKey(), 401, domain.ErrTypeAuthentication, domain.ErrCodeInactiveAPIKey}, + {"InactiveAccount", domain.ErrInactiveAccount(), 403, domain.ErrTypePermission, domain.ErrCodeInactiveAccount}, + {"UnsupportedModel", domain.ErrUnsupportedModel("foo"), 404, domain.ErrTypeInvalidRequest, domain.ErrCodeUnsupportedModel}, + {"UnsupportedEndpoint", domain.ErrUnsupportedEndpoint("m", "e"), 422, domain.ErrTypeInvalidRequest, domain.ErrCodeUnsupportedEndpoint}, + {"UnsupportedFeature", domain.ErrUnsupportedFeature("f"), 422, domain.ErrTypeInvalidRequest, domain.ErrCodeUnsupportedFeature}, + {"ProviderUnavailable", domain.ErrProviderUnavailable("p"), 503, domain.ErrTypeProvider, domain.ErrCodeProviderUnavailable}, + {"ProviderTimeout", domain.ErrProviderTimeout("p"), 504, domain.ErrTypeProvider, domain.ErrCodeProviderTimeout}, + {"InvalidField", domain.ErrInvalidField("f"), 400, domain.ErrTypeInvalidRequest, domain.ErrCodeInvalidField}, + {"InsufficientBalance", domain.ErrInsufficientBalance(), 402, domain.ErrTypePermission, domain.ErrCodeInsufficientBalance}, + {"UnknownField", domain.ErrUnknownField("f"), 400, domain.ErrTypeInvalidRequest, domain.ErrCodeUnknownField}, + {"Internal", domain.ErrInternal("i"), 500, domain.ErrTypeInternal, domain.ErrCodeInternalError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.HTTPStatus != tt.wantStatus { + t.Errorf("HTTPStatus = %d, want %d", tt.err.HTTPStatus, tt.wantStatus) + } + if tt.err.Type != tt.wantType { + t.Errorf("Type = %s, want %s", tt.err.Type, tt.wantType) + } + if tt.err.Code != tt.wantCode { + t.Errorf("Code = %s, want %s", tt.err.Code, tt.wantCode) + } + }) + } +} + +func TestAccountIsActive(t *testing.T) { + a := domain.Account{Status: "active"} + if !a.IsActive() { + t.Error("expected active") + } + a.Status = "inactive" + if a.IsActive() { + t.Error("expected inactive") + } +} + +func TestPublicModelSupportsEndpoint(t *testing.T) { + m := domain.PublicModel{ + SupportsChatCompletions: true, + } + if !m.SupportsEndpoint(domain.EndpointChatCompletions) { + t.Error("should support chat_completions") + } + if m.SupportsEndpoint("unknown") { + t.Error("should not support unknown") + } +} + +func TestPublicModelSupportsStreamForEndpoint(t *testing.T) { + m := domain.PublicModel{ + SupportsChatCompletionsStream: true, + } + if !m.SupportsStreamForEndpoint(domain.EndpointChatCompletions) { + t.Error("should support chat_completions stream") + } + m.SupportsChatCompletionsStream = false + if m.SupportsStreamForEndpoint(domain.EndpointChatCompletions) { + t.Error("should not support chat_completions stream when disabled") + } +} + +func TestAccountBalanceApplyUsageCharge_MonthlyBeforePrepaid(t *testing.T) { + now := time.Date(2026, 3, 24, 12, 0, 0, 0, time.UTC) + balance := domain.AccountBalance{ + MonthlyCreditTotalMicrocents: 1000, + MonthlyCreditUsedMicrocents: 800, + PrepaidCreditMicrocents: 500, + BillingPeriodStart: now.Add(-time.Hour), + BillingPeriodEnd: now.Add(time.Hour), + } + + updated, charge := balance.ApplyUsageCharge(400, now) + if charge.ChargedMonthlyMicrocents != 200 { + t.Fatalf("monthly charge = %d, want 200", charge.ChargedMonthlyMicrocents) + } + if charge.ChargedPrepaidMicrocents != 200 { + t.Fatalf("prepaid charge = %d, want 200", charge.ChargedPrepaidMicrocents) + } + if updated.MonthlyCreditUsedMicrocents != 1000 { + t.Fatalf("monthly used = %d, want 1000", updated.MonthlyCreditUsedMicrocents) + } + if updated.PrepaidCreditMicrocents != 300 { + t.Fatalf("prepaid credits = %d, want 300", updated.PrepaidCreditMicrocents) + } +} + +func TestAccountBalanceApplyUsageCharge_IgnoresExpiredMonthlyCredits(t *testing.T) { + now := time.Date(2026, 3, 24, 12, 0, 0, 0, time.UTC) + balance := domain.AccountBalance{ + MonthlyCreditTotalMicrocents: 1000, + MonthlyCreditUsedMicrocents: 100, + PrepaidCreditMicrocents: 500, + BillingPeriodStart: now.Add(-48 * time.Hour), + BillingPeriodEnd: now.Add(-24 * time.Hour), + } + + updated, charge := balance.ApplyUsageCharge(300, now) + if charge.ChargedMonthlyMicrocents != 0 { + t.Fatalf("monthly charge = %d, want 0", charge.ChargedMonthlyMicrocents) + } + if charge.ChargedPrepaidMicrocents != 300 { + t.Fatalf("prepaid charge = %d, want 300", charge.ChargedPrepaidMicrocents) + } + if updated.MonthlyCreditUsedMicrocents != 100 { + t.Fatalf("monthly used = %d, want unchanged 100", updated.MonthlyCreditUsedMicrocents) + } + if updated.PrepaidCreditMicrocents != 200 { + t.Fatalf("prepaid credits = %d, want 200", updated.PrepaidCreditMicrocents) + } +} + +func TestAccountBalanceApplyUsageCharge_DoesNotGoNegative(t *testing.T) { + now := time.Date(2026, 3, 24, 12, 0, 0, 0, time.UTC) + balance := domain.AccountBalance{ + MonthlyCreditTotalMicrocents: 500, + MonthlyCreditUsedMicrocents: 300, + PrepaidCreditMicrocents: 100, + BillingPeriodStart: now.Add(-time.Hour), + BillingPeriodEnd: now.Add(time.Hour), + } + + updated, charge := balance.ApplyUsageCharge(1000, now) + if charge.ChargedMonthlyMicrocents != 200 { + t.Fatalf("monthly charge = %d, want 200", charge.ChargedMonthlyMicrocents) + } + if charge.ChargedPrepaidMicrocents != 100 { + t.Fatalf("prepaid charge = %d, want 100", charge.ChargedPrepaidMicrocents) + } + if charge.UnchargedCostMicrocents != 700 { + t.Fatalf("uncharged cost = %d, want 700", charge.UnchargedCostMicrocents) + } + if updated.MonthlyCreditUsedMicrocents != 500 { + t.Fatalf("monthly used = %d, want 500", updated.MonthlyCreditUsedMicrocents) + } + if updated.PrepaidCreditMicrocents != 0 { + t.Fatalf("prepaid credits = %d, want 0", updated.PrepaidCreditMicrocents) + } +} + +func TestPublicModel_MaxCostMicrocents(t *testing.T) { + // GPT-5.4 Mini seed values: €0.75 input / €4.50 output per 1M tokens + // InputPricePerMillion = 75_000_000 / 100_000_000 = 0.75 + // OutputPricePerMillion = 450_000_000 / 100_000_000 = 4.50 + model := domain.PublicModel{ + MaxContextWindow: 1_000_000, + MaxOutputTokens: 16_384, + InputPricePerMillion: decimal.NewFromFloat(0.75), + OutputPricePerMillion: decimal.NewFromFloat(4.50), + } + + // No override: uses model max output tokens + // inputCost = 0.75 * 1_000_000 / 1_000_000 = 0.75 + // outputCost = 4.50 * 16_384 / 1_000_000 = 0.073728 + // total = 0.823728 euros → 82_372_800 microcents + got := model.MaxCostMicrocents(nil) + want := int64(82_372_800) + if got != want { + t.Errorf("MaxCostMicrocents(nil) = %d, want %d", got, want) + } + + // With override smaller than model max + override := 4096 + // outputCost = 4.50 * 4096 / 1_000_000 = 0.018432 + // total = 0.75 + 0.018432 = 0.768432 → 76_843_200 + got = model.MaxCostMicrocents(&override) + want = int64(76_843_200) + if got != want { + t.Errorf("MaxCostMicrocents(&4096) = %d, want %d", got, want) + } + + // Override larger than model max — ignored + override = 100_000 + got = model.MaxCostMicrocents(&override) + want = int64(82_372_800) // same as nil + if got != want { + t.Errorf("MaxCostMicrocents(&100000) = %d, want %d", got, want) + } + + // Override zero — ignored + override = 0 + got = model.MaxCostMicrocents(&override) + want = int64(82_372_800) + if got != want { + t.Errorf("MaxCostMicrocents(&0) = %d, want %d", got, want) + } +} + +func TestPublicModel_MaxCostMicrocents_MinimumOne(t *testing.T) { + // Model with zero pricing should still return at least 1 microcent + model := domain.PublicModel{ + MaxContextWindow: 100, + MaxOutputTokens: 100, + InputPricePerMillion: decimal.Zero, + OutputPricePerMillion: decimal.Zero, + } + + got := model.MaxCostMicrocents(nil) + if got != 1 { + t.Errorf("MaxCostMicrocents with zero pricing = %d, want 1", got) + } +} diff --git a/pkg/domain/error.go b/pkg/domain/error.go new file mode 100644 index 0000000..048a3b0 --- /dev/null +++ b/pkg/domain/error.go @@ -0,0 +1,180 @@ +package domain + +import ( + "errors" + "fmt" +) + +// Error types as defined in the spec. +const ( + ErrTypeInvalidRequest = "invalid_request_error" + ErrTypeAuthentication = "authentication_error" + ErrTypePermission = "permission_error" + ErrTypeProvider = "provider_error" + ErrTypeInternal = "internal_error" +) + +// Error codes as defined in the spec. +const ( + ErrCodeInvalidAPIKey = "invalid_api_key" + ErrCodeInactiveAPIKey = "inactive_api_key" + ErrCodeInactiveAccount = "inactive_account" + ErrCodeUnsupportedModel = "unsupported_model" + ErrCodeUnsupportedEndpoint = "unsupported_endpoint" + ErrCodeUnsupportedFeature = "unsupported_feature" + ErrCodeContextLengthExceeded = "context_length_exceeded" + ErrCodeProviderUnavailable = "provider_unavailable" + ErrCodeProviderTimeout = "provider_timeout" + ErrCodeToolSchemaInvalid = "tool_schema_invalid" + ErrCodeToolMessageInvalid = "tool_message_invalid" + ErrCodeInvalidField = "invalid_field" + ErrCodeUnknownField = "unknown_field" + ErrCodeInsufficientBalance = "insufficient_balance" + ErrCodeClientCanceled = "client_canceled" + ErrCodeInternalError = "internal_error" +) + +// GatewayError is the canonical error type for the gateway. +type GatewayError struct { + HTTPStatus int + Type string + Code string + Message string + // Metadata carries structured diagnostic fields for logging (never serialized to clients). + Metadata map[string]any +} + +func (e *GatewayError) Error() string { + return fmt.Sprintf("%s: %s: %s", e.Type, e.Code, e.Message) +} + +// WithMeta returns a shallow copy of the error with additional metadata fields merged in. +func (e *GatewayError) WithMeta(fields ...any) *GatewayError { + cp := *e + if cp.Metadata == nil { + cp.Metadata = make(map[string]any, len(fields)/2) + } else { + m := make(map[string]any, len(cp.Metadata)+len(fields)/2) + for k, v := range cp.Metadata { + m[k] = v + } + cp.Metadata = m + } + for i := 0; i+1 < len(fields); i += 2 { + if key, ok := fields[i].(string); ok { + cp.Metadata[key] = fields[i+1] + } + } + return &cp +} + +// LogFields returns the metadata as a flat slice of key-value pairs for structured logging. +func (e *GatewayError) LogFields() []any { + if len(e.Metadata) == 0 { + return nil + } + fields := make([]any, 0, len(e.Metadata)*2) + for k, v := range e.Metadata { + fields = append(fields, k, v) + } + return fields +} + +// Error constructors. + +func ErrInvalidAPIKey(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 401, Type: ErrTypeAuthentication, Code: ErrCodeInvalidAPIKey, Message: msg} +} + +func ErrInactiveAPIKey() *GatewayError { + return &GatewayError{HTTPStatus: 401, Type: ErrTypeAuthentication, Code: ErrCodeInactiveAPIKey, Message: "API key is inactive"} +} + +func ErrInactiveAccount() *GatewayError { + return &GatewayError{HTTPStatus: 403, Type: ErrTypePermission, Code: ErrCodeInactiveAccount, Message: "Account is inactive"} +} + +func ErrUnsupportedModel(model string) *GatewayError { + return &GatewayError{HTTPStatus: 404, Type: ErrTypeInvalidRequest, Code: ErrCodeUnsupportedModel, Message: fmt.Sprintf("Model '%s' is not available", model)} +} + +func ErrUnsupportedEndpoint(model, endpoint string) *GatewayError { + return &GatewayError{HTTPStatus: 422, Type: ErrTypeInvalidRequest, Code: ErrCodeUnsupportedEndpoint, Message: fmt.Sprintf("Model '%s' does not support endpoint '%s'", model, endpoint)} +} + +func ErrUnsupportedFeature(feature string) *GatewayError { + return &GatewayError{HTTPStatus: 422, Type: ErrTypeInvalidRequest, Code: ErrCodeUnsupportedFeature, Message: fmt.Sprintf("Feature '%s' is not supported by the selected model", feature)} +} + +func ErrProviderUnavailable(provider string) *GatewayError { + return &GatewayError{HTTPStatus: 503, Type: ErrTypeProvider, Code: ErrCodeProviderUnavailable, Message: fmt.Sprintf("Provider '%s' is unavailable", provider)} +} + +func ErrProviderTimeout(provider string) *GatewayError { + return &GatewayError{HTTPStatus: 504, Type: ErrTypeProvider, Code: ErrCodeProviderTimeout, Message: fmt.Sprintf("Provider '%s' timed out", provider)} +} + +func ErrClientCanceled() *GatewayError { + return &GatewayError{HTTPStatus: 499, Type: ErrTypeInvalidRequest, Code: ErrCodeClientCanceled, Message: "Client canceled the request"} +} + +func ErrInvalidField(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 400, Type: ErrTypeInvalidRequest, Code: ErrCodeInvalidField, Message: msg} +} + +func ErrInsufficientBalance() *GatewayError { + return &GatewayError{HTTPStatus: 402, Type: ErrTypePermission, Code: ErrCodeInsufficientBalance, Message: "Account has no spendable balance"} +} + +func ErrUnknownField(field string) *GatewayError { + return &GatewayError{HTTPStatus: 400, Type: ErrTypeInvalidRequest, Code: ErrCodeUnknownField, Message: fmt.Sprintf("Unknown field: '%s'", field)} +} + +func ErrToolSchemaInvalid(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 400, Type: ErrTypeInvalidRequest, Code: ErrCodeToolSchemaInvalid, Message: msg} +} + +func ErrToolMessageInvalid(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 400, Type: ErrTypeInvalidRequest, Code: ErrCodeToolMessageInvalid, Message: msg} +} + +func ErrInternal(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 500, Type: ErrTypeInternal, Code: ErrCodeInternalError, Message: msg} +} + +func ErrProviderError(status int, msg string) *GatewayError { + return &GatewayError{HTTPStatus: status, Type: ErrTypeProvider, Code: ErrCodeProviderUnavailable, Message: msg} +} + +// Control-plane error codes. +const ( + ErrCodeNotFound = "not_found" + ErrCodeForbidden = "forbidden" + ErrCodeConflict = "conflict" + ErrCodeAlreadyExists = "already_exists" +) + +func ErrNotFound(resource, id string) *GatewayError { + return &GatewayError{HTTPStatus: 404, Type: ErrTypeInvalidRequest, Code: ErrCodeNotFound, Message: fmt.Sprintf("%s '%s' not found", resource, id)} +} + +func ErrForbidden(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 403, Type: ErrTypePermission, Code: ErrCodeForbidden, Message: msg} +} + +func ErrConflict(msg string) *GatewayError { + return &GatewayError{HTTPStatus: 409, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: msg} +} + +func ErrAlreadyExists(resource, id string) *GatewayError { + return &GatewayError{HTTPStatus: 409, Type: ErrTypeInvalidRequest, Code: ErrCodeAlreadyExists, Message: fmt.Sprintf("%s '%s' already exists", resource, id)} +} + +// IsAlreadyExists returns true if the error is a GatewayError with code already_exists. +func IsAlreadyExists(err error) bool { + var gwErr *GatewayError + if errors.As(err, &gwErr) { + return gwErr.Code == ErrCodeAlreadyExists + } + return false +} diff --git a/pkg/domain/generation.go b/pkg/domain/generation.go new file mode 100644 index 0000000..6b84845 --- /dev/null +++ b/pkg/domain/generation.go @@ -0,0 +1,65 @@ +package domain + +// Endpoint constants. +const ( + EndpointChatCompletions = "chat_completions" +) + +// ResponseTextConfig configures structured text output. +type ResponseTextConfig struct { + FormatType *string + JSONSchema map[string]any +} + +// GenerateRequest is the canonical internal generation request. +type GenerateRequest struct { + PublicModelID string + RequestedModelID string + RouterID *string + RoutedPublicModelID *string + // Routing decision metadata. Set when the request was resolved via a + // router; nil for direct model requests. + MatchedCategory *string + RoutingScore *float32 + RoutingCategoryScores []RoutingCategoryScore + DecisionReason *string + FallbackUsed *bool + Input []InputItem + Instructions *string + MaxOutputTokens *int + Temperature *float64 + ReasoningEffort *string + TopP *float64 + Stop []string + Stream bool + Tools []ToolDefinition + ToolChoice *ToolChoice + ParallelToolCalls *bool + User *string + Metadata map[string]any + TextConfig *ResponseTextConfig + ProviderOptions map[string]any + + // Pass-through parameters (forwarded to providers that support them) + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + Seed *int `json:"seed,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + Store *bool `json:"store,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` +} + +// GenerateResult is the canonical internal generation result. +type GenerateResult struct { + ID string + CreatedUnix int64 + PublicModelID string + ProviderName string + ProviderModelID string + Output []OutputItem + FinishReason *string + Usage *Usage + TinfoilProof *TinfoilTransportProof +} diff --git a/pkg/domain/input.go b/pkg/domain/input.go new file mode 100644 index 0000000..b31478b --- /dev/null +++ b/pkg/domain/input.go @@ -0,0 +1,17 @@ +package domain + +// InputItemType constants. +const ( + InputItemTypeMessage = "message" + InputItemTypeText = "text" +) + +// InputItem is a canonical input item for generation requests. +type InputItem struct { + Type string // "message" or "text" + Role *string + Content *string + ReasoningContent *string + ToolCalls []ToolCall + ToolCallID *string +} diff --git a/pkg/domain/model.go b/pkg/domain/model.go new file mode 100644 index 0000000..55cb099 --- /dev/null +++ b/pkg/domain/model.go @@ -0,0 +1,97 @@ +package domain + +import "github.com/shopspring/decimal" + +const ( + ProofModeNone = "none" + ProofModeTinfoilAttestedTransport = "tinfoil_attested_transport" +) + +// ProviderConfig holds upstream provider connection details. +type ProviderConfig struct { + ID string + ProviderName string + BaseURL string + APIKeySecretRef string + OrganizationRef *string + Active bool +} + +// PublicModel is the fully-resolved model exposed by the gateway. +type PublicModel struct { + ID string + PublicModelID string + DisplayName string + Description *string + ProviderModelID string + UpstreamModelName string + ProviderConfig ProviderConfig + SupportsChatCompletions bool + SupportsChatCompletionsStream bool + SupportsTools bool + SupportsParallelToolCalls bool + SupportsStructuredOutput bool + SupportsReasoning bool + ProofMode string + MaxContextWindow int + MaxOutputTokens int + InputPricePerMillion decimal.Decimal + OutputPricePerMillion decimal.Decimal + CacheReadPricePerMillion *decimal.Decimal + CacheWritePricePerMillion *decimal.Decimal + Currency string + Active bool +} + +func (m PublicModel) EffectiveProofMode() string { + if m.ProofMode != "" && m.ProofMode != ProofModeNone { + return m.ProofMode + } + return ProofModeNone +} + +func ProofModeEnabled(mode string) bool { + return mode != "" && mode != ProofModeNone +} + +// SupportsEndpoint checks if the model supports the given endpoint. +func (m PublicModel) SupportsEndpoint(endpoint string) bool { + switch endpoint { + case EndpointChatCompletions: + return m.SupportsChatCompletions + default: + return false + } +} + +// SupportsStreamForEndpoint checks if the model supports streaming for the given endpoint. +func (m PublicModel) SupportsStreamForEndpoint(endpoint string) bool { + switch endpoint { + case EndpointChatCompletions: + return m.SupportsChatCompletionsStream + default: + return false + } +} + +// MaxCostMicrocents returns the worst-case cost in microcents for a single request, +// based on the model's maximum context window (input) and maximum output tokens. +// If maxOutputTokensOverride is non-nil and smaller than the model's limit, it is used instead. +func (m PublicModel) MaxCostMicrocents(maxOutputTokensOverride *int) int64 { + maxInputTokens := int64(m.MaxContextWindow) + maxOutputTokens := int64(m.MaxOutputTokens) + if maxOutputTokensOverride != nil && *maxOutputTokensOverride > 0 && int64(*maxOutputTokensOverride) < maxOutputTokens { + maxOutputTokens = int64(*maxOutputTokensOverride) + } + + million := decimal.NewFromInt(1_000_000) + toMicrocents := decimal.NewFromInt(100_000_000) + + inputCost := m.InputPricePerMillion.Mul(decimal.NewFromInt(maxInputTokens)).Div(million) + outputCost := m.OutputPricePerMillion.Mul(decimal.NewFromInt(maxOutputTokens)).Div(million) + totalMicrocents := inputCost.Add(outputCost).Mul(toMicrocents).Ceil().IntPart() + if totalMicrocents < 1 { + totalMicrocents = 1 + } + return totalMicrocents +} diff --git a/pkg/domain/output.go b/pkg/domain/output.go new file mode 100644 index 0000000..2f87b72 --- /dev/null +++ b/pkg/domain/output.go @@ -0,0 +1,16 @@ +package domain + +// OutputItemType constants. +const ( + OutputItemTypeMessage = "message" + OutputItemTypeText = "text" +) + +// OutputItem is a canonical output item from a generation result. +type OutputItem struct { + Type string // "message" or "text" + Role *string + Content *string + ReasoningContent *string + ToolCalls []ToolCall +} diff --git a/pkg/domain/pii.go b/pkg/domain/pii.go new file mode 100644 index 0000000..8208256 --- /dev/null +++ b/pkg/domain/pii.go @@ -0,0 +1,422 @@ +package domain + +import ( + "fmt" + "regexp" + "sort" + "strings" +) + +// PIIEntity is a single span of detected personally-identifiable information. +// +// Start / End are byte offsets into the original UTF-8 string (Presidio uses +// character offsets for its `analyzer` API; the adapter converts them to byte +// offsets before constructing PIIEntity values). +type PIIEntity struct { + Type string // PRESIDIO entity type, e.g. PERSON, EMAIL_ADDRESS, PHONE_NUMBER + Start int // inclusive byte offset + End int // exclusive byte offset + Score float32 // detector confidence in [0, 1] +} + +// PIIMapping is a request-scoped, deterministic mapping between placeholder +// tokens (e.g. `[PERSON_1]`) and the original PII values they replace. +// +// The same original value mapped under the same type always reuses the same +// placeholder, so repeated PII inside a single request collapses cleanly. +type PIIMapping struct { + // tokenToOriginal preserves the substitution mapping used during masking. + tokenToOriginal map[string]string + // tokenAliases maps tolerated model rewrites back to the canonical token. + // For example, models may strip brackets and return `EMAIL_ADDRESS_1` + // instead of `[EMAIL_ADDRESS_1]`. + tokenAliases map[string]string + // normalizedTokenAliases maps case/separator-normalized placeholder forms + // back to the canonical token for bracketed model rewrites. + normalizedTokenAliases map[string]string + // originalToToken accelerates dedup when the same value appears multiple times. + originalToToken map[string]string // key = type|original + // counters tracks the next available index per entity type. + counters map[string]int +} + +// NewPIIMapping returns an empty mapping ready for use. +func NewPIIMapping() *PIIMapping { + return &PIIMapping{ + tokenToOriginal: make(map[string]string), + tokenAliases: make(map[string]string), + normalizedTokenAliases: make(map[string]string), + originalToToken: make(map[string]string), + counters: make(map[string]int), + } +} + +// Token returns (or assigns) the deterministic placeholder for a given (type, value) pair. +func (m *PIIMapping) Token(entityType, original string) string { + if m == nil { + return original + } + key := entityType + "|" + original + if tok, ok := m.originalToToken[key]; ok { + return tok + } + m.counters[entityType]++ + index := m.counters[entityType] + tok := fmt.Sprintf("[%s_%d]", entityType, index) + m.originalToToken[key] = tok + m.tokenToOriginal[tok] = original + m.registerTokenAliases(tok, entityType, index) + return tok +} + +// Len returns the number of distinct placeholders in the mapping. +func (m *PIIMapping) Len() int { + if m == nil { + return 0 + } + return len(m.tokenToOriginal) +} + +// Original returns the original value for a token, if present. +func (m *PIIMapping) Original(token string) (string, bool) { + if m == nil { + return "", false + } + if v, ok := m.tokenToOriginal[token]; ok { + return v, true + } + canonical, ok := m.tokenAliases[token] + if !ok { + canonical, ok = m.normalizedTokenAliases[normalizePlaceholderToken(token)] + if !ok { + return "", false + } + } + v, ok := m.tokenToOriginal[canonical] + return v, ok +} + +func (m *PIIMapping) registerTokenAliases(canonical, entityType string, index int) { + if m == nil { + return + } + m.registerTokenAlias(strings.Trim(canonical, "[]"), canonical) + + short := placeholderTypeAlias(entityType) + if short != "" && short != entityType { + m.registerTokenAlias(fmt.Sprintf("[%s_%d]", short, index), canonical) + m.registerTokenAlias(fmt.Sprintf("%s_%d", short, index), canonical) + } +} + +func (m *PIIMapping) registerTokenAlias(alias, canonical string) { + if alias == "" || alias == canonical { + return + } + if existing, ok := m.tokenAliases[alias]; ok && existing != canonical { + return + } + m.tokenAliases[alias] = canonical + normalized := normalizePlaceholderToken(alias) + if normalized == "" { + return + } + if existing, ok := m.normalizedTokenAliases[normalized]; ok && existing != canonical { + return + } + m.normalizedTokenAliases[normalized] = canonical +} + +func placeholderTypeAlias(entityType string) string { + switch entityType { + case "EMAIL_ADDRESS": + return "EMAIL" + case "PHONE_NUMBER": + return "PHONE" + case "CREDIT_CARD": + return "CARD" + case "IBAN_CODE": + return "IBAN" + case "IP_ADDRESS": + return "IP" + case "US_SSN": + return "SSN" + case "US_DRIVER_LICENSE": + return "DRIVER_LICENSE" + default: + return entityType + } +} + +// MaskKnownOriginals replaces original values already present in the mapping +// with their placeholder tokens. It is used for diagnostics/logging paths where +// an upstream provider may echo request text back in an error body. +func (m *PIIMapping) MaskKnownOriginals(text string) string { + if m == nil || len(m.tokenToOriginal) == 0 || text == "" { + return text + } + + type pair struct { + token string + original string + } + pairs := make([]pair, 0, len(m.tokenToOriginal)) + for token, original := range m.tokenToOriginal { + if original == "" { + continue + } + pairs = append(pairs, pair{token: token, original: original}) + } + if len(pairs) == 0 { + return text + } + sort.SliceStable(pairs, func(i, j int) bool { + return len(pairs[i].original) > len(pairs[j].original) + }) + + out := text + for _, p := range pairs { + out = strings.ReplaceAll(out, p.original, p.token) + } + return out +} + +// ApplyMask rewrites `text` by replacing every entity span with its placeholder +// token from `mapping`. Entities that overlap or fall outside the bounds of +// text are silently skipped. The mapping is mutated to record any new tokens +// allocated during this call. +func ApplyMask(text string, entities []PIIEntity, mapping *PIIMapping) string { + if len(entities) == 0 || mapping == nil || text == "" { + return text + } + + // Sort by Start ascending so we can drop overlaps deterministically. + sorted := make([]PIIEntity, 0, len(entities)) + sorted = append(sorted, entities...) + sort.SliceStable(sorted, func(i, j int) bool { + if sorted[i].Start == sorted[j].Start { + return sorted[i].End > sorted[j].End // prefer longer span on tie + } + return sorted[i].Start < sorted[j].Start + }) + + // Drop invalid / overlapping spans. We keep the first occurrence and skip + // any later span that starts before the previous one ended. + kept := sorted[:0] + lastEnd := -1 + for _, e := range sorted { + if e.Start < 0 || e.End > len(text) || e.End <= e.Start { + continue + } + if e.Start < lastEnd { + continue + } + kept = append(kept, e) + lastEnd = e.End + } + if len(kept) == 0 { + return text + } + + var b strings.Builder + b.Grow(len(text)) + cursor := 0 + for _, e := range kept { + b.WriteString(text[cursor:e.Start]) + b.WriteString(mapping.Token(e.Type, text[e.Start:e.End])) + cursor = e.End + } + b.WriteString(text[cursor:]) + return b.String() +} + +// Unmask replaces every `[TYPE_N]` placeholder found in `text` with its +// original value from `mapping`. It also tolerates common model rewrites such +// as bracket stripping (`EMAIL_ADDRESS_1`) and short aliases (`EMAIL_1`). +// Tokens that are not in the mapping (e.g. hallucinated by an upstream model) +// are passed through verbatim. +func Unmask(text string, mapping *PIIMapping) string { + if mapping == nil || mapping.Len() == 0 || text == "" { + return text + } + var b strings.Builder + b.Grow(len(text)) + bareAliases := mapping.bareAliasesByLength() + i := 0 + for i < len(text) { + if text[i] == '[' { + close := strings.IndexByte(text[i:], ']') + if close >= 0 { + end := i + close + 1 + tok := text[i:end] + if original, ok := mapping.Original(tok); ok { + b.WriteString(original) + } else { + b.WriteString(tok) + } + i = end + continue + } + // Unterminated bracket — flush rest verbatim. + b.WriteString(text[i:]) + break + } + + if isBareTokenBoundary(text, i-1) { + if alias, original, ok := mapping.matchBareAlias(text[i:], bareAliases); ok && isBareTokenBoundary(text, i+len(alias)) { + b.WriteString(original) + i += len(alias) + continue + } + } + + b.WriteByte(text[i]) + i++ + } + return b.String() +} + +func (m *PIIMapping) bareAliasesByLength() []string { + aliases := make([]string, 0, len(m.tokenAliases)+len(m.tokenToOriginal)) + for token := range m.tokenToOriginal { + aliases = append(aliases, strings.Trim(token, "[]")) + } + for alias := range m.tokenAliases { + if !strings.HasPrefix(alias, "[") { + aliases = append(aliases, alias) + } + } + sort.SliceStable(aliases, func(i, j int) bool { + if len(aliases[i]) == len(aliases[j]) { + return aliases[i] < aliases[j] + } + return len(aliases[i]) > len(aliases[j]) + }) + return aliases +} + +// BareTokenAliases returns tolerated bracketless placeholder forms sorted from +// longest to shortest. It is used by streaming restorers to avoid emitting a +// partial alias before the next chunk arrives. +func (m *PIIMapping) BareTokenAliases() []string { + if m == nil || m.Len() == 0 { + return nil + } + return m.bareAliasesByLength() +} + +func (m *PIIMapping) matchBareAlias(text string, aliases []string) (string, string, bool) { + for _, alias := range aliases { + if len(text) < len(alias) || !strings.EqualFold(text[:len(alias)], alias) { + continue + } + if original, ok := m.Original(alias); ok { + return alias, original, true + } + } + return "", "", false +} + +func isBareTokenBoundary(text string, index int) bool { + if index < 0 || index >= len(text) { + return true + } + c := text[index] + return !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_') +} + +var placeholderCandidateRE = regexp.MustCompile(`\[[A-Za-z][A-Za-z0-9_ -]*[ _-][0-9]+\]|\b[A-Z][A-Z0-9_]*_[0-9]+\b`) + +// UnresolvedTokens returns placeholder-looking tokens that remain after an +// attempted restore and whose entity type matches this request's PII mapping. +// The tokens themselves are synthetic and safe to log; original values are not +// returned. +func (m *PIIMapping) UnresolvedTokens(text string) []string { + if m == nil || m.Len() == 0 || text == "" { + return nil + } + knownTypes := m.knownPlaceholderTypes() + if len(knownTypes) == 0 { + return nil + } + + seen := make(map[string]struct{}) + out := make([]string, 0) + for _, token := range placeholderCandidateRE.FindAllString(text, -1) { + if _, ok := m.Original(token); ok { + continue + } + typ := placeholderTokenType(token) + if _, ok := knownTypes[typ]; !ok { + continue + } + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + sort.Strings(out) + return out +} + +func (m *PIIMapping) knownPlaceholderTypes() map[string]struct{} { + types := make(map[string]struct{}) + for token := range m.tokenToOriginal { + if typ := placeholderTokenType(token); typ != "" { + types[typ] = struct{}{} + } + } + for alias := range m.tokenAliases { + if typ := placeholderTokenType(alias); typ != "" { + types[typ] = struct{}{} + } + } + return types +} + +func placeholderTokenType(token string) string { + token = normalizePlaceholderToken(token) + idx := strings.LastIndexByte(token, '_') + if idx <= 0 || idx == len(token)-1 { + return "" + } + return token[:idx] +} + +func normalizePlaceholderToken(token string) string { + token = strings.TrimSpace(token) + token = strings.TrimPrefix(token, "[") + token = strings.TrimSuffix(token, "]") + token = strings.TrimSpace(token) + if token == "" { + return "" + } + + var b strings.Builder + b.Grow(len(token)) + lastWasSep := false + for i := 0; i < len(token); i++ { + c := token[i] + switch { + case c >= 'a' && c <= 'z': + b.WriteByte(c - 'a' + 'A') + lastWasSep = false + case (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'): + b.WriteByte(c) + lastWasSep = false + case c == '_' || c == '-' || c == ' ': + if b.Len() > 0 && !lastWasSep { + b.WriteByte('_') + lastWasSep = true + } + default: + return "" + } + } + out := strings.Trim(b.String(), "_") + if out == "" { + return "" + } + return out +} diff --git a/pkg/domain/pii_test.go b/pkg/domain/pii_test.go new file mode 100644 index 0000000..9b78747 --- /dev/null +++ b/pkg/domain/pii_test.go @@ -0,0 +1,177 @@ +package domain_test + +import ( + "strings" + "testing" + + "github.com/dappnode/dappnode-nexus-gateway/pkg/domain" +) + +func TestPIIMapping_DeterministicNumberingAndDedup(t *testing.T) { + m := domain.NewPIIMapping() + if got := m.Token("PERSON", "John Smith"); got != "[PERSON_1]" { + t.Fatalf("first PERSON token = %q, want [PERSON_1]", got) + } + if got := m.Token("PERSON", "John Smith"); got != "[PERSON_1]" { + t.Fatalf("repeated PERSON token = %q, want [PERSON_1]", got) + } + if got := m.Token("PERSON", "Jane Doe"); got != "[PERSON_2]" { + t.Fatalf("second PERSON token = %q, want [PERSON_2]", got) + } + if got := m.Token("EMAIL", "a@b.com"); got != "[EMAIL_1]" { + t.Fatalf("first EMAIL token = %q, want [EMAIL_1]", got) + } + if m.Len() != 3 { + t.Fatalf("mapping size = %d, want 3", m.Len()) + } + if v, ok := m.Original("[PERSON_2]"); !ok || v != "Jane Doe" { + t.Fatalf("Original([PERSON_2]) = %q,%v want Jane Doe,true", v, ok) + } +} + +func TestApplyMask_BasicAndDedup(t *testing.T) { + text := "Hi John, email John at john@x.com or +1-415-555-0124." + entities := []domain.PIIEntity{ + {Type: "PERSON", Start: 3, End: 7, Score: 0.99}, // John + {Type: "PERSON", Start: 15, End: 19, Score: 0.99}, // John + {Type: "EMAIL", Start: 23, End: 33, Score: 0.99}, // john@x.com + {Type: "PHONE", Start: 37, End: len(text) - 1, Score: 0.99}, // +1-415-555-0124 + } + m := domain.NewPIIMapping() + got := domain.ApplyMask(text, entities, m) + want := "Hi [PERSON_1], email [PERSON_1] at [EMAIL_1] or [PHONE_1]." + if got != want { + t.Fatalf("ApplyMask:\n got = %q\nwant = %q", got, want) + } + if m.Len() != 3 { + t.Fatalf("mapping size = %d, want 3 (PERSON_1 dedup)", m.Len()) + } +} + +func TestApplyMask_DropsOverlapsAndInvalidSpans(t *testing.T) { + text := "hello world" + entities := []domain.PIIEntity{ + {Type: "A", Start: 0, End: 5}, // "hello" + {Type: "B", Start: 3, End: 8}, // overlaps with A — drop + {Type: "C", Start: -1, End: 2}, // invalid — drop + {Type: "D", Start: 6, End: 100}, // out of range — drop + {Type: "E", Start: 6, End: 11}, // "world" + } + got := domain.ApplyMask(text, entities, domain.NewPIIMapping()) + want := "[A_1] [E_1]" + if got != want { + t.Fatalf("ApplyMask:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_RestoresOriginalAndKeepsUnknownTokens(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + m.Token("EMAIL", "john@x.com") + got := domain.Unmask("hi [PERSON_1], your email [EMAIL_1] (and [GHOST_5]).", m) + want := "hi John, your email john@x.com (and [GHOST_5])." + if got != want { + t.Fatalf("Unmask:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_RestoresBracketlessAndShortAliases(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + m.Token("PHONE_NUMBER", "+1-415-555-0100") + + got := domain.Unmask("email=EMAIL_ADDRESS_1 alt=[EMAIL_1] phone=PHONE_1", m) + want := "email=jane@example.com alt=jane@example.com phone=+1-415-555-0100" + if got != want { + t.Fatalf("Unmask aliases:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_RestoresSafePlaceholderRewrites(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + m.Token("PERSON", "Jane Doe") + + got := domain.Unmask("email=[Email Address 1] short=[email-1] person=[Person_1] bare=EMAIL_address_1", m) + want := "email=jane@example.com short=jane@example.com person=Jane Doe bare=jane@example.com" + if got != want { + t.Fatalf("Unmask rewrites:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_DoesNotRestoreBareNaturalLanguagePhrase(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + + got := domain.Unmask("leave bare words email address 1 alone, restore [email address 1]", m) + want := "leave bare words email address 1 alone, restore jane@example.com" + if got != want { + t.Fatalf("Unmask natural phrase:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_DoesNotRestoreBareAliasInsideLongerIdentifier(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + + got := domain.Unmask("keep XEMAIL_ADDRESS_1Y but restore EMAIL_ADDRESS_1", m) + want := "keep XEMAIL_ADDRESS_1Y but restore jane@example.com" + if got != want { + t.Fatalf("Unmask boundary handling:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnresolvedTokens_FindsKnownPIIPlaceholderTypesOnly(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("EMAIL_ADDRESS", "jane@example.com") + m.Token("PERSON", "Jane") + + got := m.UnresolvedTokens(`{"email":"EMAIL_ADDRESS_2","person":"[Person 9]","id":"ORDER_ID_1"}`) + want := []string{"EMAIL_ADDRESS_2", "[Person 9]"} + if strings.Join(got, ",") != strings.Join(want, ",") { + t.Fatalf("UnresolvedTokens = %#v, want %#v", got, want) + } +} + +func TestPIIMapping_MaskKnownOriginals(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + m.Token("EMAIL", "john@x.com") + + got := m.MaskKnownOriginals("provider echoed John ") + want := "provider echoed [PERSON_1] <[EMAIL_1]>" + if got != want { + t.Fatalf("MaskKnownOriginals:\n got = %q\nwant = %q", got, want) + } +} + +func TestUnmask_HandlesUnterminatedAndEmpty(t *testing.T) { + m := domain.NewPIIMapping() + m.Token("PERSON", "John") + if got := domain.Unmask("", m); got != "" { + t.Fatalf("Unmask(\"\") = %q, want empty", got) + } + if got := domain.Unmask("no tokens here", m); got != "no tokens here" { + t.Fatalf("Unmask passthrough = %q", got) + } + if got := domain.Unmask("hi [PERSON_1 partial", m); got != "hi [PERSON_1 partial" { + t.Fatalf("Unmask unterminated = %q", got) + } +} + +func TestApplyMask_RoundTripWithUnmask(t *testing.T) { + text := "Contact Alice at alice@example.com about the meeting with Alice." + entities := []domain.PIIEntity{ + {Type: "PERSON", Start: 8, End: 13}, // Alice + {Type: "EMAIL", Start: 17, End: 36}, // alice@example.com + {Type: "PERSON", Start: 58, End: 63}, // Alice + } + m := domain.NewPIIMapping() + masked := domain.ApplyMask(text, entities, m) + if strings.Contains(masked, "Alice") || strings.Contains(masked, "alice@example.com") { + t.Fatalf("masked text still contains PII: %q", masked) + } + if got := domain.Unmask(masked, m); got != text { + t.Fatalf("round trip failed:\n got = %q\nwant = %q", got, text) + } +} diff --git a/pkg/domain/pricing.go b/pkg/domain/pricing.go new file mode 100644 index 0000000..a03a20d --- /dev/null +++ b/pkg/domain/pricing.go @@ -0,0 +1,14 @@ +package domain + +import "time" + +// ModelPricing represents pricing for a provider model at a point in time. +type ModelPricing struct { + ID string + ProviderModelID string + InputPricePer1MTokensMicrocents int64 + OutputPricePer1MTokensMicrocents int64 + CacheReadPricePer1MTokensMicrocents *int64 + CacheWritePricePer1MTokensMicrocents *int64 + EffectiveFrom time.Time +} diff --git a/pkg/domain/promo.go b/pkg/domain/promo.go new file mode 100644 index 0000000..51d143e --- /dev/null +++ b/pkg/domain/promo.go @@ -0,0 +1,127 @@ +package domain + +import ( + "regexp" + "strings" + "time" +) + +// Promo code configuration constants. +const ( + PromoCodeMinLength = 4 + PromoCodeMaxLength = 64 + + // NewUserGraceWindow is how long after account creation a "new users only" + // promo can still be redeemed. The auto-redeem path fires immediately after + // Authgear signup (account is auto-provisioned moments before), but a slack + // window keeps the experience resilient to clock skew, slow UI loads, and + // manual "redeem on Billing" attempts by a brand-new user. + NewUserGraceWindow = 24 * time.Hour +) + +// promoCodePattern matches normalized promo codes: uppercase letters, digits, +// underscores and hyphens. +var promoCodePattern = regexp.MustCompile(`^[A-Z0-9_-]+$`) + +// PromoCode represents an operator-issued code that grants prepaid credit. +type PromoCode struct { + ID string + Code string + AmountCents int64 + Currency string + MaxRedemptions *int + NewUsersOnly bool + Active bool + ExpiresAt *time.Time + Description *string + CreatedAt time.Time + UpdatedAt time.Time + // RedemptionCount is the number of redemptions recorded for this code. + // Populated only by list/get queries that aggregate redemptions; zero elsewhere. + RedemptionCount int +} + +// RemainingRedemptions reports how many more redemptions are allowed, or nil +// when MaxRedemptions is unset (unlimited). +func (p PromoCode) RemainingRedemptions() *int { + if p.MaxRedemptions == nil { + return nil + } + remaining := *p.MaxRedemptions - p.RedemptionCount + if remaining < 0 { + remaining = 0 + } + return &remaining +} + +// IsExpiredAt reports whether the promo code is past its expiry at the given time. +func (p PromoCode) IsExpiredAt(now time.Time) bool { + return p.ExpiresAt != nil && !now.Before(*p.ExpiresAt) +} + +// PromoRedemption records a single successful redemption of a promo code. +type PromoRedemption struct { + ID string + PromoCodeID string + AccountID string + AccountName *string + AmountMicrocents int64 + CreatedAt time.Time +} + +// RedeemResult is returned by the redemption flow. Applied is false when the +// account had already redeemed this code (idempotent success) — in that case +// Redemption holds the original redemption row. +type RedeemResult struct { + Applied bool + Redemption PromoRedemption + PromoCode PromoCode +} + +// NormalizePromoCode uppercases and trims the input and validates its shape. +// Returns the normalized code and whether it is valid. +func NormalizePromoCode(raw string) (string, bool) { + code := strings.ToUpper(strings.TrimSpace(raw)) + if len(code) < PromoCodeMinLength || len(code) > PromoCodeMaxLength { + return "", false + } + if !promoCodePattern.MatchString(code) { + return "", false + } + return code, true +} + +// IsNewUserAt reports whether accountCreatedAt falls inside the new-user grace +// window ending at now. Used to enforce PromoCode.NewUsersOnly. +func IsNewUserAt(accountCreatedAt, now time.Time) bool { + return now.Before(accountCreatedAt.Add(NewUserGraceWindow)) +} + +// Promo error constructors. These map to the control-plane error codes. +// 410 Gone signals "this code no longer accepts redemptions" (deactivated or +// expired); 409 Conflict signals "you already redeemed this code" or "you are +// not eligible"; 422 signals the cap was hit. + +func ErrPromoNotFound(code string) *GatewayError { + return &GatewayError{HTTPStatus: 404, Type: ErrTypeInvalidRequest, Code: ErrCodeNotFound, Message: "promo code '" + code + "' not found"} +} + +func ErrPromoInactive(code string) *GatewayError { + return &GatewayError{HTTPStatus: 410, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: "promo code '" + code + "' is no longer active"} +} + +func ErrPromoExpired(code string) *GatewayError { + return &GatewayError{HTTPStatus: 410, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: "promo code '" + code + "' has expired"} +} + +func ErrPromoMaxRedemptionsReached(code string) *GatewayError { + return &GatewayError{HTTPStatus: 422, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: "promo code '" + code + "' has reached its maximum redemptions"} +} + +func ErrPromoNotEligibleNewUsersOnly(code string) *GatewayError { + return &GatewayError{HTTPStatus: 409, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: "promo code '" + code + "' is only available to new users"} +} + +func ErrPromoAlreadyRedeemed(code string) *GatewayError { + return &GatewayError{HTTPStatus: 409, Type: ErrTypeInvalidRequest, Code: ErrCodeConflict, Message: "promo code '" + code + "' has already been redeemed"} +} diff --git a/pkg/domain/promo_test.go b/pkg/domain/promo_test.go new file mode 100644 index 0000000..95f5692 --- /dev/null +++ b/pkg/domain/promo_test.go @@ -0,0 +1,95 @@ +package domain + +import ( + "testing" + "time" +) + +func TestNormalizePromoCode(t *testing.T) { + cases := []struct { + in string + want string + wantOk bool + }{ + {"WELCOME5", "WELCOME5", true}, + {" welcome5 ", "WELCOME5", true}, + {"Welcome-5", "WELCOME-5", true}, + {"BONUS_CODE-3", "BONUS_CODE-3", true}, + {"ABC", "", false}, // too short + {"", "", false}, // empty + {"WITH SPACE", "", false}, // space not allowed + {"BAD!CODE", "", false}, // invalid char + {"LOWER1234", "LOWER1234", true}, + {"a_b-c1234", "A_B-C1234", true}, + } + for _, c := range cases { + got, ok := NormalizePromoCode(c.in) + if got != c.want || ok != c.wantOk { + t.Errorf("NormalizePromoCode(%q) = (%q,%v), want (%q,%v)", c.in, got, ok, c.want, c.wantOk) + } + } +} + +func TestNormalizePromoCode_TooLong(t *testing.T) { + long := "" + for i := 0; i < PromoCodeMaxLength+1; i++ { + long += "A" + } + if _, ok := NormalizePromoCode(long); ok { + t.Errorf("NormalizePromoCode accepted a code longer than %d chars", PromoCodeMaxLength) + } +} + +func TestPromoCodeRemainingRedemptions(t *testing.T) { + // unlimited + unlimited := PromoCode{MaxRedemptions: nil, RedemptionCount: 5} + if r := unlimited.RemainingRedemptions(); r != nil { + t.Errorf("unlimited code returned remaining %v, want nil", *r) + } + + // 100 / 5 used → 95 + max := 100 + p := PromoCode{MaxRedemptions: &max, RedemptionCount: 5} + r := p.RemainingRedemptions() + if r == nil || *r != 95 { + t.Errorf("remaining = %v, want 95", r) + } + + // exhausted clamps to 0 + exhausted := PromoCode{MaxRedemptions: &max, RedemptionCount: 150} + r = exhausted.RemainingRedemptions() + if r == nil || *r != 0 { + t.Errorf("remaining = %v, want 0", r) + } +} + +func TestPromoCodeIsExpiredAt(t *testing.T) { + expiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + p := PromoCode{ExpiresAt: &expiry} + if p.IsExpiredAt(expiry) != true { + t.Error("code should be expired at expiry time") + } + if p.IsExpiredAt(expiry.Add(-time.Second)) != false { + t.Error("code should not be expired one second before expiry") + } + noExpiry := PromoCode{} + if noExpiry.IsExpiredAt(time.Now()) != false { + t.Error("code with nil expiry should never be expired") + } +} + +func TestIsNewUserAt(t *testing.T) { + created := time.Date(2026, 6, 22, 12, 0, 0, 0, time.UTC) + // Within window + if !IsNewUserAt(created, created.Add(time.Hour)) { + t.Error("1h after creation should be new user") + } + // Exactly at window boundary → no longer new (exclusive) + if IsNewUserAt(created, created.Add(NewUserGraceWindow)) { + t.Error("at grace boundary should not be new user") + } + // Well after window + if IsNewUserAt(created, created.Add(48*time.Hour)) { + t.Error("48h after creation should not be new user") + } +} diff --git a/pkg/domain/provider_catalog.go b/pkg/domain/provider_catalog.go new file mode 100644 index 0000000..bd4cf27 --- /dev/null +++ b/pkg/domain/provider_catalog.go @@ -0,0 +1,16 @@ +package domain + +// ProviderCatalogEntry is one model offered by a provider, as discovered from +// the provider's own model-listing API. Prices are the provider's raw rate in +// USD per 1M tokens, before any FX conversion or markup. +type ProviderCatalogEntry struct { + ProviderModelName string + Title string + Description string + ContextSize int64 + MaxOutputTokens int64 + InputUSDPer1M float64 + OutputUSDPer1M float64 + SupportsTools bool + SupportsReasoning bool +} diff --git a/pkg/domain/provider_model.go b/pkg/domain/provider_model.go new file mode 100644 index 0000000..18b64f3 --- /dev/null +++ b/pkg/domain/provider_model.go @@ -0,0 +1,12 @@ +package domain + +import "time" + +// ProviderModel represents an upstream provider's model configuration. +type ProviderModel struct { + ID string + ProviderName string + ProviderModelName string + Active bool + CreatedAt time.Time +} diff --git a/pkg/domain/public_model_cp.go b/pkg/domain/public_model_cp.go new file mode 100644 index 0000000..266a70b --- /dev/null +++ b/pkg/domain/public_model_cp.go @@ -0,0 +1,29 @@ +package domain + +import "time" + +// PublicModelEntry represents a public model as managed by the control plane. +type PublicModelEntry struct { + ID string + DisplayName string + ProviderModelID string + Active bool + Description *string + MaxContextWindow int + MaxOutputTokens int + SupportsChatCompletions bool + SupportsChatCompletionsStream bool + SupportsTools bool + SupportsParallelToolCalls bool + SupportsStructuredOutput bool + SupportsReasoning bool + ProofMode string + CreatedAt time.Time +} + +func (m PublicModelEntry) EffectiveProofMode() string { + if m.ProofMode != "" && m.ProofMode != ProofModeNone { + return m.ProofMode + } + return ProofModeNone +} diff --git a/pkg/domain/router.go b/pkg/domain/router.go new file mode 100644 index 0000000..aa4b499 --- /dev/null +++ b/pkg/domain/router.go @@ -0,0 +1,94 @@ +package domain + +import ( + "encoding/json" + "time" +) + +// CatalogKind identifies the kind of user-facing model catalog entry. +type CatalogKind string + +const ( + CatalogKindPublicModel CatalogKind = "public_model" + CatalogKindRouter CatalogKind = "router" +) + +// RouterStrategyType identifies how a router chooses a concrete public model. +type RouterStrategyType string + +const ( + RouterStrategyFallback RouterStrategyType = "fallback" + RouterStrategyEmbedding RouterStrategyType = "embedding" +) + +// RouteRequest is the gateway-to-router request contract. +// +// The gateway passes the full canonical generation request along with the +// router_id. The router service is a black box: it loads its strategy from +// the DB and extracts whatever features that strategy needs from the request +// internally. When a new strategy needs new features, only the router +// service changes — the gateway stays untouched. +type RouteRequest struct { + RouterID string + Request GenerateRequest +} + +// RouteDecision is the router service output contract. +type RouteDecision struct { + PublicModelID string + Category *string + Score *float32 + CategoryScores []RoutingCategoryScore + FallbackUsed bool + Reason string +} + +// RoutingCategoryScore is the per-category embedding score snapshot returned +// by the router and persisted with routed usage events. +type RoutingCategoryScore struct { + Category string `json:"category"` + Score float32 `json:"score"` + Threshold float32 `json:"threshold"` + PassedThreshold bool `json:"passed_threshold"` + Selected bool `json:"selected"` +} + +// RouterEntry is a user-facing router configuration managed by the control plane. +type RouterEntry struct { + ID string + RouterID string + DisplayName string + Description *string + FallbackPublicModelID string + StrategyType RouterStrategyType + StrategyConfig json.RawMessage + Active bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// ModelCatalogEntry is returned by catalog APIs that mix concrete public models and routers. +type ModelCatalogEntry struct { + Kind CatalogKind + PublicModel *PublicModel + Router *RouterEntry + EURToUSDRate float64 +} + +// RouterCategory is a strategy-specific routing target managed by the router +// service. The control plane treats categories as opaque metadata it can +// proxy on behalf of the admin UI. +type RouterCategory struct { + ID string + RouterID string + Name string + PublicModelID string + Threshold float32 +} + +// RouterCategoryInput is the create/update payload for a router category. +type RouterCategoryInput struct { + Name string + PublicModelID string + Threshold float32 +} diff --git a/pkg/domain/stream.go b/pkg/domain/stream.go new file mode 100644 index 0000000..062fc7c --- /dev/null +++ b/pkg/domain/stream.go @@ -0,0 +1,24 @@ +package domain + +// StreamEventType constants. +const ( + StreamEventOutputTextDelta = "output_text_delta" + StreamEventOutputMessageDelta = "output_message_delta" + StreamEventToolCallDelta = "tool_call_delta" + StreamEventCompleted = "completed" + StreamEventError = "error" +) + +// StreamEvent is a canonical stream event normalized from provider stream formats. +type StreamEvent struct { + Type string // output_text_delta, output_message_delta, tool_call_delta, completed, error + ProviderResponseID string + ChoiceIndex *int + Role *string + ContentDelta *string + ReasoningDelta *string + ToolCallDelta *ToolCallDelta + FinishReason *string + Usage *Usage + Error *GatewayError +} diff --git a/pkg/domain/tee.go b/pkg/domain/tee.go new file mode 100644 index 0000000..cbfb084 --- /dev/null +++ b/pkg/domain/tee.go @@ -0,0 +1,56 @@ +package domain + +import ( + "encoding/json" + "time" +) + +const ( + ProofStatusVerified = "verified" + ProofStatusFailed = "failed" + ProofStatusNotAvailable = "not_available" + ProofStatusUnsupported = "unsupported" +) + +// TinfoilTransportProof is the permanent, safe proof record for one response +// generated over Tinfoil's attested encrypted transport. It contains no raw +// prompt, decrypted request body, raw response, or decrypted response body. +type TinfoilTransportProof struct { + ID string + AccountID string + APIKeyID string + Provider string + PublicModelID string + UpstreamModelID string + ProviderResponseID string + EnclaveHost *string + ConfigRepo *string + Digest *string + CodeFingerprint *string + EnclaveFingerprint *string + TLSPublicKey *string + HPKEPublicKey *string + TransportMode *string + SDKVersion *string + Status string + FailureReason *string + VerificationEvidenceJSON json.RawMessage + CreatedAt time.Time + VerifiedAt *time.Time +} + +// TinfoilProofListParams describes a user-facing proof history query. +type TinfoilProofListParams struct { + Offset int + Limit int + Status string + Query string +} + +// TinfoilTransportProofRecord enriches a proof with safe API key context for +// dashboard history views. It never contains raw API key material. +type TinfoilTransportProofRecord struct { + Proof TinfoilTransportProof + APIKeyName *string + APIKeyPrefix *string +} diff --git a/pkg/domain/tool.go b/pkg/domain/tool.go new file mode 100644 index 0000000..d5ca835 --- /dev/null +++ b/pkg/domain/tool.go @@ -0,0 +1,38 @@ +package domain + +// ToolDefinition describes a function tool available to the model. +type ToolDefinition struct { + Name string + Description string + Parameters map[string]any + Strict bool +} + +// ToolCall represents a tool invocation returned by the model. +type ToolCall struct { + ID string + Name string + ArgumentsJSON string +} + +// ToolCallDelta represents a partial tool call in a stream event. +type ToolCallDelta struct { + Index int + ID *string + Name *string + ArgumentsDelta *string +} + +// ToolChoiceMode constants. +const ( + ToolChoiceNone = "none" + ToolChoiceAuto = "auto" + ToolChoiceRequired = "required" + ToolChoiceFunction = "function" +) + +// ToolChoice represents the tool_choice parameter. +type ToolChoice struct { + Mode string // "none", "auto", "required", "function" + FunctionName *string // set only when Mode == "function" +} diff --git a/pkg/domain/usage.go b/pkg/domain/usage.go new file mode 100644 index 0000000..021a152 --- /dev/null +++ b/pkg/domain/usage.go @@ -0,0 +1,10 @@ +package domain + +// Usage holds token usage information for a generation. +type Usage struct { + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 + CacheCreationTokens int64 + CacheReadTokens int64 +} diff --git a/pkg/domain/usage_query.go b/pkg/domain/usage_query.go new file mode 100644 index 0000000..bbe2a84 --- /dev/null +++ b/pkg/domain/usage_query.go @@ -0,0 +1,80 @@ +package domain + +import "time" + +// UsageEvent represents a single usage event as stored in the database. +type UsageEvent struct { + ID string + RequestID *string + AccountID string + APIKeyID string + PublicModelID *string + ProviderModelID *string + RouterID *string + RoutedPublicModelID *string + PublicModelName *string + ProviderName *string + ProviderRequestID *string + InputTokens *int64 + OutputTokens *int64 + CacheCreationTokens *int64 + CacheReadTokens *int64 + InputPricePer1MTokensMicrocents *int64 + OutputPricePer1MTokensMicrocents *int64 + CacheReadPricePer1MTokensMicrocents *int64 + CacheWritePricePer1MTokensMicrocents *int64 + CostMicrocents *int64 + ChargedMonthlyMicrocents *int64 + ChargedPrepaidMicrocents *int64 + LatencyMs *int64 + CreatedAt time.Time + // AccountEmail is populated by list queries that JOIN accounts; nil elsewhere. + AccountEmail *string +} + +// UsageSummary holds aggregated usage data for an account or globally. +type UsageSummary struct { + TotalRequests int64 + TotalInputTokens int64 + TotalOutputTokens int64 + TotalCostMicrocents int64 + AvgLatencyMs float64 + From time.Time + To time.Time +} + +// UsageTimeBucket holds usage data aggregated for a single time bucket. +type UsageTimeBucket struct { + Bucket time.Time + TotalRequests int64 + TotalInputTokens int64 + TotalOutputTokens int64 + TotalCostMicrocents int64 +} + +// UsageEventRoutingDetails is the minimal projection used to surface the +// router decision behind a single usage event. RoutingScore is snapshotted +// at decision time. Threshold is the current configured threshold of the +// matched category (looked up at read time, not snapshotted), and is nil +// when no category matched or the category was since deleted. +type UsageEventRoutingDetails struct { + EventID string + RouterID *string + MatchedCategory *string + RoutingScore *float32 + CategoryScores []RoutingCategoryScore + MatchedThreshold *float32 + DecisionReason *string + FallbackUsed *bool +} + +// UsageByDimension holds usage data aggregated by an arbitrary grouping key +// (e.g. model name, provider name, or account ID). +type UsageByDimension struct { + Key string + Label string + TotalRequests int64 + TotalInputTokens int64 + TotalOutputTokens int64 + TotalCostMicrocents int64 +} diff --git a/pkg/fx/frankfurter.go b/pkg/fx/frankfurter.go new file mode 100644 index 0000000..7dc9036 --- /dev/null +++ b/pkg/fx/frankfurter.go @@ -0,0 +1,117 @@ +// Package fx provides cached exchange-rate clients. +package fx + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +const frankfurterURL = "https://api.frankfurter.app/latest" + +// CacheTTL bounds how stale a cached rate may be. ECB publishes once per +// business day, so an hour is comfortably fresh while sparing the API. +const CacheTTL = time.Hour + +// Frankfurter fetches and caches exchange rates from the Frankfurter API. +type Frankfurter struct { + httpClient *http.Client + + mu sync.Mutex + cached map[rateKey]cachedRate +} + +type rateKey struct { + from string + to string +} + +type cachedRate struct { + rate float64 + fetchedAt time.Time +} + +// NewFrankfurter builds an FX rate provider with the given request timeout. +func NewFrankfurter(timeout time.Duration) *Frankfurter { + return &Frankfurter{ + httpClient: &http.Client{Timeout: timeout}, + cached: make(map[rateKey]cachedRate), + } +} + +// USDToEUR returns how many euros one US dollar is worth. +func (f *Frankfurter) USDToEUR(ctx context.Context) (float64, error) { + return f.Rate(ctx, "USD", "EUR") +} + +// EURToUSD returns how many US dollars one euro is worth. +func (f *Frankfurter) EURToUSD(ctx context.Context) (float64, error) { + return f.Rate(ctx, "EUR", "USD") +} + +// Rate returns the exchange rate from one currency to another. +func (f *Frankfurter) Rate(ctx context.Context, from, to string) (float64, error) { + key := rateKey{from: from, to: to} + + f.mu.Lock() + defer f.mu.Unlock() + + if cached, ok := f.cached[key]; ok && cached.rate > 0 && time.Since(cached.fetchedAt) < CacheTTL { + return cached.rate, nil + } + + rate, err := f.fetch(ctx, from, to) + if err != nil { + // Fall back to a stale rate rather than failing callers that can + // tolerate a slightly old conversion. + if cached, ok := f.cached[key]; ok && cached.rate > 0 { + return cached.rate, nil + } + return 0, err + } + + f.cached[key] = cachedRate{rate: rate, fetchedAt: time.Now()} + return rate, nil +} + +func (f *Frankfurter) fetch(ctx context.Context, from, to string) (float64, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, frankfurterURL, nil) + if err != nil { + return 0, fmt.Errorf("fx: build request: %w", err) + } + q := req.URL.Query() + q.Set("from", from) + q.Set("to", to) + req.URL.RawQuery = q.Encode() + req.Header.Set("Accept", "application/json") + + resp, err := f.httpClient.Do(req) + if err != nil { + return 0, fmt.Errorf("fx: request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return 0, fmt.Errorf("fx: read response: %w", err) + } + if resp.StatusCode >= 400 { + return 0, fmt.Errorf("fx: unexpected status %d", resp.StatusCode) + } + + var parsed struct { + Rates map[string]float64 `json:"rates"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return 0, fmt.Errorf("fx: decode response: %w", err) + } + rate := parsed.Rates[to] + if rate <= 0 { + return 0, fmt.Errorf("fx: missing or invalid %s rate", to) + } + return rate, nil +} diff --git a/pkg/observability/logger/zap.go b/pkg/observability/logger/zap.go new file mode 100644 index 0000000..ac2a60b --- /dev/null +++ b/pkg/observability/logger/zap.go @@ -0,0 +1,76 @@ +package logger + +import ( + "fmt" + "strings" + + "go.uber.org/zap" +) + +// ZapLogger wraps zap.SugaredLogger to implement the Logger port. +type ZapLogger struct { + sugar *zap.SugaredLogger + verbose bool +} + +func NewZapLogger(level string) (*ZapLogger, error) { + cfg := zap.NewProductionConfig() + cfg.Level = zap.NewAtomicLevel() + verbose := false + + switch level { + case "debug": + cfg.Level.SetLevel(zap.DebugLevel) + verbose = true + case "info": + cfg.Level.SetLevel(zap.InfoLevel) + case "warn": + cfg.Level.SetLevel(zap.WarnLevel) + case "error": + cfg.Level.SetLevel(zap.ErrorLevel) + default: + cfg.Level.SetLevel(zap.InfoLevel) + } + + l, err := cfg.Build(zap.AddCallerSkip(1)) + if err != nil { + return nil, err + } + + return &ZapLogger{ + sugar: l.Sugar(), + verbose: verbose, + }, nil +} + +func (l *ZapLogger) Debug(msg string, fields ...any) { + l.sugar.Debugw(msg, fields...) +} + +func (l *ZapLogger) Info(msg string, fields ...any) { + l.sugar.Infow(msg, fields...) +} + +func (l *ZapLogger) Warn(msg string, fields ...any) { + l.sugar.Warnw(msg, fields...) +} + +func (l *ZapLogger) Error(msg string, fields ...any) { + l.sugar.Errorw(msg, fields...) +} + +func (l *ZapLogger) Printf(format string, args ...any) { + msg := strings.TrimSpace(fmt.Sprintf(format, args...)) + if msg == "" { + return + } + l.sugar.Info(msg) +} + +func (l *ZapLogger) Verbose() bool { + return l.verbose +} + +func (l *ZapLogger) Sync() { + l.sugar.Sync() +}