From 91c956ebab1a3f70f5f50a44df2d0d48946d2584 Mon Sep 17 00:00:00 2001 From: Mallory Hill Date: Wed, 24 Jun 2026 17:14:54 -0400 Subject: [PATCH] HYPERFLEET-1259 - fix: SQL injection protection and allowlist for order --- CHANGELOG.md | 1 + pkg/db/sql_helpers.go | 93 ++++--- pkg/db/sql_helpers_test.go | 174 ++++++++++++ pkg/services/generic.go | 6 +- test/integration/clusters_test.go | 2 +- test/integration/order_field_mapping_test.go | 277 +++++++++++++++++++ 6 files changed, 515 insertions(+), 38 deletions(-) create mode 100644 test/integration/order_field_mapping_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e10eb730..2b09e2f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Restricted `order` query parameter to only allow specific whitelisted fields ([#244](https://github.com/openshift-hyperfleet/hyperfleet-api/pull/244)) - Validated adapter status conditions in handler layer ([#88](https://github.com/openshift-hyperfleet/hyperfleet-api/pull/88)) - Removed org prefix from image.repository default ([#86](https://github.com/openshift-hyperfleet/hyperfleet-api/pull/86)) - Addressed revive linter violations from enabled linting standard ([#85](https://github.com/openshift-hyperfleet/hyperfleet-api/pull/85)) diff --git a/pkg/db/sql_helpers.go b/pkg/db/sql_helpers.go index b1e0bdd4..2c52ef6e 100755 --- a/pkg/db/sql_helpers.go +++ b/pkg/db/sql_helpers.go @@ -553,49 +553,74 @@ func IdentWalk(n *tsl.Node, check func(string) (string, error)) (*tsl.Node, erro } } -// cleanOrderBy takes the orderBy arg and cleans it. -func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) { - var orderField string +// orderAllowedFields defines the whitelist of fields that are allowed to be ordered. +// This prevents SQL injection and restricts invalid order queries. +var orderAllowedFields = map[string]bool{ + "id": true, + "name": true, + "created_time": true, + "updated_time": true, + "deleted_time": true, + "kind": true, + "created_by": true, + "updated_by": true, + "deleted_by": true, + "generation": true, + "href": true, +} - trimedName := strings.Trim(userArg, " ") +// orderPattern matches valid order syntax: field name (letters, digits, underscore) followed by optional asc/desc. +// This regex rejects SQL injection attempts (semicolons, parentheses, dashes, comments, etc). +var orderPattern = regexp.MustCompile(`^[a-z_][a-z_]*(\s+(asc|desc))?$`) + +// ArgsToOrder validates and cleans order arguments against the allowed fields whitelist. +// Returns a cleaned list of order clauses in the format ["field direction", ...] +// Empty or whitespace-only strings are silently skipped. +func ArgsToOrder(args []string) (cleanedOrderList []string, err *errors.ServiceError) { + for _, val := range args { + // Accept args with trailing and leading spaces + trimVal := strings.TrimSpace(val) + + // Skip empty strings silently + if trimVal == "" { + continue + } - order := strings.Split(trimedName, " ") - direction := "none valid" + // Check for SQL injection attempts before parsing + if !orderPattern.MatchString(trimVal) { + return nil, errors.BadRequest("invalid order format '%s': expected 'field' or 'field asc|desc'", val) + } - if len(order) == 1 { - orderField, err = getField(order[0], disallowedFields) - direction = "asc" - } else if len(order) == 2 { - orderField, err = getField(order[0], disallowedFields) - direction = order[1] - } - if err != nil || (direction != "asc" && direction != "desc") { - err = errors.BadRequest("bad order value '%s'", userArg) - return - } + // Each value should be "" or " asc|desc" + splitVal := strings.Fields(trimVal) + lenVal := len(splitVal) - orderBy = fmt.Sprintf("%s %s", orderField, direction) - return -} + var field, direction string -// ArgsToOrderBy returns cleaned orderBy list. -func ArgsToOrderBy( - orderByArgs []string, - disallowedFields map[string]string, -) (orderBy []string, err *errors.ServiceError) { - var order string - if len(orderByArgs) != 0 { - orderBy = []string{} - for _, o := range orderByArgs { - order, err = cleanOrderBy(o, disallowedFields) - if err != nil { - return + switch lenVal { + case 2: + field = splitVal[0] + direction = splitVal[1] + if direction != "asc" && direction != "desc" { + return nil, errors.BadRequest("invalid sort direction '%s': must be 'asc' or 'desc'", direction) } + case 1: + field = splitVal[0] + direction = "asc" + default: + return nil, errors.BadRequest("invalid order format '%s': expected 'field' or 'field asc|desc'", val) + } - orderBy = append(orderBy, order) + // Validate field against orderAllowedFields + if !orderAllowedFields[field] { + return nil, errors.BadRequest("field '%s' is not allowed for ordering", field) } + + cleanedValue := fmt.Sprintf("%s %s", field, direction) + cleanedOrderList = append(cleanedOrderList, cleanedValue) } - return + + return cleanedOrderList, nil } func GetTableName(g2 *gorm.DB) string { diff --git a/pkg/db/sql_helpers_test.go b/pkg/db/sql_helpers_test.go index 0561cf36..bb9ae1c1 100644 --- a/pkg/db/sql_helpers_test.go +++ b/pkg/db/sql_helpers_test.go @@ -834,3 +834,177 @@ func TestConditionStatusValidation(t *testing.T) { }) } } + +func TestArgsToOrder(t *testing.T) { + tests := []struct { + name string + errorContains string + input []string + expected []string + expectError bool + }{ + { + name: "single field with asc", + input: []string{"name asc"}, + expected: []string{"name asc"}, + }, + { + name: "single field with desc", + input: []string{"created_time desc"}, + expected: []string{"created_time desc"}, + }, + { + name: "single field without direction defaults to asc", + input: []string{"created_time"}, + expected: []string{"created_time asc"}, + }, + { + name: "multiple fields", + input: []string{"name asc", "created_time desc"}, + expected: []string{"name asc", "created_time desc"}, + }, + { + name: "field with extra spaces", + input: []string{" name asc "}, + expected: []string{"name asc"}, + }, + { + name: "field with tabs and spaces", + input: []string{"name \t desc"}, + expected: []string{"name desc"}, + }, + { + name: "all allowed fields", + input: []string{"id", "name", "created_time", "updated_time", "kind"}, + expected: []string{"id asc", "name asc", "created_time asc", "updated_time asc", "kind asc"}, + }, + { + name: "invalid direction", + input: []string{"name ascending"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "SQL injection attempt - semicolon", + input: []string{"name; DROP TABLE resources"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "SQL injection attempt - comment", + input: []string{"name-- asc"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "uppercase field name", + input: []string{"NAME asc"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "uppercase direction", + input: []string{"name ASC"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "empty string in array is skipped", + input: []string{""}, + expected: nil, + }, + { + name: "empty string in array with field at end", + input: []string{"", "", "", "", "", "kind asc", "href desc"}, + expected: []string{"kind asc", "href desc"}, + }, + { + name: "whitespace only string is skipped", + input: []string{" "}, + expected: nil, + }, + { + name: "mixed valid and empty strings and tabs", + input: []string{"name asc", "", "created_time desc", " ", "\t"}, + expected: []string{"name asc", "created_time desc"}, + }, + { + name: "mixed valid and invalid field", + input: []string{"created_time desc", "name", "wrong_field"}, + expectError: true, + errorContains: "not allowed for ordering", + }, + { + name: "field not in whitelist", + input: []string{"custom_field asc"}, + expectError: true, + errorContains: "not allowed for ordering", + }, + { + name: "deleted_time field", + input: []string{"deleted_time desc"}, + expected: []string{"deleted_time desc"}, + }, + { + name: "generation field", + input: []string{"generation asc"}, + expected: []string{"generation asc"}, + }, + { + name: "too many parts", + input: []string{"name asc extra"}, + expectError: true, + errorContains: "invalid order format", + }, + { + name: "empty array", + input: []string{}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + RegisterTestingT(t) + + result, err := ArgsToOrder(tt.input) + + if tt.expectError { + Expect(err).ToNot(BeNil(), "expected error but got nil") + if tt.errorContains != "" { + Expect(err.Reason).To(ContainSubstring(tt.errorContains)) + } + } else { + Expect(err).To(BeNil(), "unexpected error: %v", err) + Expect(result).To(Equal(tt.expected)) + } + }) + } +} + +func TestArgsToOrder_SecurityValidation(t *testing.T) { + RegisterTestingT(t) + + // SQL injection attempts that should all fail + injectionAttempts := []struct { + name string + input string + }{ + {"union injection", "name UNION SELECT password FROM users"}, + {"comment injection", "name--"}, + {"semicolon terminator", "name; DROP TABLE resources;"}, + {"quote escape", "name' OR '1'='1"}, + {"parentheses", "name) OR (1=1"}, + {"wildcard", "name*"}, + {"backtick", "name`"}, + } + + for _, tt := range injectionAttempts { + t.Run(tt.name, func(t *testing.T) { + RegisterTestingT(t) + + _, err := ArgsToOrder([]string{tt.input}) + Expect(err).ToNot(BeNil(), "injection attempt '%s' should be rejected", tt.input) + }) + } +} diff --git a/pkg/services/generic.go b/pkg/services/generic.go index 8c048acb..25db050a 100755 --- a/pkg/services/generic.go +++ b/pkg/services/generic.go @@ -142,12 +142,12 @@ func (s *sqlGenericService) buildPreload(listCtx *listContext, d *dao.GenericDao func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao) (bool, *errors.ServiceError) { if len(listCtx.args.Order) != 0 { - orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.Order, *listCtx.disallowedFields) + cleanedOrderList, serviceErr := db.ArgsToOrder(listCtx.args.Order) if serviceErr != nil { return false, serviceErr } - for _, orderByArg := range orderByArgs { - (*d).OrderBy(orderByArg) + for _, orderArg := range cleanedOrderList { + (*d).OrderBy(orderArg) } } return false, nil diff --git a/test/integration/clusters_test.go b/test/integration/clusters_test.go index 3d462a6b..35e72367 100644 --- a/test/integration/clusters_test.go +++ b/test/integration/clusters_test.go @@ -688,7 +688,7 @@ func TestClusterList_DefaultSorting(t *testing.T) { t.Logf("✓ Default sorting works: clusters sorted by created_time desc") } -// TestClusterList_OrderByName tests custom sorting by name +// TestClusterList_OrderName tests custom sorting by name func TestClusterList_OrderName(t *testing.T) { h, client := test.RegisterIntegration(t) diff --git a/test/integration/order_field_mapping_test.go b/test/integration/order_field_mapping_test.go new file mode 100644 index 00000000..497fc585 --- /dev/null +++ b/test/integration/order_field_mapping_test.go @@ -0,0 +1,277 @@ +package integration + +import ( + "fmt" + "net/http" + "testing" + + . "github.com/onsi/gomega" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api/openapi" + "github.com/openshift-hyperfleet/hyperfleet-api/test" + "github.com/openshift-hyperfleet/hyperfleet-api/test/factories" +) + +// TestOrderFieldMapping verifies that order parameters correctly map to database columns +// and produce properly sorted results. +func TestOrderFieldMapping(t *testing.T) { + RegisterTestingT(t) + h, client := test.RegisterIntegration(t) + + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create 5 clusters for ordering tests + for range 5 { + _, err := factories.NewClusterWithStatus(&h.Factories, h.DBFactory, h.NewID(), true, true) + Expect(err).NotTo(HaveOccurred()) + } + + t.Run("OrderName", func(t *testing.T) { + RegisterTestingT(t) + + orderAsc := openapi.QueryParamsOrder("name asc") + params := &openapi.GetClustersParams{ + Order: &orderAsc, + } + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK)) + list := resp.JSON200 + Expect(list).NotTo(BeNil()) + + // Verify ascending order - each name should be >= previous + items := list.Items + if len(items) >= 2 { + for i := 1; i < len(items); i++ { + prevName := items[i-1].Name + currName := items[i].Name + Expect(currName >= prevName).To(BeTrue(), + fmt.Sprintf("Names should be in ascending order: %s >= %s", currName, prevName)) + } + } + + // Order by desc + orderDesc := openapi.QueryParamsOrder("name desc") + params = &openapi.GetClustersParams{ + Order: &orderDesc, + } + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK)) + list = resp.JSON200 + Expect(list).NotTo(BeNil()) + items = list.Items + if len(items) >= 2 { + for i := 1; i < len(items); i++ { + prevName := items[i-1].Name + currName := items[i].Name + Expect(currName <= prevName).To(BeTrue(), + fmt.Sprintf("Names should be in descending order: %s <= %s", currName, prevName)) + } + } + }) + + // order defaults to created_time desc + t.Run("OrderDefault", func(t *testing.T) { + RegisterTestingT(t) + + // Default ordering is created_time desc + params := &openapi.GetClustersParams{} + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK)) + list := resp.JSON200 + Expect(list).NotTo(BeNil()) + + // Verify results are ordered by created_time descending (newest first) + items := list.Items + if len(items) >= 2 { + // Each subsequent item should have created_time <= previous + for i := 1; i < len(items); i++ { + prevTime := items[i-1].CreatedTime + currTime := items[i].CreatedTime + Expect(currTime.Before(prevTime) || currTime.Equal(prevTime)).To(BeTrue(), + fmt.Sprintf("created_time should be descending: %v should be <= %v", + currTime, prevTime)) + } + } + }) + t.Run("MultipleOrderFields", func(t *testing.T) { + RegisterTestingT(t) + + // Order by kind asc, then name desc + order := openapi.QueryParamsOrder("kind asc,name desc") + params := &openapi.GetClustersParams{ + Order: &order, + } + + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK)) + list := resp.JSON200 + Expect(list).NotTo(BeNil()) + + // Verify multi-level ordering + items := list.Items + if len(items) >= 2 { + for i := 1; i < len(items); i++ { + prevKind := items[i-1].Kind + currKind := items[i].Kind + prevName := items[i-1].Name + currName := items[i].Name + + // Primary key: kind should be in ascending order (equal) + Expect(currKind == prevKind).To(BeTrue(), fmt.Sprintf("Kinds should be equal: %s == %s", currKind, prevKind)) + + // Secondary key: within same kind, names should be ascending + Expect(currName <= prevName).To(BeTrue(), + fmt.Sprintf("Within same kind, names should be descending: %s <= %s", + currName, prevName)) + } + } + }) +} + +// TestOrderFieldValidation verifies that invalid order parameters return proper errors +func TestOrderFieldValidation(t *testing.T) { + RegisterTestingT(t) + h, client := test.RegisterIntegration(t) + + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + tests := []struct { + name string + order string + expectedError string + }{ + { + name: "InvalidFieldName", + order: "nonexistent_field asc", + expectedError: "not allowed for ordering", + }, + { + name: "InvalidDirection", + order: "name ascending", + expectedError: "invalid order format", + }, + { + name: "SQLInjectionAttempt1", + order: "name; DROP TABLE clusters", + expectedError: "invalid order format", + }, + { + name: "SQLInjectionAttempt2", + order: "name; (SELECT/**/tablename::int/**/FROM/**/pg_tables/**/LIMIT/**/1/**/OFFSET/**/N)", + expectedError: "invalid order format", + }, + { + name: "SQLInjectionAttempt3", + order: "pg_sleep(10)", + expectedError: "invalid order format", + }, + { + name: "SQLInjectionAttempt4", + order: "(SELECT/**/current_user::int)", + expectedError: "invalid order format", + }, + { + name: "TooManyParts", + order: "name asc extra", + expectedError: "invalid order format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + RegisterTestingT(t) + + order := openapi.QueryParamsOrder(tt.order) + params := &openapi.GetClustersParams{ + Order: &order, + } + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusBadRequest)) + Expect(string(resp.Body)).To(ContainSubstring(tt.expectedError)) + }) + } +} + +// TestOrderAllowedFields verifies that all whitelisted fields work correctly +func TestOrderAllowedFields(t *testing.T) { + RegisterTestingT(t) + h, client := test.RegisterIntegration(t) + + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create a test cluster + _, err := factories.NewClusterWithStatus(&h.Factories, h.DBFactory, h.NewID(), true, true) + Expect(err).NotTo(HaveOccurred()) + + allowedFields := []string{ + "id", + "name", + "created_time", + "updated_time", + "kind", + "created_by", + "updated_by", + "generation", + "href", + "deleted_time", + "deleted_by", + } + + for _, field := range allowedFields { + t.Run(fmt.Sprintf("Order_%s", field), func(t *testing.T) { + RegisterTestingT(t) + + order := openapi.QueryParamsOrder(fmt.Sprintf("%s asc", field)) + params := &openapi.GetClustersParams{ + Order: &order, + } + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK), + fmt.Sprintf("Field %s should be allowed for ordering", field)) + Expect(resp.JSON200).NotTo(BeNil()) + }) + } +} + +// TestOrderEmptyStrings verifies that empty order values are handled gracefully +func TestOrderEmptyStrings(t *testing.T) { + RegisterTestingT(t) + h, client := test.RegisterIntegration(t) + + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create a test cluster + _, err := factories.NewClusterWithStatus(&h.Factories, h.DBFactory, h.NewID(), true, true) + Expect(err).NotTo(HaveOccurred()) + + t.Run("EmptyOrder", func(t *testing.T) { + RegisterTestingT(t) + + // Empty string should fall back to default (created_time desc) + order := openapi.QueryParamsOrder("") + params := &openapi.GetClustersParams{ + Order: &order, + } + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusOK)) + Expect(resp.JSON200).NotTo(BeNil()) + }) + +}