diff --git a/.github/workflows/job_test_go_api_local.yaml b/.github/workflows/job_test_go_api_local.yaml index a26a82d73b..4fbd54308e 100644 --- a/.github/workflows/job_test_go_api_local.yaml +++ b/.github/workflows/job_test_go_api_local.yaml @@ -18,5 +18,5 @@ jobs: working-directory: go - name: Test - run: go test -cover -json -timeout=60m -failfast ./... | tparse -all -progress + run: task test working-directory: go diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 738f4c5dd0..fbfca27cc1 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -45,6 +45,6 @@ jobs: # AGENT_BASE_URL: "http://localhost:8080" # working-directory: apps/agent - # test_go_api_local: - # name: Test Go API Local - # uses: ./.github/workflows/job_test_go_api_local.yaml + test_go_api_local: + name: Test Go API Local + uses: ./.github/workflows/job_test_go_api_local.yaml diff --git a/apps/engineering/app/architecture/[[...slug]]/page.tsx b/apps/engineering/app/architecture/[[...slug]]/page.tsx index 3cc378d34d..97e31755b8 100644 --- a/apps/engineering/app/architecture/[[...slug]]/page.tsx +++ b/apps/engineering/app/architecture/[[...slug]]/page.tsx @@ -36,9 +36,9 @@ export default async function Page(props: { > {page.data.title} - {page.data.description} + {page.data.description} - + diff --git a/apps/engineering/app/architecture/layout.tsx b/apps/engineering/app/architecture/layout.tsx index 3ec7507d92..fd24222a50 100644 --- a/apps/engineering/app/architecture/layout.tsx +++ b/apps/engineering/app/architecture/layout.tsx @@ -5,10 +5,8 @@ import { baseOptions } from "../layout.config"; export default function Layout({ children }: { children: ReactNode }) { return ( -
- - {children} - -
+ + {children} + ); } diff --git a/apps/engineering/app/company/[[...slug]]/page.tsx b/apps/engineering/app/company/[[...slug]]/page.tsx index b09b22a7bf..0ce3ca0a74 100644 --- a/apps/engineering/app/company/[[...slug]]/page.tsx +++ b/apps/engineering/app/company/[[...slug]]/page.tsx @@ -37,7 +37,7 @@ export default async function Page(props: { > {page.data.title} - {page.data.description} + {page.data.description} diff --git a/apps/engineering/app/contributing/[[...slug]]/page.tsx b/apps/engineering/app/contributing/[[...slug]]/page.tsx index 488eed3b8e..3a31b5a295 100644 --- a/apps/engineering/app/contributing/[[...slug]]/page.tsx +++ b/apps/engineering/app/contributing/[[...slug]]/page.tsx @@ -37,9 +37,9 @@ export default async function Page(props: { > {page.data.title} - {page.data.description} + {page.data.description} - + diff --git a/apps/engineering/app/rfcs/layout.tsx b/apps/engineering/app/rfcs/layout.tsx index e84828ef2e..b7c9599bcf 100644 --- a/apps/engineering/app/rfcs/layout.tsx +++ b/apps/engineering/app/rfcs/layout.tsx @@ -5,7 +5,7 @@ import { baseOptions } from "../layout.config"; export default function Layout({ children }: { children: ReactNode }) { return ( -
+
{children} diff --git a/apps/engineering/content/architecture/services/api.mdx b/apps/engineering/content/architecture/services/api.mdx new file mode 100644 index 0000000000..16c7687d2f --- /dev/null +++ b/apps/engineering/content/architecture/services/api.mdx @@ -0,0 +1,211 @@ +--- +title: API +--- +import { Step, Steps } from 'fumadocs-ui/components/steps'; +import { TypeTable } from 'fumadocs-ui/components/type-table'; +import {Property} from "fumadocs-openapi/ui" + + + + This document only covers v2 of the Unkey API. The v1 API on Cloudflare Workers is deprecated and will be removed in the future. It was too hard to selfhost anyways. + + +Our API runs on AWS containers, in multiple regions behind a global load balancer to ensure high availability and low latency. + + +The source code is available on [GitHub](https://github.com/unkeyed/unkey/tree/main/go/cmd/api). + +## Quickstart + +To get started, you need [go1.24+](https://go.dev/dl/) installed on your machine. + + + + + ### Clone the repository: + +```bash +git clone git@github.com:unkeyed/unkey.git +cd unkey/go +``` + + + + ### Build the binary: + +```bash +go build -o unkey . +``` + + + + ### Run the binary: + +```bash +unkey api --config ./path/to/config.json +``` + +You should now be able to access the API at + +```bash +$ curl http://localhost:/v2/liveness +{"message":"we're cooking"}% +``` + + + + + + +## Configuration + +The API server requires a json configuration file to be passed as an argument to the binary. + +You can use `${SOME_NAME}` as placeholder in your config file and it will be replaced by the value of the environment variable `SOME_NAME`. + + +```json title="Example" +{ + "httpPort": "${PORT}" +} +``` + +The most up to date json schema can be found here: [https://raw.githubusercontent.com/unkeyed/unkey/refs/heads/main/go/schema.json](https://raw.githubusercontent.com/unkeyed/unkey/refs/heads/main/go/schema.json) + +Most IDEs support JSON schema validation if you put it in your config file. +```json + { + "$schema": "https://raw.githubusercontent.com/unkeyed/unkey/refs/heads/main/go/schema.json", + // ... + } +``` + + + You can check out our own configuration files on [GitHub](https://github.com/unkeyed/unkey/tree/main/go). + + + + + +### General Configuration + +These settings define the fundamental properties of the server. + + + The platform this server is running on ("aws", "gcp", ...). + + Most metrics include this information to help with troubleshooting and monitoring. + + + + The container image and version identifier for this instance. + + + + The HTTP port where the server will listen for incoming connections. Defaults to 7070. + + + + + Geographic region identifier where this server is deployed (e.g., "us-west-1") + + +### Heartbeat Configuration +The API can send heartbeats to a destination to monitor its health. + + + Configuration for server health check reporting. Contains the following properties: + + + Endpoint URL where heartbeat signals will be sent (e.g., "http://monitor.example.com/heartbeat") + + + + Time between heartbeat signals in seconds (e.g., 30 for checking every half minute) + + + +### Cluster Configuration +Settings for cluster operation when running multiple server instances together. + + + Settings for multi-server cluster operations: + + + Unique identifier for this node in the cluster (e.g., "node-1", "server-east-1") + + If omitted, a random id will be generated. + + + + Network address other nodes will use to contact this node (e.g., "10.0.0.1", "node1.example.com") + + + + Port used for internal cluster communication via RPC. Defaults to "7071". + + + + Port used for cluster membership and state dissemination. Defaults to "7072". + + + + Configuration for how cluster nodes discover each other on startup. + + Only one discovery method can be configured at a time. + + + Fixed list of cluster nodes for stable environments: + + + List of node addresses to connect to (e.g., ["node1:7071", "node2:7071"]) + + + + + Redis-based dynamic discovery for flexible environments. + All nodes will send heartbeats to the Redis server with their address. + + + Redis connection string (e.g., "redis://redis.example.com:6379") + + + + + +### Logging and Monitoring +Configuration for observability and debugging capabilities. + + + Logging configuration settings: + + + Enable colored output in log messages for better readability + + + + +### Database Configuration +Database connection settings for the server's data storage. + + + Database connection configuration: + + + Primary database connection string for read and write operations (e.g., "postgresql://user:pass@localhost:5432/dbname") + + + + Optional read-only database replica for scaling read operations + + + +### ClickHouse Configuration + + + ClickHouse integration for metrics and logging: + + + ClickHouse server connection string (e.g., "http://clickhouse:8123") + + diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index a342ee71e3..953e3ab7e0 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -51,12 +51,19 @@ services: dockerfile: ./Dockerfile depends_on: - mysql + - redis - clickhouse environment: GOSSIP_PORT: 9090 RPC_PORT: 9091 - DATABASE_PRIMARY_DSN: "mysql://unkey:password@tcp(mysql:3900)/unkey" + DATABASE_PRIMARY_DSN: "mysql://unkey:password@tcp(mysql:3900)/unkey?parseTime=true" CLICKHOUSE_URL: "clickhouse://default:password@clickhouse:9000" + REDIS_URL: "redis://redis:6379" + + redis: + image: redis:latest + ports: + - 6379:6379 agent: command: ["/usr/local/bin/unkey", "agent", "--config", "config.docker.json"] diff --git a/go/.golangci.yaml b/go/.golangci.yaml index e3d225dd8c..2e107b1c45 100644 --- a/go/.golangci.yaml +++ b/go/.golangci.yaml @@ -12,7 +12,6 @@ run: # Default: 1m timeout: 3m - # This file contains only configs which differ from defaults. # All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml linters-settings: @@ -169,7 +168,7 @@ linters-settings: nolintlint: # Exclude following linters from requiring an explanation. # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] + allow-no-explanation: [funlen, gocognit, lll] # Enable to require an explanation of nonzero length after each nolint directive. # Default: false require-explanation: true @@ -217,7 +216,6 @@ linters-settings: # Default: false all: true - linters: disable-all: true enable: @@ -241,14 +239,14 @@ linters: - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error # - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 - exhaustive # checks exhaustiveness of enum switch statements - - fatcontext # detects nested contexts in loops + # - fatcontext # detects nested contexts in loops # - forbidigo # forbids identifiers # - funlen # tool for detection of long functions # - gocheckcompilerdirectives # validates go compiler directive comments (//go:) # - gochecknoglobals # checks that no global variables exist # # - gochecknoinits # checks that no init functions are present in Go code - gochecksumtype # checks exhaustiveness on Go "sum types" - - gocognit # computes and checks the cognitive complexity of functions + # - gocognit # computes and checks the cognitive complexity of functions - goconst # finds repeated strings that could be replaced by a constant - gocritic # provides diagnostics that check for bugs, performance and style issues # - gocyclo # computes and checks the cyclomatic complexity of functions @@ -332,7 +330,6 @@ linters: #- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers #- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines - issues: # Maximum count of issues with the same text. # Set to 0 to disable. @@ -341,9 +338,9 @@ issues: exclude-rules: - source: "(noinspection|TODO)" - linters: [ godot ] + linters: [godot] - source: "//noinspection" - linters: [ gocritic ] + linters: [gocritic] - path: "_test\\.go" linters: - bodyclose diff --git a/go/Taskfile.yml b/go/Taskfile.yml index 1cbe67ca8c..e07d80827d 100644 --- a/go/Taskfile.yml +++ b/go/Taskfile.yml @@ -9,6 +9,7 @@ tasks: - task: lint test: cmds: + - docker pull mysql:latest - go test -cover -json -failfast ./... | tparse -all -progress build: diff --git a/go/api/config.yaml b/go/api/config.yaml index 6911059186..e9774775bf 100644 --- a/go/api/config.yaml +++ b/go/api/config.yaml @@ -4,6 +4,5 @@ output: ./gen.go generate: models: true - output-options: nullable-type: true diff --git a/go/api/gen.go b/go/api/gen.go index 78edd2e8a1..bca0654d49 100644 --- a/go/api/gen.go +++ b/go/api/gen.go @@ -82,6 +82,9 @@ type V2RatelimitLimitRequestBody struct { // Limit The maximum number of requests allowed. Limit int64 `json:"limit"` + + // Namespace The namespace name for the rate limit. + Namespace string `json:"namespace"` } // V2RatelimitLimitResponseBody defines model for V2RatelimitLimitResponseBody. @@ -113,7 +116,7 @@ type V2RatelimitSetOverrideRequestBody struct { // NamespaceId The id of the namespace. Either namespaceId or namespaceName must be provided NamespaceId *string `json:"namespaceId,omitempty"` - // NamespaceName xThe name of the namespace. Either namespaceId or namespaceName must be provided + // NamespaceName The name of the namespace. Either namespaceId or namespaceName must be provided NamespaceName *string `json:"namespaceName,omitempty"` } diff --git a/go/api/openapi.json b/go/api/openapi.json index 3b30f02267..89c87eb60c 100644 --- a/go/api/openapi.json +++ b/go/api/openapi.json @@ -15,7 +15,6 @@ "components": { "schemas": { "BaseError": { - "additionalProperties": false, "properties": { "requestId": { "description": "A unique id for this request. Please always provide this to support.", @@ -71,13 +70,14 @@ "$ref": "#/components/schemas/BaseError" }, { + "type": "object", "properties": { "errors": { "description": "Optional list of individual error details", "items": { "$ref": "#/components/schemas/ValidationError" }, - "type": ["array"] + "type": "array" } }, "required": ["errors"] @@ -124,25 +124,32 @@ "properties": { "namespaceId": { "description": "The id of the namespace. Either namespaceId or namespaceName must be provided", - "type": "string" + "type": "string", + "minLength": 1, + "maxLength": 255, + "pattern": "^rlns_.+$" }, "namespaceName": { - "description": "xThe name of the namespace. Either namespaceId or namespaceName must be provided", + "description": "The name of the namespace. Either namespaceId or namespaceName must be provided", "type": "string" }, "duration": { "description": "The duration in milliseconds for the rate limit window.", "format": "int64", - "type": "integer" + "type": "integer", + "minimum": 1000 }, "identifier": { "description": "Identifier of your user, this can be their userId, an email, an ip or anything else. Wildcards ( * ) can be used to match multiple identifiers, More info can be found at https://www.unkey.com/docs/ratelimiting/overrides#wildcard-rules", - "type": "string" + "type": "string", + "minLength": 1, + "maxLength": 255 }, "limit": { "description": "The maximum number of requests allowed.", "format": "int64", - "type": "integer" + "type": "integer", + "minimum": 0 } }, "required": ["identifier", "limit", "duration"], @@ -162,6 +169,11 @@ "V2RatelimitLimitRequestBody": { "additionalProperties": false, "properties": { + "namespace": { + "description": "The namespace name for the rate limit.", + "type": "string", + "example": "sms.sign_up" + }, "cost": { "description": "The cost of the request. Defaults to 1 if not provided.", "format": "int64", @@ -183,7 +195,7 @@ "type": "integer" } }, - "required": ["identifier", "limit", "duration"], + "required": ["namespace", "identifier", "limit", "duration"], "type": "object" }, "V2RatelimitLimitResponseBody": { diff --git a/go/cmd/api/config.go b/go/cmd/api/config.go index 7c933db802..83f8bf8fcc 100644 --- a/go/cmd/api/config.go +++ b/go/cmd/api/config.go @@ -20,10 +20,9 @@ type nodeConfig struct { Static *struct { Addrs []string `json:"addrs" minLength:"1" description:"List of node addresses"` } `json:"static,omitempty" description:"Static cluster discovery configuration"` - AwsCloudmap *struct { - ServiceName string `json:"serviceName" minLength:"1" description:"Cloudmap service name"` - Region string `json:"region" minLength:"1" description:"Cloudmap region"` - } `json:"awsCloudmap,omitempty" description:"Cloudmap cluster discovery configuration"` + Redis *struct { + URL string `json:"url" minLength:"1" description:"Redis URL"` + } `json:"redis,omitempty" description:"Redis cluster discovery configuration"` } `json:"discovery,omitempty" description:"Cluster discovery configuration, only one supported: static, cloudmap"` } `json:"cluster,omitempty" description:"Cluster configuration"` diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index 72874eacde..c825e46578 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -14,10 +14,13 @@ import ( "github.com/unkeyed/unkey/go/cmd/api/routes" "github.com/unkeyed/unkey/go/internal/services/keys" + "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/cluster" "github.com/unkeyed/unkey/go/pkg/config" "github.com/unkeyed/unkey/go/pkg/database" + dbCache "github.com/unkeyed/unkey/go/pkg/database/middleware/cache" "github.com/unkeyed/unkey/go/pkg/discovery" "github.com/unkeyed/unkey/go/pkg/logging" "github.com/unkeyed/unkey/go/pkg/membership" @@ -50,6 +53,8 @@ var Cmd = &cli.Command{ // nolint:gocognit func run(cliC *cli.Context) error { + shutdowns := []func(ctx context.Context) error{} + if cliC.Bool("generate-config-schema") { // nolint:exhaustruct _, err := config.GenerateJsonSchema(nodeConfig{}, "schema.json") @@ -62,6 +67,7 @@ func run(cliC *cli.Context) error { return nil } ctx := cliC.Context + clk := clock.New() configFile := cliC.String("config") // nolint:exhaustruct @@ -72,8 +78,12 @@ func run(cliC *cli.Context) error { } nodeID := uid.Node() - if cfg.Cluster != nil && cfg.Cluster.NodeID != "" { - nodeID = cfg.Cluster.NodeID + if cfg.Cluster != nil { + if cfg.Cluster.NodeID == "" { + cfg.Cluster.NodeID = nodeID + } else { + nodeID = cfg.Cluster.NodeID + } } if cfg.Region == "" { @@ -98,61 +108,23 @@ func run(cliC *cli.Context) error { logger.Info(ctx, "configration loaded", slog.String("file", configFile)) - var c cluster.Cluster = cluster.NewNoop(nodeID, net.ParseIP("127.0.0.1")) - if cfg.Cluster != nil { - var d discovery.Discoverer - - switch { - case cfg.Cluster.Discovery.Static != nil: - - d = &discovery.Static{ - Addrs: cfg.Cluster.Discovery.Static.Addrs, - } - case cfg.Cluster.Discovery.AwsCloudmap != nil: - return fmt.Errorf("NOT IMPLEMENTED") - default: - return fmt.Errorf("missing discovery method") - } - - gossipPort, err := strconv.ParseInt(cfg.Cluster.GossipPort, 10, 64) - if err != nil { - return fmt.Errorf("unable to parse gossip port: %w", err) - } - - m, mErr := membership.New(membership.Config{ - NodeID: nodeID, - Addr: net.ParseIP(""), - GossipPort: int(gossipPort), - Logger: logger, - }) - if mErr != nil { - return fmt.Errorf("unable to create membership: %w", err) - } + db, err := database.New(database.Config{ + PrimaryDSN: cfg.Database.Primary, + ReadOnlyDSN: cfg.Database.ReadonlyReplica, + Logger: logger, + Clock: clock.New(), + }, dbCache.WithCaching(logger)) + if err != nil { + return fmt.Errorf("unable to create db: %w", err) + } - rpcPort, err := strconv.ParseInt(cfg.Cluster.RpcPort, 10, 64) - if err != nil { - return fmt.Errorf("unable to parse rpc port: %w", err) - } - c, err = cluster.New(cluster.Config{ - Self: cluster.Node{ - - ID: nodeID, - Addr: net.ParseIP(cfg.Cluster.AdvertiseAddr), - RpcAddr: "TO DO", - }, - Logger: logger, - Membership: m, - RpcPort: int(rpcPort), - }) - if err != nil { - return fmt.Errorf("unable to create cluster: %w", err) - } + defer db.Close() - err = m.Start(d) - if err != nil { - return fmt.Errorf("unable to start membership: %w", err) - } + c, shutdownCluster, err := setupCluster(cfg, logger) + if err != nil { + return fmt.Errorf("unable to create cluster: %w", err) } + shutdowns = append(shutdowns, shutdownCluster...) var ch clickhouse.Bufferer = clickhouse.NewNoop() if cfg.Clickhouse != nil { @@ -173,15 +145,6 @@ func run(cliC *cli.Context) error { return fmt.Errorf("unable to create server: %w", err) } - db, err := database.New(database.Config{ - PrimaryDSN: cfg.Database.Primary, - ReadOnlyDSN: cfg.Database.ReadonlyReplica, - Logger: logger, - }) - if err != nil { - return fmt.Errorf("unable to connect to database: %w", err) - } - validator, err := validation.New() if err != nil { return fmt.Errorf("unable to create validator: %w", err) @@ -195,12 +158,22 @@ func run(cliC *cli.Context) error { return fmt.Errorf("unable to create key service: %w", err) } + rlSvc, err := ratelimit.New(ratelimit.Config{ + Logger: logger, + Cluster: c, + Clock: clk, + }) + if err != nil { + return fmt.Errorf("unable to create ratelimit service: %w", err) + } + routes.Register(srv, &routes.Services{ Logger: logger, Database: db, EventBuffer: ch, Keys: keySvc, Validator: validator, + Ratelimit: rlSvc, }) go func() { @@ -210,6 +183,10 @@ func run(cliC *cli.Context) error { } }() + return gracefulShutdown(ctx, logger, shutdowns) +} + +func gracefulShutdown(ctx context.Context, logger logging.Logger, shutdowns []func(ctx context.Context) error) error { cShutdown := make(chan os.Signal, 1) signal.Notify(cShutdown, os.Interrupt, syscall.SIGTERM) @@ -217,10 +194,99 @@ func run(cliC *cli.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() logger.Info(ctx, "shutting down") - err = c.Shutdown(ctx) - if err != nil { - return fmt.Errorf("unable to leave cluster: %w", err) + errors := []error{} + for i := len(shutdowns) - 1; i >= 0; i-- { + err := shutdowns[i](ctx) + if err != nil { + errors = append(errors, err) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred during shutdown: %v", errors) } return nil } + +func setupCluster(cfg nodeConfig, logger logging.Logger) (cluster.Cluster, []func(ctx context.Context) error, error) { + if cfg.Cluster == nil { + return cluster.NewNoop("", net.ParseIP("127.0.0.1")), []func(ctx context.Context) error{}, nil + } + + shutdowns := []func(ctx context.Context) error{} + + gossipPort, err := strconv.ParseInt(cfg.Cluster.GossipPort, 10, 64) + if err != nil { + return nil, shutdowns, fmt.Errorf("unable to parse gossip port: %w", err) + } + + m, err := membership.New(membership.Config{ + NodeID: cfg.Cluster.NodeID, + Addr: net.ParseIP(""), + GossipPort: int(gossipPort), + Logger: logger, + }) + if err != nil { + return nil, shutdowns, fmt.Errorf("unable to create membership: %w", err) + } + + rpcPort, err := strconv.ParseInt(cfg.Cluster.RpcPort, 10, 64) + if err != nil { + return nil, shutdowns, fmt.Errorf("unable to parse rpc port: %w", err) + } + c, err := cluster.New(cluster.Config{ + Self: cluster.Node{ + + ID: cfg.Cluster.NodeID, + Addr: net.ParseIP(cfg.Cluster.AdvertiseAddr), + RpcAddr: "TO DO", + }, + Logger: logger, + Membership: m, + RpcPort: int(rpcPort), + }) + if err != nil { + return nil, shutdowns, fmt.Errorf("unable to create cluster: %w", err) + } + shutdowns = append(shutdowns, c.Shutdown) + + var d discovery.Discoverer + + switch { + case cfg.Cluster.Discovery.Static != nil: + { + d = &discovery.Static{ + Addrs: cfg.Cluster.Discovery.Static.Addrs, + } + break + } + + case cfg.Cluster.Discovery.Redis != nil: + { + rd, rErr := discovery.NewRedis(discovery.RedisConfig{ + URL: cfg.Cluster.Discovery.Redis.URL, + NodeID: cfg.Cluster.NodeID, + Addr: cfg.Cluster.AdvertiseAddr, + Logger: logger, + }) + if rErr != nil { + return nil, shutdowns, fmt.Errorf("unable to create redis discovery: %w", rErr) + } + shutdowns = append(shutdowns, rd.Shutdown) + d = rd + break + } + default: + { + return nil, nil, fmt.Errorf("missing discovery method") + } + } + + err = m.Start(d) + if err != nil { + return nil, nil, fmt.Errorf("unable to start membership: %w", err) + } + + return c, shutdowns, nil +} diff --git a/go/cmd/api/routes/register.go b/go/cmd/api/routes/register.go index 070216c9be..0707c26a42 100644 --- a/go/cmd/api/routes/register.go +++ b/go/cmd/api/routes/register.go @@ -4,28 +4,31 @@ import ( v2EcsMeta "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ecs_meta" v2Liveness "github.com/unkeyed/unkey/go/cmd/api/routes/v2_liveness" v2RatelimitLimit "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ratelimit_limit" + v2RatelimitSetOverride "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ratelimit_set_override" zen "github.com/unkeyed/unkey/go/pkg/zen" ) // here we register all of the routes. // this function runs during startup. func Register(srv *zen.Server, svc *Services) { + withTracing := zen.WithTracing() withMetrics := zen.WithMetrics(svc.EventBuffer) - withRootKeyAuth := zen.WithRootKeyAuth(svc.Keys) + withLogging := zen.WithLogging(svc.Logger) withErrorHandling := zen.WithErrorHandling() withValidation := zen.WithValidation(svc.Validator) defaultMiddlewares := []zen.Middleware{ + withTracing, withMetrics, withLogging, withErrorHandling, - withRootKeyAuth, // must be before validation to capture the workspaceID withValidation, } srv.RegisterRoute( []zen.Middleware{ + withTracing, withMetrics, withLogging, withErrorHandling, @@ -33,11 +36,22 @@ func Register(srv *zen.Server, svc *Services) { }, v2Liveness.New()) + // --------------------------------------------------------------------------- + // v2/ratelimit + + // v2/ratelimit.limit + srv.RegisterRoute( + defaultMiddlewares, + v2RatelimitLimit.New(v2RatelimitLimit.Services{Logger: svc.Logger, DB: svc.Database, Keys: svc.Keys, Ratelimit: svc.Ratelimit}), + ) + // v2/ratelimit.setOverride srv.RegisterRoute( defaultMiddlewares, - v2RatelimitLimit.New(v2RatelimitLimit.Services{}), + v2RatelimitSetOverride.New(v2RatelimitSetOverride.Services{Logger: svc.Logger, DB: svc.Database, Keys: svc.Keys}), ) + // --------------------------------------------------------------------------- + // misc srv.RegisterRoute([]zen.Middleware{}, v2EcsMeta.New()) } diff --git a/go/cmd/api/routes/services.go b/go/cmd/api/routes/services.go index 0c337cefc0..4be8b470e2 100644 --- a/go/cmd/api/routes/services.go +++ b/go/cmd/api/routes/services.go @@ -2,6 +2,7 @@ package routes import ( "github.com/unkeyed/unkey/go/internal/services/keys" + "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse/schema" "github.com/unkeyed/unkey/go/pkg/database" "github.com/unkeyed/unkey/go/pkg/logging" @@ -18,4 +19,5 @@ type Services struct { EventBuffer EventBuffer Keys keys.KeyService Validator *validation.Validator + Ratelimit ratelimit.Service } diff --git a/go/cmd/api/routes/v2_ratelimit_limit/handler.go b/go/cmd/api/routes/v2_ratelimit_limit/handler.go index ef74f9a3a9..aca8872ab7 100644 --- a/go/cmd/api/routes/v2_ratelimit_limit/handler.go +++ b/go/cmd/api/routes/v2_ratelimit_limit/handler.go @@ -1,17 +1,29 @@ package v2RatelimitLimit import ( + "errors" "net/http" + "time" openapi "github.com/unkeyed/unkey/go/api" + "github.com/unkeyed/unkey/go/internal/services/keys" + "github.com/unkeyed/unkey/go/internal/services/ratelimit" + "github.com/unkeyed/unkey/go/pkg/database" "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/logging" zen "github.com/unkeyed/unkey/go/pkg/zen" ) type Request = openapi.V2RatelimitLimitRequestBody type Response = openapi.V2RatelimitLimitResponseBody -type Services struct{} +type Services struct { + Logger logging.Logger + Keys keys.KeyService + DB database.Database + + Ratelimit ratelimit.Service +} func New(svc Services) zen.Route { return zen.NewRoute("POST", "/v2/ratelimit.limit", func(s *zen.Session) error { @@ -23,14 +35,69 @@ func New(svc Services) zen.Route { ) } - // do stuff + limitRequest := ratelimit.RatelimitRequest{ + Identifier: req.Identifier, + Limit: req.Limit, + Duration: time.Duration(req.Duration) * time.Millisecond, + Cost: 1, + } + if req.Cost != nil { + limitRequest.Cost = *req.Cost + } + + namespace, err := svc.DB.FindRatelimitNamespaceByName(s.Context(), s.AuthorizedWorkspaceID(), req.Namespace) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + + return fault.Wrap( + err, + fault.WithTag(fault.NOT_FOUND), + fault.WithDesc("namespace not found", "This namespace does not exist."), + ) + } + + return fault.Wrap( + err, + fault.WithDesc("unable to load namespace", ""), + ) + } + + overrides, err := svc.DB.FindRatelimitOverridesByIdentifier(s.Context(), s.AuthorizedWorkspaceID(), namespace.ID, req.Identifier) - res := Response{ - Limit: -1, - Remaining: -1, - Reset: -1, - Success: true, + if err != nil && !errors.Is(err, database.ErrNotFound) { + return fault.Wrap( + err, + fault.WithDesc("unable to load overrides", ""), + ) + } + + usedOverrideID := "" + for _, override := range overrides { + usedOverrideID = override.ID + limitRequest.Limit = int64(override.Limit) + limitRequest.Duration = override.Duration + + if override.Identifier == req.Identifier { + // we found an exact match, which takes presedence over wildcard matches + break + } + } + if usedOverrideID != "" { + s.AddHeader("X-Unkey-Override-Used", usedOverrideID) } - return s.JSON(http.StatusOK, res) + + res, err := svc.Ratelimit.Ratelimit(s.Context(), limitRequest) + if err != nil { + return fault.Wrap(err, + fault.WithDesc("ratelimit failed", "We're unable to ratelimit the request."), + ) + } + + return s.JSON(http.StatusOK, Response{ + Limit: req.Limit, + Remaining: res.Remaining, + Reset: res.Reset, + Success: res.Success, + }) }) } diff --git a/go/cmd/api/routes/v2_ratelimit_set_override/200_test.go b/go/cmd/api/routes/v2_ratelimit_set_override/200_test.go new file mode 100644 index 0000000000..22c954af03 --- /dev/null +++ b/go/cmd/api/routes/v2_ratelimit_set_override/200_test.go @@ -0,0 +1,71 @@ +package handler_test + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + handler "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ratelimit_set_override" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/testutil" + "github.com/unkeyed/unkey/go/pkg/uid" +) + +func TestCreateNewOverrideSuccessfully(t *testing.T) { + ctx := context.Background() + h := testutil.NewHarness(t) + + namespaceID := uid.New("test_ns") + namespaceName := "test_namespace" + h.DB.InsertRatelimitNamespace(ctx, entities.RatelimitNamespace{ + ID: namespaceID, + WorkspaceID: h.Resources.UserWorkspace.ID, + Name: namespaceName, + CreatedAt: time.Now(), + UpdatedAt: time.Time{}, + DeletedAt: time.Time{}, + }) + + route := handler.New(handler.Services{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + }) + + h.Register(route) + + rootKey := h.CreateRootKey() + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + req := handler.Request{ + NamespaceId: nil, + NamespaceName: &namespaceName, + Identifier: "test_identifier", + Limit: 10, + Duration: 1000, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status, "expected 200, received: %v", res.Body) + require.NotNil(t, res.Body) + require.NotEqual(t, "", res.Body.OverrideId, "Override ID should not be empty, got: %+v", res.Body) + + override, err := h.DB.FindRatelimitOverrideByID(ctx, h.Resources.UserWorkspace.ID, res.Body.OverrideId) + require.NoError(t, err) + + require.Equal(t, namespaceID, override.NamespaceID) + require.Equal(t, req.Identifier, override.Identifier) + require.Equal(t, req.Limit, int64(override.Limit)) + require.Equal(t, req.Duration, override.Duration.Milliseconds()) + require.False(t, override.CreatedAt.IsZero()) + require.True(t, override.UpdatedAt.IsZero()) + require.True(t, override.DeletedAt.IsZero()) + +} diff --git a/go/cmd/api/routes/v2_ratelimit_set_override/400_test.go b/go/cmd/api/routes/v2_ratelimit_set_override/400_test.go new file mode 100644 index 0000000000..4e516029ad --- /dev/null +++ b/go/cmd/api/routes/v2_ratelimit_set_override/400_test.go @@ -0,0 +1,235 @@ +//nolint:exhaustruct +package handler_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/apps/agent/pkg/util" + "github.com/unkeyed/unkey/go/api" + handler "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ratelimit_set_override" + "github.com/unkeyed/unkey/go/pkg/testutil" +) + +func TestBadRequests(t *testing.T) { + + testCases := []struct { + name string + req api.V2RatelimitSetOverrideRequestBody + expectedError api.BadRequestError + }{ + { + name: "missing all required fields", + req: api.V2RatelimitSetOverrideRequestBody{}, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "missing identifier", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + Duration: 1000, + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "missing duration", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + Identifier: "user_123", + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "missing limit", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "user_123", + Duration: 1000, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "negative duration", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "user_123", + Duration: -1000, + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "zero duration", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "user_123", + Duration: 0, + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "negative limit", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "user_123", + Duration: 1000, + Limit: -100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "zero limit", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "user_123", + Duration: 1000, + Limit: 0, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "empty identifier", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: util.Pointer("not_empty"), + NamespaceName: nil, + Identifier: "", + Duration: 1000, + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "One or more fields failed validation", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + { + name: "neither namespace ID nor name provided", + req: api.V2RatelimitSetOverrideRequestBody{ + NamespaceId: nil, + NamespaceName: nil, + Identifier: "user_123", + Duration: 1000, + Limit: 100, + }, + expectedError: api.BadRequestError{ + Title: "Bad Request", + Detail: "You must provide either a namespace ID or name.", + Status: http.StatusBadRequest, + Type: "https://unkey.com/docs/errors/bad_request", + Errors: []api.ValidationError{}, + RequestId: "test", + Instance: nil, + }, + }, + } + + for _, tc := range testCases { + + t.Run(tc.name, func(t *testing.T) { + h := testutil.NewHarness(t) + + rootKey := h.CreateRootKey() + route := handler.New(handler.Services{ + DB: h.DB, + Keys: h.Keys, + Logger: h.Logger, + }) + + h.Register(route) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } + + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, tc.req) + require.Equal(t, 400, res.Status, "expected 400, received: %v", res.Body) + require.NotNil(t, res.ErrorBody) + require.Equal(t, tc.expectedError.Type, res.ErrorBody.Type) + require.Equal(t, tc.expectedError.Detail, res.ErrorBody.Detail) + require.Equal(t, tc.expectedError.Status, res.ErrorBody.Status) + require.Equal(t, tc.expectedError.Title, res.ErrorBody.Title) + require.NotEmpty(t, res.ErrorBody.RequestId) + + }) + } + +} diff --git a/go/cmd/api/routes/v2_ratelimit_set_override/handler.go b/go/cmd/api/routes/v2_ratelimit_set_override/handler.go index 59fabec031..7708c7d535 100644 --- a/go/cmd/api/routes/v2_ratelimit_set_override/handler.go +++ b/go/cmd/api/routes/v2_ratelimit_set_override/handler.go @@ -1,6 +1,8 @@ package handler import ( + "context" + "log/slog" "net/http" "time" @@ -9,7 +11,7 @@ import ( "github.com/unkeyed/unkey/go/pkg/database" "github.com/unkeyed/unkey/go/pkg/entities" "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/logging" "github.com/unkeyed/unkey/go/pkg/uid" "github.com/unkeyed/unkey/go/pkg/zen" ) @@ -18,20 +20,17 @@ type Request = api.V2RatelimitSetOverrideRequestBody type Response = api.V2RatelimitSetOverrideResponseBody type Services struct { - DB database.Database - Keys keys.KeyService + Logger logging.Logger + DB database.Database + Keys keys.KeyService } func New(svc Services) zen.Route { return zen.NewRoute("POST", "/v2/ratelimit.setOverride", func(s *zen.Session) error { - rootKey, err := zen.Bearer(s) - if err != nil { - return err - } - - auth, err := svc.Keys.Verify(s.Context(), hash.Sha256(rootKey)) + auth, err := svc.Keys.VerifyRootKey(s.Context(), s) if err != nil { + svc.Logger.Warn(s.Context(), "failed to verify root key", slog.String("error", err.Error())) return err } @@ -45,14 +44,20 @@ func New(svc Services) zen.Route { ) } + namespace, err := getNamespace(s.Context(), svc, auth.AuthorizedWorkspaceID, req) + if err != nil { + svc.Logger.Warn(s.Context(), "failed to get namespace", slog.String("error", err.Error())) + return err + } + overrideID := uid.New(uid.RatelimitOverridePrefix) err = svc.DB.InsertRatelimitOverride(s.Context(), entities.RatelimitOverride{ ID: overrideID, WorkspaceID: auth.AuthorizedWorkspaceID, - NamespaceID: "", - Identifier: "", - Limit: 0, - Duration: 0, + NamespaceID: namespace.ID, + Identifier: req.Identifier, + Limit: int32(req.Limit), // nolint:gosec + Duration: time.Duration(req.Duration) * time.Millisecond, CreatedAt: time.Now(), UpdatedAt: time.Time{}, DeletedAt: time.Time{}, @@ -64,8 +69,29 @@ func New(svc Services) zen.Route { fault.WithDesc("database failed", "The database is unavailable."), ) } + return s.JSON(http.StatusOK, Response{ OverrideId: overrideID, }) }) } + +func getNamespace(ctx context.Context, svc Services, workspaceID string, req Request) (entities.RatelimitNamespace, error) { + + switch { + case req.NamespaceId != nil: + { + return svc.DB.FindRatelimitNamespaceByID(ctx, *req.NamespaceId) + } + case req.NamespaceName != nil: + { + return svc.DB.FindRatelimitNamespaceByName(ctx, workspaceID, *req.NamespaceName) + } + } + + return entities.RatelimitNamespace{}, fault.New("missing namespace id or name", + fault.WithTag(fault.BAD_REQUEST), + fault.WithDesc("missing namespace id or name", "You must provide either a namespace ID or name."), + ) + +} diff --git a/go/cmd/api/routes/v2_ratelimit_set_override/happy_test.go b/go/cmd/api/routes/v2_ratelimit_set_override/happy_test.go deleted file mode 100644 index 5be38abd11..0000000000 --- a/go/cmd/api/routes/v2_ratelimit_set_override/happy_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package handler_test - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/unkeyed/unkey/apps/agent/pkg/util" - handler "github.com/unkeyed/unkey/go/cmd/api/routes/v2_ratelimit_set_override" - "github.com/unkeyed/unkey/go/pkg/database" - "github.com/unkeyed/unkey/go/pkg/entities" - "github.com/unkeyed/unkey/go/pkg/logging" - "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/uid" -) - -func TestCreateNewOverride(t *testing.T) { - ctx := context.Background() - h := testutil.NewHarness(t) - - c := testutil.NewContainers(t) - - mysqlAddr := c.RunMySQL() - - db, err := database.New(database.Config{ - PrimaryDSN: mysqlAddr, - ReadOnlyDSN: "", - Logger: logging.NewNoop(), - }) - require.NoError(t, err) - - db.InsertRatelimitOverride(ctx, entities.RatelimitOverride{ - ID: uid.Test(), - WorkspaceID: uid.Test(), - NamespaceID: uid.Test(), - Identifier: "test", - Limit: 10, - Duration: 0, - Async: false, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - DeletedAt: time.Time{}, - }) - - route := handler.New(handler.Services{ - DB: db, - Keys: nil, - }) - - h.Register(route) - - req := handler.Request{ - NamespaceId: util.Pointer(""), - NamespaceName: nil, - Identifier: "", - Limit: 10, - Duration: 1000, - } - res := testutil.CallRoute[handler.Request, handler.Response](h, route, nil, req) - - require.Equal(t, 200, res.Status) - require.NotEqual(t, "", res.Body.OverrideId) -} diff --git a/go/config.docker.json b/go/config.docker.json index 4ed45bd71d..8b565e0ab7 100644 --- a/go/config.docker.json +++ b/go/config.docker.json @@ -10,9 +10,10 @@ "cluster": { "rpcPort": "${RPC_PORT}", "gossipPort": "${GOSSIP_PORT}", + "advertiseAddr": "${HOSTNAME}", "discovery": { - "static": { - "addrs": ["unkey-apiv2-1:${GOSSIP_PORT}"] + "redis": { + "url": "${REDIS_URL}" } } } diff --git a/go/go.mod b/go/go.mod index fe4a74f4bb..86618d8c54 100644 --- a/go/go.mod +++ b/go/go.mod @@ -18,8 +18,10 @@ require ( github.com/maypok86/otter v1.2.4 github.com/oapi-codegen/oapi-codegen/v2 v2.4.1 github.com/ory/dockertest/v3 v3.11.0 + github.com/panjf2000/ants v1.3.0 github.com/pb33f/libopenapi v0.21.2 github.com/pb33f/libopenapi-validator v0.3.0 + github.com/redis/go-redis/v9 v9.6.1 github.com/segmentio/ksuid v1.0.4 github.com/sqlc-dev/sqlc v1.28.0 github.com/stretchr/testify v1.10.0 @@ -45,11 +47,13 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/continuity v0.4.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/creack/pty v1.1.23 // indirect github.com/cubicdaiya/gonp v1.0.4 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/docker/cli v27.2.0+incompatible // indirect github.com/docker/docker v27.4.1+incompatible // indirect diff --git a/go/go.sum b/go/go.sum index c43d86d6d5..b37c3ceaea 100644 --- a/go/go.sum +++ b/go/go.sum @@ -104,6 +104,10 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= @@ -123,6 +127,8 @@ github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -160,6 +166,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/cli v27.2.0+incompatible h1:yHD1QEB1/0vr5eBNpu8tncu8gWxg8EydFPOSKHzXSMM= @@ -548,6 +556,8 @@ github.com/opencontainers/runc v1.1.13/go.mod h1:R016aXacfp/gwQBYw2FDGa9m+n6atbL github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNGuyA= github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI= +github.com/panjf2000/ants v1.3.0 h1:8pQ+8leaLc9lys2viEEr8md0U4RN6uOSUCE9bOYjQ9M= +github.com/panjf2000/ants v1.3.0/go.mod h1:AaACblRPzq35m1g3enqYcxspbbiOJJYaxU2wMpm1cXY= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -603,6 +613,8 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= +github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/riza-io/grpc-go v0.2.0 h1:2HxQKFVE7VuYstcJ8zqpN84VnAoJ4dCL6YFhJewNcHQ= diff --git a/go/internal/services/keys/interface.go b/go/internal/services/keys/interface.go index 0f58346acb..631e592424 100644 --- a/go/internal/services/keys/interface.go +++ b/go/internal/services/keys/interface.go @@ -2,10 +2,13 @@ package keys import ( "context" + + "github.com/unkeyed/unkey/go/pkg/zen" ) type KeyService interface { Verify(ctx context.Context, hash string) (VerifyResponse, error) + VerifyRootKey(ctx context.Context, sess *zen.Session) (VerifyResponse, error) } type VerifyResponse struct { diff --git a/go/internal/services/keys/service.go b/go/internal/services/keys/service.go index cef6ecd2be..aea9decc35 100644 --- a/go/internal/services/keys/service.go +++ b/go/internal/services/keys/service.go @@ -1,13 +1,7 @@ package keys import ( - "context" - "time" - - "github.com/unkeyed/unkey/go/pkg/cache" - "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/database" - "github.com/unkeyed/unkey/go/pkg/entities" "github.com/unkeyed/unkey/go/pkg/logging" ) @@ -19,36 +13,12 @@ type Config struct { type service struct { logger logging.Logger db database.Database - cache cache.Cache[entities.Key] } func New(config Config) (*service, error) { - keyCache, err := cache.New[entities.Key](cache.Config[entities.Key]{ - Fresh: 10 * time.Second, - Stale: 1 * time.Minute, - RefreshFromOrigin: func(ctx context.Context, hash string) (entities.Key, bool) { - key, err := config.DB.FindKeyByHash(ctx, hash) - if err != nil { - config.Logger.Error(ctx, "failed to fetch key by hash") - // nolint:exhaustruct - return entities.Key{}, false - } - return key, true - }, - - Logger: config.Logger, - MaxSize: 1_000_000, - Resource: "keys", - Clock: clock.New(), - }) - if err != nil { - return nil, err - } - return &service{ logger: config.Logger, db: config.DB, - cache: keyCache, }, nil } diff --git a/go/internal/services/keys/verify.go b/go/internal/services/keys/verify.go index 806d6cdd18..7d545212ec 100644 --- a/go/internal/services/keys/verify.go +++ b/go/internal/services/keys/verify.go @@ -2,24 +2,34 @@ package keys import ( "context" + "errors" "github.com/unkeyed/unkey/go/pkg/assert" + "github.com/unkeyed/unkey/go/pkg/database" "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/hash" ) -func (s *service) Verify(ctx context.Context, hash string) (VerifyResponse, error) { +func (s *service) Verify(ctx context.Context, rawKey string) (VerifyResponse, error) { - err := assert.NotEmpty(hash) + err := assert.NotEmpty(rawKey) if err != nil { - return VerifyResponse{}, fault.Wrap(err, fault.WithDesc("hash is empty", "")) + return VerifyResponse{}, fault.Wrap(err, fault.WithDesc("rawKey is empty", "")) } + h := hash.Sha256(rawKey) - key, found := s.cache.SWR(ctx, hash) - if !found { - return VerifyResponse{}, fault.New( - "key does not exist", - fault.WithTag(fault.NOT_FOUND), - fault.WithDesc("key does not exist", "We could not find the requested key."), + key, err := s.db.FindKeyByHash(ctx, h) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return VerifyResponse{}, fault.Wrap( + err, + fault.WithTag(fault.NOT_FOUND), + fault.WithDesc("key does not exist", "We could not find the requested key."), + ) + } + return VerifyResponse{}, fault.Wrap( + err, + fault.WithDesc("unable to load key", "We could not load the requested key."), ) } @@ -28,7 +38,7 @@ func (s *service) Verify(ctx context.Context, hash string) (VerifyResponse, erro // - Is it expired? // - Is it ratelimited? - if key.DeletedAt.IsZero() { + if !key.DeletedAt.IsZero() { return VerifyResponse{}, fault.New( "key is deleted", fault.WithDesc("deleted_at is non-zero", "The key has been deleted."), diff --git a/go/internal/services/keys/verify_root_key.go b/go/internal/services/keys/verify_root_key.go new file mode 100644 index 0000000000..e73bf022a1 --- /dev/null +++ b/go/internal/services/keys/verify_root_key.go @@ -0,0 +1,32 @@ +package keys + +import ( + "context" + + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/zen" +) + +func (s *service) VerifyRootKey(ctx context.Context, sess *zen.Session) (VerifyResponse, error) { + + rootKey, err := zen.Bearer(sess) + if err != nil { + return VerifyResponse{}, fault.Wrap(err, + fault.WithTag(fault.UNAUTHORIZED), + fault.WithDesc( + "no bearer", + "You must provide a valid root key in the Authorization header in the format 'Bearer '", + ), + ) + } + + res, err := s.Verify(ctx, rootKey) + if err != nil { + return VerifyResponse{}, fault.Wrap(err, + fault.WithTag(fault.UNAUTHORIZED), + fault.WithDesc("invalid root key", "The provided root key is invalid")) + } + + return res, nil + +} diff --git a/go/pkg/cache/cache.go b/go/pkg/cache/cache.go index 838513ffe8..7a7a05288a 100644 --- a/go/pkg/cache/cache.go +++ b/go/pkg/cache/cache.go @@ -5,29 +5,33 @@ import ( "encoding/json" "fmt" "log/slog" + "sync" "time" "github.com/maypok86/otter" + "github.com/panjf2000/ants" "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/fault" "github.com/unkeyed/unkey/go/pkg/logging" - "github.com/unkeyed/unkey/go/pkg/tracing" - "go.opentelemetry.io/otel/attribute" ) -type cache[T any] struct { - otter otter.Cache[string, swrEntry[T]] - fresh time.Duration - stale time.Duration - refreshFromOrigin func(ctx context.Context, identifier string) (data T, ok bool) - // If a key is stale, its identifier will be put into this channel and a goroutine refreshes it in the background - refreshC chan string +type cache[K comparable, V any] struct { + otter otter.Cache[K, swrEntry[V]] + fresh time.Duration + stale time.Duration + // If a key is stale, its key will be put into this channel and a goroutine refreshes it in the background + refreshC chan K logger logging.Logger resource string clock clock.Clock + + inflightMu sync.Mutex + inflightRefreshes map[K]bool + + pool *ants.Pool } -type Config[T any] struct { +type Config[K comparable, V any] struct { // How long the data is considered fresh // Subsequent requests in this time will try to use the cache Fresh time.Duration @@ -36,9 +40,6 @@ type Config[T any] struct { // fetching from the origin server Stale time.Duration - // A handler that will be called to refetch data from the origin when necessary - RefreshFromOrigin func(ctx context.Context, identifier string) (data T, ok bool) - Logger logging.Logger // Start evicting the least recently used entry when the cache grows to MaxSize @@ -49,77 +50,78 @@ type Config[T any] struct { Clock clock.Clock } -func New[T any](config Config[T]) (*cache[T], error) { +var _ Cache[any, any] = (*cache[any, any])(nil) + +// New creates a new cache instance +func New[K comparable, V any](config Config[K, V]) *cache[K, V] { - builder, err := otter.NewBuilder[string, swrEntry[T]](config.MaxSize) + builder, err := otter.NewBuilder[K, swrEntry[V]](config.MaxSize) if err != nil { - return nil, fault.Wrap(err, fault.WithDesc("failed to create otter builder", "")) + panic(err) } - otter, err := builder.CollectStats().Cost(func(key string, value swrEntry[T]) uint32 { + otter, err := builder.CollectStats().Cost(func(key K, value swrEntry[V]) uint32 { return 1 }).WithTTL(time.Hour).Build() if err != nil { - return nil, fault.Wrap(err, fault.WithDesc("failed to create otter cache", "")) + panic(err) + } + + pool, err := ants.NewPool(10) + if err != nil { + panic(err) } - c := &cache[T]{ + c := &cache[K, V]{ otter: otter, fresh: config.Fresh, stale: config.Stale, - refreshFromOrigin: config.RefreshFromOrigin, - refreshC: make(chan string, 1000), + refreshC: make(chan K, 1000), logger: config.Logger, resource: config.Resource, clock: config.Clock, + pool: pool, + inflightMu: sync.Mutex{}, + inflightRefreshes: make(map[K]bool), } - go c.runRefreshing() - - return c, nil + return c } -func (c cache[T]) Get(ctx context.Context, key string) (value T, hit CacheHit) { +func (c *cache[K, V]) Get(ctx context.Context, key K) (value V, hit CacheHit) { e, ok := c.otter.Get(key) if !ok { - // This hack is necessary because you can not return nil as T - var t T - return t, Miss + // This hack is necessary because you can not return nil as V + var v V + return v, Miss } now := c.clock.Now() - if now.Before(e.Fresh) { - - return e.Value, e.Hit - - } if now.Before(e.Stale) { - c.refreshC <- key - return e.Value, e.Hit } c.otter.Delete(key) - var t T - return t, Miss + var v V + return v, Miss } -func (c cache[T]) SetNull(ctx context.Context, key string) { +func (c *cache[K, V]) SetNull(ctx context.Context, key K) { c.set(ctx, key) } -func (c cache[T]) Set(ctx context.Context, key string, value T) { +func (c *cache[K, V]) Set(ctx context.Context, key K, value V) { c.set(ctx, key, value) } -func (c cache[T]) set(_ context.Context, key string, value ...T) { +func (c *cache[K, V]) set(_ context.Context, key K, value ...V) { now := c.clock.Now() - e := swrEntry[T]{ + e := swrEntry[V]{ Value: value[0], Fresh: now.Add(c.fresh), Stale: now.Add(c.stale), @@ -135,16 +137,16 @@ func (c cache[T]) set(_ context.Context, key string, value ...T) { } -func (c cache[T]) Remove(ctx context.Context, key string) { +func (c *cache[K, V]) Remove(ctx context.Context, key K) { c.otter.Delete(key) } -func (c cache[T]) Dump(ctx context.Context) ([]byte, error) { - data := make(map[string]swrEntry[T]) +func (c *cache[K, V]) Dump(ctx context.Context) ([]byte, error) { + data := make(map[K]swrEntry[V]) - c.otter.Range(func(key string, entry swrEntry[T]) bool { + c.otter.Range(func(key K, entry swrEntry[V]) bool { data[key] = entry return true }) @@ -158,9 +160,9 @@ func (c cache[T]) Dump(ctx context.Context) ([]byte, error) { } -func (c cache[T]) Restore(ctx context.Context, b []byte) error { +func (c *cache[K, V]) Restore(ctx context.Context, b []byte) error { - data := make(map[string]swrEntry[T]) + data := make(map[K]swrEntry[V]) err := json.Unmarshal(b, &data) if err != nil { return fmt.Errorf("failed to unmarshal cache data: %w", err) @@ -177,47 +179,93 @@ func (c cache[T]) Restore(ctx context.Context, b []byte) error { return nil } -func (c cache[T]) Clear(ctx context.Context) { +func (c *cache[K, V]) Clear(ctx context.Context) { c.otter.Clear() } -func (c cache[T]) runRefreshing() { - for { - ctx := context.Background() - identifier := <-c.refreshC - - ctx, span := tracing.Start(ctx, tracing.NewSpanName(fmt.Sprintf("cache.%s", c.resource), "refresh")) - span.SetAttributes(attribute.String("identifier", identifier)) - t, ok := c.refreshFromOrigin(ctx, identifier) - if !ok { - span.AddEvent("identifier not found in origin") - c.logger.Warn(ctx, "origin couldn't find data", slog.String("identifier", identifier)) - span.End() - continue - } - c.Set(ctx, identifier, t) - span.End() +func (c *cache[K, V]) refresh( + ctx context.Context, + key K, refreshFromOrigin func(context.Context) (V, error), + translateError func(error) CacheHit, +) { + c.inflightMu.Lock() + _, ok := c.inflightRefreshes[key] + if ok { + c.inflightMu.Unlock() + return + } + c.inflightRefreshes[key] = true + c.inflightMu.Unlock() + + defer func() { + c.inflightMu.Lock() + delete(c.inflightRefreshes, key) + c.inflightMu.Unlock() + }() + + v, err := refreshFromOrigin(ctx) + + switch translateError(err) { + case Hit: + c.set(ctx, key, v) + case Miss: + c.set(ctx, key) + case Null: + c.set(ctx, key) } } -func (c cache[T]) SWR(ctx context.Context, identifier string) (T, bool) { +func (c *cache[K, V]) SWR( + ctx context.Context, + key K, + refreshFromOrigin func(context.Context) (V, error), + translateError func(error) CacheHit, +) (V, error) { + now := c.clock.Now() + e, ok := c.otter.Get(key) + if ok { + // Cache Hit - value, hit := c.Get(ctx, identifier) + if now.Before(e.Fresh) { + // We have data and it's fresh, so we return it + + return e.Value, nil + } + + if now.Before(e.Stale) { + // We have data, but it's stale, so we refresh it in the background + // but return the current value + + err := c.pool.Submit(func() { + c.refresh(ctx, key, refreshFromOrigin, translateError) + }) + if err != nil { + c.logger.Error(ctx, "failed to submit refresh task", slog.String("error", err.Error())) + } + + return e.Value, nil + } + + // We have old data, that we should not serve anymore + c.otter.Delete(key) - if hit == Hit { - return value, true - } - if hit == Null { - return value, false } + // Cache Miss - value, found := c.refreshFromOrigin(ctx, identifier) - if found { - c.Set(ctx, identifier, value) - return value, true + // We have no data and need to go to the origin + + v, err := refreshFromOrigin(ctx) + + switch translateError(err) { + case Hit: + c.set(ctx, key, v) + case Miss: + c.set(ctx, key) + case Null: + c.set(ctx, key) } - c.SetNull(ctx, identifier) - return value, false + + return v, err } diff --git a/go/pkg/cache/cache_test.go b/go/pkg/cache/cache_test.go index 342bf8752c..38d9579eb8 100644 --- a/go/pkg/cache/cache_test.go +++ b/go/pkg/cache/cache_test.go @@ -15,18 +15,14 @@ import ( func TestWriteRead(t *testing.T) { - c, err := cache.New[string](cache.Config[string]{ + c := cache.New[string, string](cache.Config[string, string]{ MaxSize: 10_000, - Fresh: time.Minute, - Stale: time.Minute * 5, - RefreshFromOrigin: func(ctx context.Context, id string) (string, bool) { - return "hello", true - }, + Fresh: time.Minute, + Stale: time.Minute * 5, Logger: logging.NewNoop(), Resource: "test", Clock: clock.New(), }) - require.NoError(t, err) c.Set(context.Background(), "key", "value") value, hit := c.Get(context.Background(), "key") require.Equal(t, cache.Hit, hit) @@ -36,19 +32,15 @@ func TestWriteRead(t *testing.T) { func TestEviction(t *testing.T) { clk := clock.NewTestClock() - c, err := cache.New[string](cache.Config[string]{ + c := cache.New[string, string](cache.Config[string, string]{ MaxSize: 10_000, - Fresh: time.Second, - Stale: time.Second, - RefreshFromOrigin: func(ctx context.Context, id string) (string, bool) { - return "hello", true - }, + Fresh: time.Second, + Stale: time.Second, Logger: logging.NewNoop(), Resource: "test", Clock: clk, }) - require.NoError(t, err) c.Set(context.Background(), "key", "value") clk.Tick(2 * time.Second) @@ -63,22 +55,15 @@ func TestRefresh(t *testing.T) { // count how many times we refreshed from origin refreshedFromOrigin := atomic.Int32{} - c, err := cache.New[string](cache.Config[string]{ + c := cache.New[string, string](cache.Config[string, string]{ MaxSize: 10_000, - Fresh: time.Second * 2, - Stale: time.Minute * 5, - RefreshFromOrigin: func(ctx context.Context, id string) (string, bool) { - refreshedFromOrigin.Add(1) - - t.Log("called", id, clk.Now()) - return "hello", true - }, + Fresh: time.Second * 2, + Stale: time.Minute * 5, Logger: logging.NewNoop(), Resource: "test", Clock: clk, }) - require.NoError(t, err) c.Set(context.Background(), "key", "value") clk.Tick(time.Second) @@ -95,16 +80,14 @@ func TestRefresh(t *testing.T) { func TestNull(t *testing.T) { t.Skip() - c, err := cache.New[string](cache.Config[string]{ - MaxSize: 10_000, - Fresh: time.Second * 1, - Stale: time.Minute * 5, - Logger: logging.NewNoop(), - RefreshFromOrigin: nil, - Resource: "test", - Clock: clock.New(), + c := cache.New[string, string](cache.Config[string, string]{ + MaxSize: 10_000, + Fresh: time.Second * 1, + Stale: time.Minute * 5, + Logger: logging.NewNoop(), + Resource: "test", + Clock: clock.New(), }) - require.NoError(t, err) c.SetNull(context.Background(), "key") diff --git a/go/pkg/cache/interface.go b/go/pkg/cache/interface.go index 14474732d1..d3193da40a 100644 --- a/go/pkg/cache/interface.go +++ b/go/pkg/cache/interface.go @@ -4,21 +4,21 @@ import ( "context" ) -type Cache[T any] interface { +type Cache[K comparable, V any] interface { // Get returns the value for the given key. // If the key is not found, found will be false. - Get(ctx context.Context, key string) (value T, hit CacheHit) + Get(ctx context.Context, key K) (value V, hit CacheHit) // Sets the value for the given key. - Set(ctx context.Context, key string, value T) + Set(ctx context.Context, key K, value V) // Sets the given key to null, indicating that the value does not exist in the origin. - SetNull(ctx context.Context, key string) + SetNull(ctx context.Context, key K) // Removes the key from the cache. - Remove(ctx context.Context, key string) + Remove(ctx context.Context, key K) - SWR(ctx context.Context, key string) (value T, found bool) + SWR(ctx context.Context, key K, refreshFromOrigin func(ctx context.Context) (V, error), translateError func(error) CacheHit) (value V, err error) // Dump returns a serialized representation of the cache. Dump(ctx context.Context) ([]byte, error) @@ -30,6 +30,10 @@ type Cache[T any] interface { Clear(ctx context.Context) } +type Key interface { + ToString() string +} + type CacheHit int const ( diff --git a/go/pkg/cache/middleware.go b/go/pkg/cache/middleware.go index 971d7b1c1d..2152a44518 100644 --- a/go/pkg/cache/middleware.go +++ b/go/pkg/cache/middleware.go @@ -1,3 +1,3 @@ package cache -type Middleware[T any] func(Cache[T]) Cache[T] +type Middleware[K comparable, V any] func(Cache[K, V]) Cache[K, V] diff --git a/go/pkg/cache/middleware/tracing.go b/go/pkg/cache/middleware/tracing.go index d6cbe6be73..3ba79b9e89 100644 --- a/go/pkg/cache/middleware/tracing.go +++ b/go/pkg/cache/middleware/tracing.go @@ -2,24 +2,25 @@ package middleware import ( "context" + "fmt" "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/tracing" "go.opentelemetry.io/otel/attribute" ) -type tracingMiddleware[T any] struct { - next cache.Cache[T] +type tracingMiddleware[K comparable, V any] struct { + next cache.Cache[K, V] } -func WithTracing[T any](c cache.Cache[T]) cache.Cache[T] { - return &tracingMiddleware[T]{next: c} +func WithTracing[K comparable, V any](c cache.Cache[K, V]) cache.Cache[K, V] { + return &tracingMiddleware[K, V]{next: c} } -func (mw *tracingMiddleware[T]) Get(ctx context.Context, key string) (T, cache.CacheHit) { +func (mw *tracingMiddleware[K, V]) Get(ctx context.Context, key K) (V, cache.CacheHit) { ctx, span := tracing.Start(ctx, "cache.Get") defer span.End() - span.SetAttributes(attribute.String("key", key)) + span.SetAttributes(attribute.String("key", fmt.Sprintf("%+v", key))) value, hit := mw.next.Get(ctx, key) span.SetAttributes( @@ -27,32 +28,32 @@ func (mw *tracingMiddleware[T]) Get(ctx context.Context, key string) (T, cache.C ) return value, hit } -func (mw *tracingMiddleware[T]) Set(ctx context.Context, key string, value T) { +func (mw *tracingMiddleware[K, V]) Set(ctx context.Context, key K, value V) { ctx, span := tracing.Start(ctx, "cache.Set") defer span.End() - span.SetAttributes(attribute.String("key", key)) + span.SetAttributes(attribute.String("key", fmt.Sprintf("%+v", key))) mw.next.Set(ctx, key, value) } -func (mw *tracingMiddleware[T]) SetNull(ctx context.Context, key string) { +func (mw *tracingMiddleware[K, V]) SetNull(ctx context.Context, key K) { ctx, span := tracing.Start(ctx, "cache.SetNull") defer span.End() - span.SetAttributes(attribute.String("key", key)) + span.SetAttributes(attribute.String("key", fmt.Sprintf("%+v", key))) mw.next.SetNull(ctx, key) } -func (mw *tracingMiddleware[T]) Remove(ctx context.Context, key string) { +func (mw *tracingMiddleware[K, V]) Remove(ctx context.Context, key K) { ctx, span := tracing.Start(ctx, "cache.Remove") defer span.End() - span.SetAttributes(attribute.String("key", key)) + span.SetAttributes(attribute.String("key", fmt.Sprintf("%+v", key))) mw.next.Remove(ctx, key) } -func (mw *tracingMiddleware[T]) Dump(ctx context.Context) ([]byte, error) { +func (mw *tracingMiddleware[K, V]) Dump(ctx context.Context) ([]byte, error) { ctx, span := tracing.Start(ctx, "cache.Dump") defer span.End() @@ -64,7 +65,7 @@ func (mw *tracingMiddleware[T]) Dump(ctx context.Context) ([]byte, error) { return b, err } -func (mw *tracingMiddleware[T]) Restore(ctx context.Context, data []byte) error { +func (mw *tracingMiddleware[K, V]) Restore(ctx context.Context, data []byte) error { ctx, span := tracing.Start(ctx, "cache.Restore") defer span.End() @@ -76,20 +77,22 @@ func (mw *tracingMiddleware[T]) Restore(ctx context.Context, data []byte) error return err } -func (mw *tracingMiddleware[T]) Clear(ctx context.Context) { +func (mw *tracingMiddleware[K, V]) Clear(ctx context.Context) { ctx, span := tracing.Start(ctx, "cache.Clear") defer span.End() mw.next.Clear(ctx) } -func (mw *tracingMiddleware[T]) SWR(ctx context.Context, key string) (T, bool) { +func (mw *tracingMiddleware[K, V]) SWR(ctx context.Context, key K, refreshFromOrigin func(ctx context.Context) (V, error), translateError func(err error) cache.CacheHit) (V, error) { ctx, span := tracing.Start(ctx, "cache.SWR") defer span.End() - span.SetAttributes(attribute.String("key", key)) + span.SetAttributes(attribute.String("key", fmt.Sprintf("%+v", key))) - value, found := mw.next.SWR(ctx, key) - span.SetAttributes(attribute.Bool("found", found)) - return value, found + value, err := mw.next.SWR(ctx, key, refreshFromOrigin, translateError) + if err != nil { + tracing.RecordError(span, err) + } + return value, err } diff --git a/go/pkg/cache/noop.go b/go/pkg/cache/noop.go index e5b8d18cb2..af0909a1ab 100644 --- a/go/pkg/cache/noop.go +++ b/go/pkg/cache/noop.go @@ -4,29 +4,29 @@ import ( "context" ) -type noopCache[T any] struct{} +type noopCache[K comparable, V any] struct{} -func (c *noopCache[T]) Get(ctx context.Context, key string) (value T, hit CacheHit) { - var t T - return t, Miss +func (c *noopCache[K, V]) Get(ctx context.Context, key K) (value V, hit CacheHit) { + var v V + return v, Miss } -func (c *noopCache[T]) Set(ctx context.Context, key string, value T) {} -func (c *noopCache[T]) SetNull(ctx context.Context, key string) {} +func (c *noopCache[K, V]) Set(ctx context.Context, key K, value V) {} +func (c *noopCache[K, V]) SetNull(ctx context.Context, key K) {} -func (c *noopCache[T]) Remove(ctx context.Context, key string) {} +func (c *noopCache[K, V]) Remove(ctx context.Context, key K) {} -func (c *noopCache[T]) Dump(ctx context.Context) ([]byte, error) { +func (c *noopCache[K, V]) Dump(ctx context.Context) ([]byte, error) { return []byte{}, nil } -func (c *noopCache[T]) Restore(ctx context.Context, data []byte) error { +func (c *noopCache[K, V]) Restore(ctx context.Context, data []byte) error { return nil } -func (c *noopCache[T]) Clear(ctx context.Context) {} -func (c *noopCache[T]) SWR(ctx context.Context, key string) (T, bool) { - var t T - return t, false +func (c *noopCache[K, V]) Clear(ctx context.Context) {} +func (c *noopCache[K, V]) SWR(ctx context.Context, key K, refreshFromOrigin func(context.Context) (V, error), translateError func(err error) CacheHit) (V, error) { + var v V + return v, nil } -func NewNoopCache[T any]() Cache[T] { - return &noopCache[T]{} +func NewNoopCache[K comparable, V any]() Cache[K, V] { + return &noopCache[K, V]{} } diff --git a/go/pkg/cache/simulation_test.go b/go/pkg/cache/simulation_test.go new file mode 100644 index 0000000000..8ee03fcf76 --- /dev/null +++ b/go/pkg/cache/simulation_test.go @@ -0,0 +1,146 @@ +package cache_test + +import ( + "context" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/logging" + "github.com/unkeyed/unkey/go/pkg/sim" +) + +type state struct { + keys []uint64 + cache cache.Cache[uint64, uint64] + clk *clock.TestClock +} + +type setEvent struct{} + +func (e *setEvent) Name() string { + return "set" +} + +func (e *setEvent) Run(rng *rand.Rand, s *state) error { + + key := rng.Uint64() + val := rng.Uint64() + + s.keys = append(s.keys, key) + s.cache.Set(context.Background(), key, val) + + return nil +} + +type getEvent struct{} + +func (e *getEvent) Name() string { + return "get" +} + +func (e *getEvent) Run(rng *rand.Rand, s *state) error { + + stored := len(s.keys) + index := rng.Intn(stored + 1) + + key := rng.Uint64() + + if index < stored { + key = s.keys[index] + } + + s.cache.Get(context.Background(), key) + + return nil +} + +type removeEvent struct{} + +func (e *removeEvent) Name() string { + return "remove" +} + +func (e *removeEvent) Run(rng *rand.Rand, s *state) error { + + stored := len(s.keys) + index := rng.Intn(stored + 1) + + key := rng.Uint64() + + if index < stored { + key = s.keys[index] + } + + s.cache.Remove(context.Background(), key) + + return nil +} + +type advanceTimeEvent struct { + clk *clock.TestClock +} + +func (e *advanceTimeEvent) Name() string { + return "advanceTime" +} + +func (e *advanceTimeEvent) Run(rng *rand.Rand, s *state) error { + nanoseconds := rng.Int63n(60 * 60 * 1000 * 1000) // up to 1h + + e.clk.Tick(time.Duration(nanoseconds) * time.Nanosecond) + + return nil +} + +func TestSimulation(t *testing.T) { + + for i := range 100 { + seed := time.Now().UnixNano() + rand.Int63() + t.Run(fmt.Sprintf("run=%d,seed=%d", i, seed), func(t *testing.T) { + + clk := clock.NewTestClock() + + s := sim.New[state](t, + sim.WithSeed[state](seed), + sim.WithSteps[state](1000000), + sim.WithState(func(rng *rand.Rand) *state { + minTime := 1738364400000 // 2025-01-01 + maxTime := 2527282800000 // 2050-01-01 + unixMilli := minTime + rng.Intn(maxTime-minTime) + clk.Set(time.UnixMilli(int64(unixMilli))) + + fresh := time.Second + time.Duration(rng.Intn(60*60*1000)) + stale := fresh + time.Duration(rng.Intn(24*60*60*1000)) + + c := cache.New[uint64, uint64](cache.Config[uint64, uint64]{ + Clock: clk, + Fresh: fresh, + Stale: stale, + Logger: logging.NewNoop(), + MaxSize: rng.Intn(1_000_000), + Resource: "test", + }) + + return &state{ + keys: []uint64{}, + cache: c, + clk: clk, + } + })) + + s.Run([]sim.Event[state]{ + &setEvent{}, + &getEvent{}, + &removeEvent{}, + &advanceTimeEvent{clk}, + }) + + require.Len(t, s.Errors, 0, "expected no errors") + }) + } +} diff --git a/go/pkg/cache/util.go b/go/pkg/cache/util.go deleted file mode 100644 index 3f703d6755..0000000000 --- a/go/pkg/cache/util.go +++ /dev/null @@ -1,33 +0,0 @@ -package cache - -import ( - "context" -) - -// withCache builds a pullthrough cache function to wrap a database call. -// Example: -// api, found, err := withCache(s.apiCache, s.db.FindApiByKeyAuthId)(ctx, key.KeyAuthId) -func WithCache[T any](c Cache[T], loadFromDatabase func(ctx context.Context, identifier string) (T, bool, error)) func(ctx context.Context, identifier string) (T, bool, error) { - return func(ctx context.Context, identifier string) (T, bool, error) { - value, hit := c.Get(ctx, identifier) - - if hit == Hit { - return value, true, nil - } - if hit == Null { - return value, false, nil - } - - value, found, err := loadFromDatabase(ctx, identifier) - if err != nil { - return value, false, err - } - if found { - c.Set(ctx, identifier, value) - return value, true, nil - } else { - c.SetNull(ctx, identifier) - return value, false, nil - } - } -} diff --git a/go/pkg/database/database.go b/go/pkg/database/database.go index 208e1a2755..96ce6ca3b9 100644 --- a/go/pkg/database/database.go +++ b/go/pkg/database/database.go @@ -2,9 +2,11 @@ package database import ( "database/sql" + "strings" _ "github.com/go-sql-driver/mysql" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/fault" "github.com/unkeyed/unkey/go/pkg/logging" @@ -19,6 +21,8 @@ type Config struct { ReadOnlyDSN string Logger logging.Logger + + Clock clock.Clock } type replica struct { @@ -30,10 +34,19 @@ type database struct { writeReplica *replica readReplica *replica logger logging.Logger + clock clock.Clock } func New(config Config, middlewares ...Middleware) (Database, error) { + if config.Clock == nil { + config.Clock = clock.New() + } + + if !strings.Contains(config.PrimaryDSN, "parseTime=true") { + return nil, fault.New("PrimaryDSN must contain parseTime=true, see https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013") + } + write, err := sql.Open("mysql", config.PrimaryDSN) if err != nil { return nil, fault.Wrap(err, fault.WithDesc("cannot open primary replica", "")) @@ -48,6 +61,9 @@ func New(config Config, middlewares ...Middleware) (Database, error) { query: gen.New(write), } if config.ReadOnlyDSN != "" { + if !strings.Contains(config.ReadOnlyDSN, "parseTime=true") { + return nil, fault.New("ReadOnlyDSN must contain parseTime=true, see https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013") + } read, err := sql.Open("mysql", config.ReadOnlyDSN) if err != nil { return nil, fault.Wrap(err, fault.WithDesc("cannot open read replica", "")) @@ -63,6 +79,7 @@ func New(config Config, middlewares ...Middleware) (Database, error) { writeReplica: writeReplica, readReplica: readReplica, logger: config.Logger, + clock: config.Clock, } for _, mw := range middlewares { diff --git a/go/pkg/database/gen/key_find_by_hash.sql.go b/go/pkg/database/gen/key_find_by_hash.sql.go index caaf243587..12fd0eb53a 100644 --- a/go/pkg/database/gen/key_find_by_hash.sql.go +++ b/go/pkg/database/gen/key_find_by_hash.sql.go @@ -9,17 +9,21 @@ import ( "context" ) -const findKeyByID = `-- name: FindKeyByID :one -SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM ` + "`" + `keys` + "`" + ` -WHERE id = ? +const findKeyByHash = `-- name: FindKeyByHash :one +SELECT + id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment +FROM ` + "`" + `keys` + "`" + ` +WHERE hash = ? ` -// FindKeyByID +// FindKeyByHash // -// SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM `keys` -// WHERE id = ? -func (q *Queries) FindKeyByID(ctx context.Context, id string) (Key, error) { - row := q.db.QueryRowContext(ctx, findKeyByID, id) +// SELECT +// id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment +// FROM `keys` +// WHERE hash = ? +func (q *Queries) FindKeyByHash(ctx context.Context, hash string) (Key, error) { + row := q.db.QueryRowContext(ctx, findKeyByHash, hash) var i Key err := row.Scan( &i.ID, diff --git a/go/pkg/database/gen/key_find_by_id.sql.go b/go/pkg/database/gen/key_find_by_id.sql.go index 918c8a8bc5..5e978f0630 100644 --- a/go/pkg/database/gen/key_find_by_id.sql.go +++ b/go/pkg/database/gen/key_find_by_id.sql.go @@ -9,30 +9,64 @@ import ( "context" ) -const findRatelimitOverrideByIdentifier = `-- name: FindRatelimitOverrideByIdentifier :one -SELECT id, workspace_id, namespace_id, identifier, ` + "`" + `limit` + "`" + `, duration, async, sharding, created_at, updated_at, deleted_at FROM ` + "`" + `ratelimit_overrides` + "`" + ` -WHERE identifier = ? +const findKeyByID = `-- name: FindKeyByID :one +SELECT + k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.created_at, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.deleted_at, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, + i.id, i.external_id, i.workspace_id, i.environment, i.created_at, i.updated_at, i.meta +FROM ` + "`" + `keys` + "`" + ` k +LEFT JOIN identities i ON k.identity_id = i.id +WHERE k.id = ? ` -// FindRatelimitOverrideByIdentifier +type FindKeyByIDRow struct { + Key Key `db:"key"` + Identity Identity `db:"identity"` +} + +// FindKeyByID // -// SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM `ratelimit_overrides` -// WHERE identifier = ? -func (q *Queries) FindRatelimitOverrideByIdentifier(ctx context.Context, identifier string) (RatelimitOverride, error) { - row := q.db.QueryRowContext(ctx, findRatelimitOverrideByIdentifier, identifier) - var i RatelimitOverride +// SELECT +// k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.created_at, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.deleted_at, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, +// i.id, i.external_id, i.workspace_id, i.environment, i.created_at, i.updated_at, i.meta +// FROM `keys` k +// LEFT JOIN identities i ON k.identity_id = i.id +// WHERE k.id = ? +func (q *Queries) FindKeyByID(ctx context.Context, id string) (FindKeyByIDRow, error) { + row := q.db.QueryRowContext(ctx, findKeyByID, id) + var i FindKeyByIDRow err := row.Scan( - &i.ID, - &i.WorkspaceID, - &i.NamespaceID, - &i.Identifier, - &i.Limit, - &i.Duration, - &i.Async, - &i.Sharding, - &i.CreatedAt, - &i.UpdatedAt, - &i.DeletedAt, + &i.Key.ID, + &i.Key.KeyAuthID, + &i.Key.Hash, + &i.Key.Start, + &i.Key.WorkspaceID, + &i.Key.ForWorkspaceID, + &i.Key.Name, + &i.Key.OwnerID, + &i.Key.IdentityID, + &i.Key.Meta, + &i.Key.CreatedAt, + &i.Key.Expires, + &i.Key.CreatedAtM, + &i.Key.UpdatedAtM, + &i.Key.DeletedAtM, + &i.Key.DeletedAt, + &i.Key.RefillDay, + &i.Key.RefillAmount, + &i.Key.LastRefillAt, + &i.Key.Enabled, + &i.Key.RemainingRequests, + &i.Key.RatelimitAsync, + &i.Key.RatelimitLimit, + &i.Key.RatelimitDuration, + &i.Key.Environment, + &i.Identity.ID, + &i.Identity.ExternalID, + &i.Identity.WorkspaceID, + &i.Identity.Environment, + &i.Identity.CreatedAt, + &i.Identity.UpdatedAt, + &i.Identity.Meta, ) return i, err } diff --git a/go/pkg/database/gen/key_insert.sql.go b/go/pkg/database/gen/key_insert.sql.go new file mode 100644 index 0000000000..20ebc1d032 --- /dev/null +++ b/go/pkg/database/gen/key_insert.sql.go @@ -0,0 +1,142 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: key_insert.sql + +package gen + +import ( + "context" + "database/sql" + "time" +) + +const insertKey = `-- name: InsertKey :exec +INSERT INTO ` + "`" + `keys` + "`" + ` ( + id, + key_auth_id, + hash, + start, + workspace_id, + for_workspace_id, + name, + owner_id, + identity_id, + meta, + created_at, + expires, + created_at_m, + enabled, + remaining_requests, + ratelimit_async, + ratelimit_limit, + ratelimit_duration, + environment +) VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + null, + ?, + ?, + ?, + ?, + UNIX_TIMESTAMP() * 1000, + ?, + ?, + ?, + ?, + ?, + ? +) +` + +type InsertKeyParams struct { + ID string `db:"id"` + KeyringID string `db:"keyring_id"` + Hash string `db:"hash"` + Start string `db:"start"` + WorkspaceID string `db:"workspace_id"` + ForWorkspaceID sql.NullString `db:"for_workspace_id"` + Name sql.NullString `db:"name"` + IdentityID sql.NullString `db:"identity_id"` + Meta sql.NullString `db:"meta"` + CreatedAt time.Time `db:"created_at"` + Expires sql.NullTime `db:"expires"` + Enabled bool `db:"enabled"` + RemainingRequests sql.NullInt32 `db:"remaining_requests"` + RatelimitAsync sql.NullBool `db:"ratelimit_async"` + RatelimitLimit sql.NullInt32 `db:"ratelimit_limit"` + RatelimitDuration sql.NullInt64 `db:"ratelimit_duration"` + Environment sql.NullString `db:"environment"` +} + +// InsertKey +// +// INSERT INTO `keys` ( +// id, +// key_auth_id, +// hash, +// start, +// workspace_id, +// for_workspace_id, +// name, +// owner_id, +// identity_id, +// meta, +// created_at, +// expires, +// created_at_m, +// enabled, +// remaining_requests, +// ratelimit_async, +// ratelimit_limit, +// ratelimit_duration, +// environment +// ) VALUES ( +// ?, +// ?, +// ?, +// ?, +// ?, +// ?, +// ?, +// null, +// ?, +// ?, +// ?, +// ?, +// UNIX_TIMESTAMP() * 1000, +// ?, +// ?, +// ?, +// ?, +// ?, +// ? +// ) +func (q *Queries) InsertKey(ctx context.Context, arg InsertKeyParams) error { + _, err := q.db.ExecContext(ctx, insertKey, + arg.ID, + arg.KeyringID, + arg.Hash, + arg.Start, + arg.WorkspaceID, + arg.ForWorkspaceID, + arg.Name, + arg.IdentityID, + arg.Meta, + arg.CreatedAt, + arg.Expires, + arg.Enabled, + arg.RemainingRequests, + arg.RatelimitAsync, + arg.RatelimitLimit, + arg.RatelimitDuration, + arg.Environment, + ) + return err +} diff --git a/go/pkg/database/gen/keyring_insert.sql.go b/go/pkg/database/gen/keyring_insert.sql.go new file mode 100644 index 0000000000..84a151befd --- /dev/null +++ b/go/pkg/database/gen/keyring_insert.sql.go @@ -0,0 +1,81 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: keyring_insert.sql + +package gen + +import ( + "context" + "database/sql" +) + +const insertKeyring = `-- name: InsertKeyring :exec +INSERT INTO ` + "`" + `key_auth` + "`" + ` ( + id, + workspace_id, + created_at, + created_at_m, + store_encrypted_keys, + default_prefix, + default_bytes, + size_approx, + size_last_updated_at +) VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ?, + 0, + 0 +) +` + +type InsertKeyringParams struct { + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + CreatedAt sql.NullTime `db:"created_at"` + CreatedAtM int64 `db:"created_at_m"` + StoreEncryptedKeys bool `db:"store_encrypted_keys"` + DefaultPrefix sql.NullString `db:"default_prefix"` + DefaultBytes sql.NullInt32 `db:"default_bytes"` +} + +// InsertKeyring +// +// INSERT INTO `key_auth` ( +// id, +// workspace_id, +// created_at, +// created_at_m, +// store_encrypted_keys, +// default_prefix, +// default_bytes, +// size_approx, +// size_last_updated_at +// ) VALUES ( +// ?, +// ?, +// ?, +// ?, +// ?, +// ?, +// ?, +// 0, +// 0 +// ) +func (q *Queries) InsertKeyring(ctx context.Context, arg InsertKeyringParams) error { + _, err := q.db.ExecContext(ctx, insertKeyring, + arg.ID, + arg.WorkspaceID, + arg.CreatedAt, + arg.CreatedAtM, + arg.StoreEncryptedKeys, + arg.DefaultPrefix, + arg.DefaultBytes, + ) + return err +} diff --git a/go/pkg/database/gen/permissions_by_key_id.sql.go b/go/pkg/database/gen/permissions_by_key_id.sql.go new file mode 100644 index 0000000000..964a33db72 --- /dev/null +++ b/go/pkg/database/gen/permissions_by_key_id.sql.go @@ -0,0 +1,80 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: permissions_by_key_id.sql + +package gen + +import ( + "context" +) + +const findPermissionsForKey = `-- name: FindPermissionsForKey :many +WITH direct_permissions AS ( + SELECT p.name as permission_name + FROM keys_permissions kp + JOIN permissions p ON kp.permission_id = p.id + WHERE kp.key_id = ? +), +role_permissions AS ( + SELECT p.name as permission_name + FROM keys_roles kr + JOIN roles_permissions rp ON kr.role_id = rp.role_id + JOIN permissions p ON rp.permission_id = p.id + WHERE kr.key_id = ? +) +SELECT DISTINCT permission_name +FROM ( + SELECT permission_name FROM direct_permissions + UNION ALL + SELECT permission_name FROM role_permissions +) all_permissions +` + +type FindPermissionsForKeyParams struct { + KeyID string `db:"key_id"` +} + +// FindPermissionsForKey +// +// WITH direct_permissions AS ( +// SELECT p.name as permission_name +// FROM keys_permissions kp +// JOIN permissions p ON kp.permission_id = p.id +// WHERE kp.key_id = ? +// ), +// role_permissions AS ( +// SELECT p.name as permission_name +// FROM keys_roles kr +// JOIN roles_permissions rp ON kr.role_id = rp.role_id +// JOIN permissions p ON rp.permission_id = p.id +// WHERE kr.key_id = ? +// ) +// SELECT DISTINCT permission_name +// FROM ( +// SELECT permission_name FROM direct_permissions +// UNION ALL +// SELECT permission_name FROM role_permissions +// ) all_permissions +func (q *Queries) FindPermissionsForKey(ctx context.Context, arg FindPermissionsForKeyParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, findPermissionsForKey, arg.KeyID, arg.KeyID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var permission_name string + if err := rows.Scan(&permission_name); err != nil { + return nil, err + } + items = append(items, permission_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/go/pkg/database/gen/querier.go b/go/pkg/database/gen/querier.go index e415db54e8..5d2c9b3ee1 100644 --- a/go/pkg/database/gen/querier.go +++ b/go/pkg/database/gen/querier.go @@ -13,26 +13,32 @@ type Querier interface { //DeleteRatelimitNamespace // // UPDATE `ratelimit_namespaces` - // SET deleted_at = NOW() + // SET deleted_at = ? // WHERE id = ? - DeleteRatelimitNamespace(ctx context.Context, id string) (sql.Result, error) + DeleteRatelimitNamespace(ctx context.Context, arg DeleteRatelimitNamespaceParams) (sql.Result, error) //DeleteRatelimitOverride // // UPDATE `ratelimit_overrides` // SET - // deleted_at = NOW() + // deleted_at = ? // WHERE id = ? - DeleteRatelimitOverride(ctx context.Context, id string) (sql.Result, error) + DeleteRatelimitOverride(ctx context.Context, arg DeleteRatelimitOverrideParams) (sql.Result, error) //FindKeyByHash // - // SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM `keys` + // SELECT + // id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment + // FROM `keys` // WHERE hash = ? FindKeyByHash(ctx context.Context, hash string) (Key, error) //FindKeyByID // - // SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM `keys` - // WHERE id = ? - FindKeyByID(ctx context.Context, id string) (Key, error) + // SELECT + // k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.created_at, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.deleted_at, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, + // i.id, i.external_id, i.workspace_id, i.environment, i.created_at, i.updated_at, i.meta + // FROM `keys` k + // LEFT JOIN identities i ON k.identity_id = i.id + // WHERE k.id = ? + FindKeyByID(ctx context.Context, id string) (FindKeyByIDRow, error) //FindKeyForVerification // // WITH direct_permissions AS ( @@ -92,6 +98,28 @@ type Querier interface { // WHERE k.hash = ? // GROUP BY k.id FindKeyForVerification(ctx context.Context, hash string) (FindKeyForVerificationRow, error) + //FindPermissionsForKey + // + // WITH direct_permissions AS ( + // SELECT p.name as permission_name + // FROM keys_permissions kp + // JOIN permissions p ON kp.permission_id = p.id + // WHERE kp.key_id = ? + // ), + // role_permissions AS ( + // SELECT p.name as permission_name + // FROM keys_roles kr + // JOIN roles_permissions rp ON kr.role_id = rp.role_id + // JOIN permissions p ON rp.permission_id = p.id + // WHERE kr.key_id = ? + // ) + // SELECT DISTINCT permission_name + // FROM ( + // SELECT permission_name FROM direct_permissions + // UNION ALL + // SELECT permission_name FROM role_permissions + // ) all_permissions + FindPermissionsForKey(ctx context.Context, arg FindPermissionsForKeyParams) ([]string, error) //FindRatelimitNamespaceByID // // SELECT id, workspace_id, name, created_at, updated_at, deleted_at FROM `ratelimit_namespaces` @@ -103,11 +131,21 @@ type Querier interface { // WHERE name = ? // AND workspace_id = ? FindRatelimitNamespaceByName(ctx context.Context, arg FindRatelimitNamespaceByNameParams) (RatelimitNamespace, error) - //FindRatelimitOverrideByIdentifier + //FindRatelimitOverridesById + // + // SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides + // WHERE + // workspace_id = ? + // AND id = ? + FindRatelimitOverridesById(ctx context.Context, arg FindRatelimitOverridesByIdParams) (RatelimitOverride, error) + //FindRatelimitOverridesByIdentifier // - // SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM `ratelimit_overrides` - // WHERE identifier = ? - FindRatelimitOverrideByIdentifier(ctx context.Context, identifier string) (RatelimitOverride, error) + // SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides + // WHERE + // workspace_id = ? + // AND namespace_id = ? + // AND identifier LIKE ? + FindRatelimitOverridesByIdentifier(ctx context.Context, arg FindRatelimitOverridesByIdentifierParams) ([]RatelimitOverride, error) //FindWorkspaceByID // // SELECT id, tenant_id, name, created_at, deleted_at, plan, stripe_customer_id, stripe_subscription_id, trial_ends, beta_features, features, plan_locked_until, plan_downgrade_request, plan_changed, subscriptions, enabled, delete_protection FROM `workspaces` @@ -119,7 +157,96 @@ type Querier interface { // WHERE id = ? // AND delete_protection = false HardDeleteWorkspace(ctx context.Context, id string) (sql.Result, error) - //InsertOverride + //InsertKey + // + // INSERT INTO `keys` ( + // id, + // key_auth_id, + // hash, + // start, + // workspace_id, + // for_workspace_id, + // name, + // owner_id, + // identity_id, + // meta, + // created_at, + // expires, + // created_at_m, + // enabled, + // remaining_requests, + // ratelimit_async, + // ratelimit_limit, + // ratelimit_duration, + // environment + // ) VALUES ( + // ?, + // ?, + // ?, + // ?, + // ?, + // ?, + // ?, + // null, + // ?, + // ?, + // ?, + // ?, + // UNIX_TIMESTAMP() * 1000, + // ?, + // ?, + // ?, + // ?, + // ?, + // ? + // ) + InsertKey(ctx context.Context, arg InsertKeyParams) error + //InsertKeyring + // + // INSERT INTO `key_auth` ( + // id, + // workspace_id, + // created_at, + // created_at_m, + // store_encrypted_keys, + // default_prefix, + // default_bytes, + // size_approx, + // size_last_updated_at + // ) VALUES ( + // ?, + // ?, + // ?, + // ?, + // ?, + // ?, + // ?, + // 0, + // 0 + // ) + InsertKeyring(ctx context.Context, arg InsertKeyringParams) error + //InsertRatelimitNamespace + // + // INSERT INTO + // `ratelimit_namespaces` ( + // id, + // workspace_id, + // name, + // created_at, + // updated_at, + // deleted_at + // ) + // VALUES + // ( + // ?, + // ?, + // ?, + // ?, + // NULL, + // NULL + // ) + InsertRatelimitNamespace(ctx context.Context, arg InsertRatelimitNamespaceParams) error + //InsertRatelimitOverride // // INSERT INTO // `ratelimit_overrides` ( @@ -141,9 +268,9 @@ type Querier interface { // ?, // ?, // false, - // now() + // ? // ) - InsertOverride(ctx context.Context, arg InsertOverrideParams) error + InsertRatelimitOverride(ctx context.Context, arg InsertRatelimitOverrideParams) error //InsertWorkspace // // INSERT INTO `workspaces` ( @@ -161,7 +288,7 @@ type Querier interface { // ?, // ?, // ?, - // NOW(), + // ?, // 'free', // '{}', // '{}', @@ -172,10 +299,10 @@ type Querier interface { //SoftDeleteWorkspace // // UPDATE `workspaces` - // SET deleted_at = NOW() + // SET deleted_at = ? // WHERE id = ? // AND delete_protection = false - SoftDeleteWorkspace(ctx context.Context, id string) (sql.Result, error) + SoftDeleteWorkspace(ctx context.Context, arg SoftDeleteWorkspaceParams) (sql.Result, error) //UpdateRatelimitOverride // // UPDATE `ratelimit_overrides` @@ -183,7 +310,7 @@ type Querier interface { // `limit` = ?, // duration = ?, // async = ?, - // updated_at = NOW() + // updated_at = ? // WHERE id = ? UpdateRatelimitOverride(ctx context.Context, arg UpdateRatelimitOverrideParams) (sql.Result, error) //UpdateWorkspaceEnabled diff --git a/go/pkg/database/gen/ratelimit_namespace_delete.sql.go b/go/pkg/database/gen/ratelimit_namespace_delete.sql.go index 0da1e61184..aa48d46516 100644 --- a/go/pkg/database/gen/ratelimit_namespace_delete.sql.go +++ b/go/pkg/database/gen/ratelimit_namespace_delete.sql.go @@ -12,15 +12,20 @@ import ( const deleteRatelimitNamespace = `-- name: DeleteRatelimitNamespace :execresult UPDATE ` + "`" + `ratelimit_namespaces` + "`" + ` -SET deleted_at = NOW() +SET deleted_at = ? WHERE id = ? ` +type DeleteRatelimitNamespaceParams struct { + Now sql.NullTime `db:"now"` + ID string `db:"id"` +} + // DeleteRatelimitNamespace // // UPDATE `ratelimit_namespaces` -// SET deleted_at = NOW() +// SET deleted_at = ? // WHERE id = ? -func (q *Queries) DeleteRatelimitNamespace(ctx context.Context, id string) (sql.Result, error) { - return q.db.ExecContext(ctx, deleteRatelimitNamespace, id) +func (q *Queries) DeleteRatelimitNamespace(ctx context.Context, arg DeleteRatelimitNamespaceParams) (sql.Result, error) { + return q.db.ExecContext(ctx, deleteRatelimitNamespace, arg.Now, arg.ID) } diff --git a/go/pkg/database/gen/ratelimit_namespace_insert.sql.go b/go/pkg/database/gen/ratelimit_namespace_insert.sql.go index a053a29763..71b1bf19ab 100644 --- a/go/pkg/database/gen/ratelimit_namespace_insert.sql.go +++ b/go/pkg/database/gen/ratelimit_namespace_insert.sql.go @@ -7,44 +7,63 @@ package gen import ( "context" + "time" ) const insertRatelimitNamespace = `-- name: InsertRatelimitNamespace :exec -INSERT INTO ` + "`" + `ratelimit_namespaces` + "`" + ` ( - id, - workspace_id, - name, - created_at -) -VALUES ( - ?, - ?, - ?, - NOW() -) +INSERT INTO + ` + "`" + `ratelimit_namespaces` + "`" + ` ( + id, + workspace_id, + name, + created_at, + updated_at, + deleted_at + ) +VALUES + ( + ?, + ?, + ?, + ?, + NULL, + NULL + ) ` type InsertRatelimitNamespaceParams struct { - ID string `db:"id"` - WorkspaceID string `db:"workspace_id"` - Name string `db:"name"` + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + Name string `db:"name"` + CreatedAt time.Time `db:"created_at"` } // InsertRatelimitNamespace // -// INSERT INTO `ratelimit_namespaces` ( -// id, -// workspace_id, -// name, -// created_at -// ) -// VALUES ( -// ?, -// ?, -// ?, -// NOW() -// ) +// INSERT INTO +// `ratelimit_namespaces` ( +// id, +// workspace_id, +// name, +// created_at, +// updated_at, +// deleted_at +// ) +// VALUES +// ( +// ?, +// ?, +// ?, +// ?, +// NULL, +// NULL +// ) func (q *Queries) InsertRatelimitNamespace(ctx context.Context, arg InsertRatelimitNamespaceParams) error { - _, err := q.db.ExecContext(ctx, insertRatelimitNamespace, arg.ID, arg.WorkspaceID, arg.Name) + _, err := q.db.ExecContext(ctx, insertRatelimitNamespace, + arg.ID, + arg.WorkspaceID, + arg.Name, + arg.CreatedAt, + ) return err } diff --git a/go/pkg/database/gen/ratelimit_override_delete.sql.go b/go/pkg/database/gen/ratelimit_override_delete.sql.go index d8f002389d..73cc4173cd 100644 --- a/go/pkg/database/gen/ratelimit_override_delete.sql.go +++ b/go/pkg/database/gen/ratelimit_override_delete.sql.go @@ -13,16 +13,21 @@ import ( const deleteRatelimitOverride = `-- name: DeleteRatelimitOverride :execresult UPDATE ` + "`" + `ratelimit_overrides` + "`" + ` SET - deleted_at = NOW() + deleted_at = ? WHERE id = ? ` +type DeleteRatelimitOverrideParams struct { + Now sql.NullTime `db:"now"` + ID string `db:"id"` +} + // DeleteRatelimitOverride // // UPDATE `ratelimit_overrides` // SET -// deleted_at = NOW() +// deleted_at = ? // WHERE id = ? -func (q *Queries) DeleteRatelimitOverride(ctx context.Context, id string) (sql.Result, error) { - return q.db.ExecContext(ctx, deleteRatelimitOverride, id) +func (q *Queries) DeleteRatelimitOverride(ctx context.Context, arg DeleteRatelimitOverrideParams) (sql.Result, error) { + return q.db.ExecContext(ctx, deleteRatelimitOverride, arg.Now, arg.ID) } diff --git a/go/pkg/database/gen/ratelimit_override_find_by_id.sql.go b/go/pkg/database/gen/ratelimit_override_find_by_id.sql.go new file mode 100644 index 0000000000..14bf2b5e3a --- /dev/null +++ b/go/pkg/database/gen/ratelimit_override_find_by_id.sql.go @@ -0,0 +1,47 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: ratelimit_override_find_by_id.sql + +package gen + +import ( + "context" +) + +const findRatelimitOverridesById = `-- name: FindRatelimitOverridesById :one +SELECT id, workspace_id, namespace_id, identifier, ` + "`" + `limit` + "`" + `, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides +WHERE + workspace_id = ? + AND id = ? +` + +type FindRatelimitOverridesByIdParams struct { + WorkspaceID string `db:"workspace_id"` + OverrideID string `db:"override_id"` +} + +// FindRatelimitOverridesById +// +// SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides +// WHERE +// workspace_id = ? +// AND id = ? +func (q *Queries) FindRatelimitOverridesById(ctx context.Context, arg FindRatelimitOverridesByIdParams) (RatelimitOverride, error) { + row := q.db.QueryRowContext(ctx, findRatelimitOverridesById, arg.WorkspaceID, arg.OverrideID) + var i RatelimitOverride + err := row.Scan( + &i.ID, + &i.WorkspaceID, + &i.NamespaceID, + &i.Identifier, + &i.Limit, + &i.Duration, + &i.Async, + &i.Sharding, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + ) + return i, err +} diff --git a/go/pkg/database/gen/ratelimit_override_find_by_identifier.sql.go b/go/pkg/database/gen/ratelimit_override_find_by_identifier.sql.go index 650bf73382..dda42b86da 100644 --- a/go/pkg/database/gen/ratelimit_override_find_by_identifier.sql.go +++ b/go/pkg/database/gen/ratelimit_override_find_by_identifier.sql.go @@ -9,44 +9,58 @@ import ( "context" ) -const findKeyByHash = `-- name: FindKeyByHash :one -SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM ` + "`" + `keys` + "`" + ` -WHERE hash = ? +const findRatelimitOverridesByIdentifier = `-- name: FindRatelimitOverridesByIdentifier :many +SELECT id, workspace_id, namespace_id, identifier, ` + "`" + `limit` + "`" + `, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides +WHERE + workspace_id = ? + AND namespace_id = ? + AND identifier LIKE ? ` -// FindKeyByHash +type FindRatelimitOverridesByIdentifierParams struct { + WorkspaceID string `db:"workspace_id"` + NamespaceID string `db:"namespace_id"` + Identifier string `db:"identifier"` +} + +// FindRatelimitOverridesByIdentifier // -// SELECT id, key_auth_id, hash, start, workspace_id, for_workspace_id, name, owner_id, identity_id, meta, created_at, expires, created_at_m, updated_at_m, deleted_at_m, deleted_at, refill_day, refill_amount, last_refill_at, enabled, remaining_requests, ratelimit_async, ratelimit_limit, ratelimit_duration, environment FROM `keys` -// WHERE hash = ? -func (q *Queries) FindKeyByHash(ctx context.Context, hash string) (Key, error) { - row := q.db.QueryRowContext(ctx, findKeyByHash, hash) - var i Key - err := row.Scan( - &i.ID, - &i.KeyAuthID, - &i.Hash, - &i.Start, - &i.WorkspaceID, - &i.ForWorkspaceID, - &i.Name, - &i.OwnerID, - &i.IdentityID, - &i.Meta, - &i.CreatedAt, - &i.Expires, - &i.CreatedAtM, - &i.UpdatedAtM, - &i.DeletedAtM, - &i.DeletedAt, - &i.RefillDay, - &i.RefillAmount, - &i.LastRefillAt, - &i.Enabled, - &i.RemainingRequests, - &i.RatelimitAsync, - &i.RatelimitLimit, - &i.RatelimitDuration, - &i.Environment, - ) - return i, err +// SELECT id, workspace_id, namespace_id, identifier, `limit`, duration, async, sharding, created_at, updated_at, deleted_at FROM ratelimit_overrides +// WHERE +// workspace_id = ? +// AND namespace_id = ? +// AND identifier LIKE ? +func (q *Queries) FindRatelimitOverridesByIdentifier(ctx context.Context, arg FindRatelimitOverridesByIdentifierParams) ([]RatelimitOverride, error) { + rows, err := q.db.QueryContext(ctx, findRatelimitOverridesByIdentifier, arg.WorkspaceID, arg.NamespaceID, arg.Identifier) + if err != nil { + return nil, err + } + defer rows.Close() + var items []RatelimitOverride + for rows.Next() { + var i RatelimitOverride + if err := rows.Scan( + &i.ID, + &i.WorkspaceID, + &i.NamespaceID, + &i.Identifier, + &i.Limit, + &i.Duration, + &i.Async, + &i.Sharding, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } diff --git a/go/pkg/database/gen/ratelimit_override_insert.sql.go b/go/pkg/database/gen/ratelimit_override_insert.sql.go index 49d9277711..9be525d782 100644 --- a/go/pkg/database/gen/ratelimit_override_insert.sql.go +++ b/go/pkg/database/gen/ratelimit_override_insert.sql.go @@ -7,9 +7,10 @@ package gen import ( "context" + "time" ) -const insertOverride = `-- name: InsertOverride :exec +const insertRatelimitOverride = `-- name: InsertRatelimitOverride :exec INSERT INTO ` + "`" + `ratelimit_overrides` + "`" + ` ( id, @@ -30,20 +31,21 @@ VALUES ?, ?, false, - now() + ? ) ` -type InsertOverrideParams struct { - ID string `db:"id"` - WorkspaceID string `db:"workspace_id"` - NamespaceID string `db:"namespace_id"` - Identifier string `db:"identifier"` - Limit int32 `db:"limit"` - Duration int32 `db:"duration"` +type InsertRatelimitOverrideParams struct { + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + NamespaceID string `db:"namespace_id"` + Identifier string `db:"identifier"` + Limit int32 `db:"limit"` + Duration int32 `db:"duration"` + CreatedAt time.Time `db:"created_at"` } -// InsertOverride +// InsertRatelimitOverride // // INSERT INTO // `ratelimit_overrides` ( @@ -65,16 +67,17 @@ type InsertOverrideParams struct { // ?, // ?, // false, -// now() +// ? // ) -func (q *Queries) InsertOverride(ctx context.Context, arg InsertOverrideParams) error { - _, err := q.db.ExecContext(ctx, insertOverride, +func (q *Queries) InsertRatelimitOverride(ctx context.Context, arg InsertRatelimitOverrideParams) error { + _, err := q.db.ExecContext(ctx, insertRatelimitOverride, arg.ID, arg.WorkspaceID, arg.NamespaceID, arg.Identifier, arg.Limit, arg.Duration, + arg.CreatedAt, ) return err } diff --git a/go/pkg/database/gen/ratelimit_override_update.sql.go b/go/pkg/database/gen/ratelimit_override_update.sql.go index 381eff64be..9de8f2cf13 100644 --- a/go/pkg/database/gen/ratelimit_override_update.sql.go +++ b/go/pkg/database/gen/ratelimit_override_update.sql.go @@ -16,7 +16,7 @@ SET ` + "`" + `limit` + "`" + ` = ?, duration = ?, async = ?, - updated_at = NOW() + updated_at = ? WHERE id = ? ` @@ -24,6 +24,7 @@ type UpdateRatelimitOverrideParams struct { Windowlimit int32 `db:"windowlimit"` Duration int32 `db:"duration"` Async sql.NullBool `db:"async"` + Now sql.NullTime `db:"now"` ID string `db:"id"` } @@ -34,13 +35,14 @@ type UpdateRatelimitOverrideParams struct { // `limit` = ?, // duration = ?, // async = ?, -// updated_at = NOW() +// updated_at = ? // WHERE id = ? func (q *Queries) UpdateRatelimitOverride(ctx context.Context, arg UpdateRatelimitOverrideParams) (sql.Result, error) { return q.db.ExecContext(ctx, updateRatelimitOverride, arg.Windowlimit, arg.Duration, arg.Async, + arg.Now, arg.ID, ) } diff --git a/go/pkg/database/gen/verify_key.sql.go b/go/pkg/database/gen/verify_key.sql.go deleted file mode 100644 index 62a43a526e..0000000000 --- a/go/pkg/database/gen/verify_key.sql.go +++ /dev/null @@ -1,175 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.27.0 -// source: verify_key.sql - -package gen - -import ( - "context" - "database/sql" -) - -const verifyKey = `-- name: VerifyKey :one -WITH direct_permissions AS ( - SELECT kp.key_id, p.name as permission_name - FROM keys_permissions kp - JOIN permissions p ON kp.permission_id = p.id -), -role_permissions AS ( - SELECT kr.key_id, p.name as permission_name - FROM keys_roles kr - JOIN roles_permissions rp ON kr.role_id = rp.role_id - JOIN permissions p ON rp.permission_id = p.id -), -all_permissions AS ( - SELECT key_id, permission_name FROM direct_permissions - UNION - SELECT key_id, permission_name FROM role_permissions -), -all_ratelimits AS ( - SELECT - key_id as target_id, - 'key' as target_type, - name, - ` + "`" + `limit` + "`" + `, - duration - FROM ratelimits - WHERE key_id IS NOT NULL - UNION - SELECT - identity_id as target_id, - 'identity' as target_type, - name, - ` + "`" + `limit` + "`" + `, - duration - FROM ratelimits - WHERE identity_id IS NOT NULL -) -SELECT - k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.created_at, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.deleted_at, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, - i.id, i.external_id, i.workspace_id, i.environment, i.created_at, i.updated_at, i.meta, - GROUP_CONCAT(DISTINCT rl.target_type) as ratelimit_types, - GROUP_CONCAT(DISTINCT rl.name) as ratelimit_names, - GROUP_CONCAT(DISTINCT rl.limit) as ratelimit_limits, - GROUP_CONCAT(DISTINCT rl.duration) as ratelimit_durations, - GROUP_CONCAT(DISTINCT perms.permission_name) as permissions -FROM ` + "`" + `keys` + "`" + ` k -LEFT JOIN identities i ON k.identity_id = i.id -LEFT JOIN all_permissions perms ON k.id = perms.key_id -LEFT JOIN all_ratelimits rl ON ( - (rl.target_type = 'key' AND rl.target_id = k.id) OR - (rl.target_type = 'identity' AND rl.target_id = k.identity_id) -) -WHERE k.hash = ? -GROUP BY k.id -` - -type VerifyKeyRow struct { - Key Key `db:"key"` - Identity Identity `db:"identity"` - RatelimitTypes sql.NullString `db:"ratelimit_types"` - RatelimitNames sql.NullString `db:"ratelimit_names"` - RatelimitLimits sql.NullString `db:"ratelimit_limits"` - RatelimitDurations sql.NullString `db:"ratelimit_durations"` - Permissions sql.NullString `db:"permissions"` -} - -// VerifyKey -// -// WITH direct_permissions AS ( -// SELECT kp.key_id, p.name as permission_name -// FROM keys_permissions kp -// JOIN permissions p ON kp.permission_id = p.id -// ), -// role_permissions AS ( -// SELECT kr.key_id, p.name as permission_name -// FROM keys_roles kr -// JOIN roles_permissions rp ON kr.role_id = rp.role_id -// JOIN permissions p ON rp.permission_id = p.id -// ), -// all_permissions AS ( -// SELECT key_id, permission_name FROM direct_permissions -// UNION -// SELECT key_id, permission_name FROM role_permissions -// ), -// all_ratelimits AS ( -// SELECT -// key_id as target_id, -// 'key' as target_type, -// name, -// `limit`, -// duration -// FROM ratelimits -// WHERE key_id IS NOT NULL -// UNION -// SELECT -// identity_id as target_id, -// 'identity' as target_type, -// name, -// `limit`, -// duration -// FROM ratelimits -// WHERE identity_id IS NOT NULL -// ) -// SELECT -// k.id, k.key_auth_id, k.hash, k.start, k.workspace_id, k.for_workspace_id, k.name, k.owner_id, k.identity_id, k.meta, k.created_at, k.expires, k.created_at_m, k.updated_at_m, k.deleted_at_m, k.deleted_at, k.refill_day, k.refill_amount, k.last_refill_at, k.enabled, k.remaining_requests, k.ratelimit_async, k.ratelimit_limit, k.ratelimit_duration, k.environment, -// i.id, i.external_id, i.workspace_id, i.environment, i.created_at, i.updated_at, i.meta, -// GROUP_CONCAT(DISTINCT rl.target_type) as ratelimit_types, -// GROUP_CONCAT(DISTINCT rl.name) as ratelimit_names, -// GROUP_CONCAT(DISTINCT rl.limit) as ratelimit_limits, -// GROUP_CONCAT(DISTINCT rl.duration) as ratelimit_durations, -// GROUP_CONCAT(DISTINCT perms.permission_name) as permissions -// FROM `keys` k -// LEFT JOIN identities i ON k.identity_id = i.id -// LEFT JOIN all_permissions perms ON k.id = perms.key_id -// LEFT JOIN all_ratelimits rl ON ( -// (rl.target_type = 'key' AND rl.target_id = k.id) OR -// (rl.target_type = 'identity' AND rl.target_id = k.identity_id) -// ) -// WHERE k.hash = ? -// GROUP BY k.id -func (q *Queries) VerifyKey(ctx context.Context, hash string) (VerifyKeyRow, error) { - row := q.db.QueryRowContext(ctx, verifyKey, hash) - var i VerifyKeyRow - err := row.Scan( - &i.Key.ID, - &i.Key.KeyAuthID, - &i.Key.Hash, - &i.Key.Start, - &i.Key.WorkspaceID, - &i.Key.ForWorkspaceID, - &i.Key.Name, - &i.Key.OwnerID, - &i.Key.IdentityID, - &i.Key.Meta, - &i.Key.CreatedAt, - &i.Key.Expires, - &i.Key.CreatedAtM, - &i.Key.UpdatedAtM, - &i.Key.DeletedAtM, - &i.Key.DeletedAt, - &i.Key.RefillDay, - &i.Key.RefillAmount, - &i.Key.LastRefillAt, - &i.Key.Enabled, - &i.Key.RemainingRequests, - &i.Key.RatelimitAsync, - &i.Key.RatelimitLimit, - &i.Key.RatelimitDuration, - &i.Key.Environment, - &i.Identity.ID, - &i.Identity.ExternalID, - &i.Identity.WorkspaceID, - &i.Identity.Environment, - &i.Identity.CreatedAt, - &i.Identity.UpdatedAt, - &i.Identity.Meta, - &i.RatelimitTypes, - &i.RatelimitNames, - &i.RatelimitLimits, - &i.RatelimitDurations, - &i.Permissions, - ) - return i, err -} diff --git a/go/pkg/database/gen/workspace_insert.sql.go b/go/pkg/database/gen/workspace_insert.sql.go index 35e70c8422..877facc649 100644 --- a/go/pkg/database/gen/workspace_insert.sql.go +++ b/go/pkg/database/gen/workspace_insert.sql.go @@ -7,6 +7,7 @@ package gen import ( "context" + "database/sql" ) const insertWorkspace = `-- name: InsertWorkspace :exec @@ -25,7 +26,7 @@ VALUES ( ?, ?, ?, - NOW(), + ?, 'free', '{}', '{}', @@ -35,9 +36,10 @@ VALUES ( ` type InsertWorkspaceParams struct { - ID string `db:"id"` - TenantID string `db:"tenant_id"` - Name string `db:"name"` + ID string `db:"id"` + TenantID string `db:"tenant_id"` + Name string `db:"name"` + CreatedAt sql.NullTime `db:"created_at"` } // InsertWorkspace @@ -57,7 +59,7 @@ type InsertWorkspaceParams struct { // ?, // ?, // ?, -// NOW(), +// ?, // 'free', // '{}', // '{}', @@ -65,6 +67,11 @@ type InsertWorkspaceParams struct { // true // ) func (q *Queries) InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) error { - _, err := q.db.ExecContext(ctx, insertWorkspace, arg.ID, arg.TenantID, arg.Name) + _, err := q.db.ExecContext(ctx, insertWorkspace, + arg.ID, + arg.TenantID, + arg.Name, + arg.CreatedAt, + ) return err } diff --git a/go/pkg/database/gen/workspace_soft_delete.sql.go b/go/pkg/database/gen/workspace_soft_delete.sql.go index 3ad79f6235..6dfb8da647 100644 --- a/go/pkg/database/gen/workspace_soft_delete.sql.go +++ b/go/pkg/database/gen/workspace_soft_delete.sql.go @@ -12,17 +12,22 @@ import ( const softDeleteWorkspace = `-- name: SoftDeleteWorkspace :execresult UPDATE ` + "`" + `workspaces` + "`" + ` -SET deleted_at = NOW() +SET deleted_at = ? WHERE id = ? AND delete_protection = false ` +type SoftDeleteWorkspaceParams struct { + Now sql.NullTime `db:"now"` + ID string `db:"id"` +} + // SoftDeleteWorkspace // // UPDATE `workspaces` -// SET deleted_at = NOW() +// SET deleted_at = ? // WHERE id = ? // AND delete_protection = false -func (q *Queries) SoftDeleteWorkspace(ctx context.Context, id string) (sql.Result, error) { - return q.db.ExecContext(ctx, softDeleteWorkspace, id) +func (q *Queries) SoftDeleteWorkspace(ctx context.Context, arg SoftDeleteWorkspaceParams) (sql.Result, error) { + return q.db.ExecContext(ctx, softDeleteWorkspace, arg.Now, arg.ID) } diff --git a/go/pkg/database/interface.go b/go/pkg/database/interface.go index 264c96d293..76f2ea5598 100644 --- a/go/pkg/database/interface.go +++ b/go/pkg/database/interface.go @@ -2,10 +2,15 @@ package database import ( "context" + "errors" "github.com/unkeyed/unkey/go/pkg/entities" ) +var ( + ErrNotFound = errors.New("not found") +) + type Database interface { // Workspace @@ -17,7 +22,9 @@ type Database interface { // FindWorkspace(ctx context.Context, workspaceId string) (entities.Workspace, bool, error) DeleteWorkspace(ctx context.Context, id string, hardDelete bool) error - // KeyAuth + // KeyRing + InsertKeyring(ctx context.Context, keyring entities.Keyring) error + // InsertKeyAuth(ctx context.Context, newKeyAuth entities.KeyAuth) error // DeleteKeyAuth(ctx context.Context, keyAuthId string) error // FindKeyAuth(ctx context.Context, keyAuthId string) (keyauth entities.KeyAuth, found bool, err error) @@ -30,7 +37,7 @@ type Database interface { // ListAllApis(ctx context.Context, limit int, offset int) ([]entities.Api, error) // Key - // InsertKey(ctx context.Context, newKey entities.Key) error + InsertKey(ctx context.Context, newKey entities.Key) error FindKeyByID(ctx context.Context, keyId string) (key entities.Key, err error) FindKeyByHash(ctx context.Context, hash string) (key entities.Key, err error) FindKeyForVerification(ctx context.Context, hash string) (key entities.Key, err error) @@ -40,6 +47,9 @@ type Database interface { // CountKeys(ctx context.Context, keyAuthId string) (int64, error) // ListKeys(ctx context.Context, keyAuthId string, ownerId string, limit int, offset int) ([]entities.Key, error) + // Permissions + FindPermissionsByKeyID(ctx context.Context, keyID string) ([]string, error) + // Ratelimit Namespace InsertRatelimitNamespace(ctx context.Context, namespace entities.RatelimitNamespace) error FindRatelimitNamespaceByID(ctx context.Context, id string) (entities.RatelimitNamespace, error) @@ -48,7 +58,8 @@ type Database interface { // Ratelimit Override InsertRatelimitOverride(ctx context.Context, ratelimitOverride entities.RatelimitOverride) error - FindRatelimitOverrideByIdentifier(ctx context.Context, identifier string) (ratelimitOverride entities.RatelimitOverride, err error) + FindRatelimitOverridesByIdentifier(ctx context.Context, workspaceId, namespaceId, identifier string) (ratelimitOverrides []entities.RatelimitOverride, err error) + FindRatelimitOverrideByID(ctx context.Context, workspaceID, identifier string) (ratelimitOverride entities.RatelimitOverride, err error) UpdateRatelimitOverride(ctx context.Context, override entities.RatelimitOverride) error DeleteRatelimitOverride(ctx context.Context, id string) error diff --git a/go/pkg/database/key_find_by_hash.go b/go/pkg/database/key_find_by_hash.go index 5dd146af8e..4d8edb617e 100644 --- a/go/pkg/database/key_find_by_hash.go +++ b/go/pkg/database/key_find_by_hash.go @@ -3,6 +3,8 @@ package database import ( "context" "database/sql" + "fmt" + "log/slog" "errors" @@ -15,10 +17,12 @@ func (db *database) FindKeyByHash(ctx context.Context, hash string) (entities.Ke model, err := db.read().FindKeyByHash(ctx, hash) if err != nil { + db.logger.Error(ctx, "found key by hash", slog.Any("model", model), slog.Any("error", err)) + if errors.Is(err, sql.ErrNoRows) { return entities.Key{}, fault.Wrap(err, fault.WithTag(fault.NOT_FOUND), - fault.WithDesc("not found", "The key does not exist."), + fault.WithDesc("not found", fmt.Sprintf("The key %s does not exist.", hash)), ) } return entities.Key{}, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) @@ -28,5 +32,6 @@ func (db *database) FindKeyByHash(ctx context.Context, hash string) (entities.Ke if err != nil { return entities.Key{}, fault.Wrap(err, fault.WithDesc("cannot transform key model to entity", "")) } + return key, nil } diff --git a/go/pkg/database/key_find_by_id.go b/go/pkg/database/key_find_by_id.go index deaa8ce429..f6e9a7d940 100644 --- a/go/pkg/database/key_find_by_id.go +++ b/go/pkg/database/key_find_by_id.go @@ -25,9 +25,15 @@ func (db *database) FindKeyByID(ctx context.Context, keyID string) (entities.Key return entities.Key{}, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) } - key, err := transform.KeyModelToEntity(model) + key, err := transform.KeyModelToEntity(model.Key) if err != nil { return entities.Key{}, fault.Wrap(err, fault.WithDesc("cannot transform key model to entity", "")) } + + identiy, err := transform.IdentityModelToEntity(model.Identity) + if err != nil { + return entities.Key{}, fault.Wrap(err, fault.WithDesc("cannot transform identity model to entity", "")) + } + key.Identity = &identiy return key, nil } diff --git a/go/pkg/database/key_insert.go b/go/pkg/database/key_insert.go new file mode 100644 index 0000000000..60127b87e3 --- /dev/null +++ b/go/pkg/database/key_insert.go @@ -0,0 +1,87 @@ +package database + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/unkeyed/unkey/go/pkg/database/gen" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/fault" +) + +func (db *database) InsertKey(ctx context.Context, key entities.Key) error { + meta, err := json.Marshal(key.Meta) + if err != nil { + return fault.Wrap(err, fault.WithDesc("failed to marshal key meta", "")) + } + + remaining := int32(0) + if key.RemainingRequests != nil { + // nolint:gosec + remaining = int32(*key.RemainingRequests) + } + + identityID := "" + if key.Identity != nil { + // nolint:gosec + identityID = key.Identity.ID + } + + params := gen.InsertKeyParams{ + ID: key.ID, + KeyringID: key.KeyringID, + Hash: key.Hash, + Start: key.Start, + WorkspaceID: key.WorkspaceID, + ForWorkspaceID: sql.NullString{ + String: key.ForWorkspaceID, + Valid: key.ForWorkspaceID != "", + }, + Name: sql.NullString{ + String: key.Name, + Valid: key.Name != "", + }, + IdentityID: sql.NullString{ + String: identityID, + Valid: identityID != "", + }, + Meta: sql.NullString{ + String: string(meta), + Valid: true, + }, + Expires: sql.NullTime{ + Time: key.Expires, + Valid: !key.Expires.IsZero(), + }, + Enabled: key.Enabled, + RemainingRequests: sql.NullInt32{ + Int32: remaining, + Valid: key.RemainingRequests != nil, + }, + RatelimitAsync: sql.NullBool{ + Bool: false, + Valid: false, + }, + RatelimitLimit: sql.NullInt32{ + Int32: 0, + Valid: false, + }, + RatelimitDuration: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + Environment: sql.NullString{ + String: key.Environment, + Valid: key.Environment != "", + }, + CreatedAt: db.clock.Now(), + } + + err = db.write().InsertKey(ctx, params) + if err != nil { + return fault.Wrap(err, fault.WithDesc("failed to insert key", "")) + } + + return nil +} diff --git a/go/pkg/database/keyring_insert.go b/go/pkg/database/keyring_insert.go new file mode 100644 index 0000000000..542ce0e42c --- /dev/null +++ b/go/pkg/database/keyring_insert.go @@ -0,0 +1,39 @@ +package database + +import ( + "context" + "database/sql" + + "github.com/unkeyed/unkey/go/pkg/database/gen" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/fault" +) + +func (db *database) InsertKeyring(ctx context.Context, keyring entities.Keyring) error { + params := gen.InsertKeyringParams{ + ID: keyring.ID, + WorkspaceID: keyring.WorkspaceID, + StoreEncryptedKeys: keyring.StoreEncryptedKeys, + DefaultPrefix: sql.NullString{ + String: keyring.DefaultPrefix, + Valid: keyring.DefaultPrefix != "", + }, + DefaultBytes: sql.NullInt32{ + // nolint:gosec + Int32: int32(keyring.DefaultBytes), + Valid: keyring.DefaultBytes != 0, + }, + CreatedAt: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, + CreatedAtM: db.clock.Now().UnixMilli(), + } + + err := db.write().InsertKeyring(ctx, params) + if err != nil { + return fault.Wrap(err, fault.WithDesc("failed to insert key ring", "")) + } + + return nil +} diff --git a/go/pkg/database/middleware/cache/cache.go b/go/pkg/database/middleware/cache/cache.go new file mode 100644 index 0000000000..559f4247a3 --- /dev/null +++ b/go/pkg/database/middleware/cache/cache.go @@ -0,0 +1,200 @@ +package cache + +import ( + "context" + "errors" + "time" + + "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/database" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/logging" +) + +type cacheMiddleware struct { + db database.Database + + translateError func(error) cache.CacheHit + + keyByHash cache.Cache[string, entities.Key] + workspaceByID cache.Cache[string, entities.Workspace] + keyByID cache.Cache[string, entities.Key] + ratelimitNamespaceByID cache.Cache[string, entities.RatelimitNamespace] + ratelimitNamespaceByName cache.Cache[KeyRatelimitNamespaceByName, entities.RatelimitNamespace] + ratelimitOverridesByIdentifier cache.Cache[KeyRatelimitOverridesByIdentifier, []entities.RatelimitOverride] + ratelimitOverrideByID cache.Cache[KeyRatelimitOverrideByID, entities.RatelimitOverride] + permissionsByKeyID cache.Cache[string, []string] +} + +var _ database.Database = (*cacheMiddleware)(nil) + +func WithCaching(logger logging.Logger) database.Middleware { + + clk := clock.New() + return func(db database.Database) database.Database { + + return &cacheMiddleware{ + db: db, + translateError: func(err error) cache.CacheHit { + if err == nil { + return cache.Hit + } + if errors.Is(err, database.ErrNotFound) { + // if no data was found, we store a special NullEntry + return cache.Null + } + // some other error, which we don't want to cache + return cache.Miss + + }, + keyByHash: cache.New[string, entities.Key](cache.Config[string, entities.Key]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "key_by_hash", + Clock: clk, + }), + keyByID: cache.New[string, entities.Key](cache.Config[string, entities.Key]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "key_by_id", + Clock: clk, + }), + workspaceByID: cache.New[string, entities.Workspace](cache.Config[string, entities.Workspace]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "workspace_by_id", + Clock: clk, + }), + ratelimitNamespaceByID: cache.New[string, entities.RatelimitNamespace](cache.Config[string, entities.RatelimitNamespace]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "ratelimit_namespace_by_id", + Clock: clk, + }), + ratelimitNamespaceByName: cache.New[KeyRatelimitNamespaceByName, entities.RatelimitNamespace](cache.Config[KeyRatelimitNamespaceByName, entities.RatelimitNamespace]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "ratelimit_namespace_by_name", + Clock: clk, + }), + ratelimitOverridesByIdentifier: cache.New[KeyRatelimitOverridesByIdentifier, []entities.RatelimitOverride](cache.Config[KeyRatelimitOverridesByIdentifier, []entities.RatelimitOverride]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "ratelimit_overrides_by_identifier", + Clock: clk, + }), + ratelimitOverrideByID: cache.New[KeyRatelimitOverrideByID, entities.RatelimitOverride](cache.Config[KeyRatelimitOverrideByID, entities.RatelimitOverride]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "ratelimit_override_by_id", + Clock: clk, + }), + permissionsByKeyID: cache.New[string, []string](cache.Config[string, []string]{ + Fresh: 10 * time.Second, + Stale: 1 * time.Minute, + Logger: logger, + MaxSize: 1_000_000, + Resource: "permissions_by_key_id", + Clock: clk, + }), + } + } +} + +func (c *cacheMiddleware) InsertWorkspace(ctx context.Context, workspace entities.Workspace) error { + return c.db.InsertWorkspace(ctx, workspace) +} +func (c *cacheMiddleware) FindWorkspaceByID(ctx context.Context, id string) (entities.Workspace, error) { + return c.workspaceByID.SWR(ctx, id, func(refreshCtx context.Context) (entities.Workspace, error) { + return c.db.FindWorkspaceByID(refreshCtx, id) + }, c.translateError) + +} +func (c *cacheMiddleware) UpdateWorkspacePlan(ctx context.Context, workspaceID string, plan entities.WorkspacePlan) error { + return c.db.UpdateWorkspacePlan(ctx, workspaceID, plan) +} +func (c *cacheMiddleware) UpdateWorkspaceEnabled(ctx context.Context, id string, enabled bool) error { + return c.db.UpdateWorkspaceEnabled(ctx, id, enabled) +} +func (c *cacheMiddleware) DeleteWorkspace(ctx context.Context, id string, hardDelete bool) error { + return c.db.DeleteWorkspace(ctx, id, hardDelete) +} + +func (c *cacheMiddleware) InsertKeyring(ctx context.Context, keyring entities.Keyring) error { + return c.db.InsertKeyring(ctx, keyring) +} + +func (c *cacheMiddleware) InsertKey(ctx context.Context, key entities.Key) error { + return c.db.InsertKey(ctx, key) +} +func (c *cacheMiddleware) FindKeyByID(ctx context.Context, keyID string) (key entities.Key, err error) { + return c.keyByID.SWR(ctx, keyID, func(refreshCtx context.Context) (entities.Key, error) { + return c.db.FindKeyByID(refreshCtx, keyID) + }, c.translateError) +} +func (c *cacheMiddleware) FindKeyByHash(ctx context.Context, hash string) (key entities.Key, err error) { + return c.keyByHash.SWR(ctx, hash, func(refreshCtx context.Context) (entities.Key, error) { + return c.db.FindKeyByHash(refreshCtx, hash) + }, c.translateError) +} +func (c *cacheMiddleware) FindPermissionsByKeyID(ctx context.Context, keyID string) (permissions []string, err error) { + return c.permissionsByKeyID.SWR(ctx, keyID, func(refreshCtx context.Context) ([]string, error) { + return c.db.FindPermissionsByKeyID(refreshCtx, keyID) + }, c.translateError) +} +func (c *cacheMiddleware) FindKeyForVerification(ctx context.Context, hash string) (key entities.Key, err error) { + panic("IMPLEMENT ME") +} +func (c *cacheMiddleware) InsertRatelimitNamespace(ctx context.Context, namespace entities.RatelimitNamespace) error { + return c.db.InsertRatelimitNamespace(ctx, namespace) +} +func (c *cacheMiddleware) FindRatelimitNamespaceByID(ctx context.Context, namespaceID string) (entities.RatelimitNamespace, error) { + return c.ratelimitNamespaceByID.SWR(ctx, namespaceID, func(refreshCtx context.Context) (entities.RatelimitNamespace, error) { + return c.db.FindRatelimitNamespaceByID(refreshCtx, namespaceID) + }, c.translateError) +} +func (c *cacheMiddleware) FindRatelimitNamespaceByName(ctx context.Context, workspaceID string, name string) (entities.RatelimitNamespace, error) { + return c.ratelimitNamespaceByName.SWR(ctx, KeyRatelimitNamespaceByName{WorkspaceID: workspaceID, NamespaceName: name}, func(refreshCtx context.Context) (entities.RatelimitNamespace, error) { + return c.db.FindRatelimitNamespaceByName(refreshCtx, workspaceID, name) + }, c.translateError) +} +func (c *cacheMiddleware) DeleteRatelimitNamespace(ctx context.Context, id string) error { + return c.db.DeleteRatelimitNamespace(ctx, id) +} +func (c *cacheMiddleware) InsertRatelimitOverride(ctx context.Context, ratelimitOverride entities.RatelimitOverride) error { + return c.db.InsertRatelimitOverride(ctx, ratelimitOverride) +} +func (c *cacheMiddleware) FindRatelimitOverridesByIdentifier(ctx context.Context, workspaceID, namespaceID, identifier string) (ratelimitOverrides []entities.RatelimitOverride, err error) { + return c.ratelimitOverridesByIdentifier.SWR(ctx, KeyRatelimitOverridesByIdentifier{WorkspaceID: workspaceID, NamespaceID: namespaceID, Identifier: identifier}, func(refreshCtx context.Context) ([]entities.RatelimitOverride, error) { + return c.db.FindRatelimitOverridesByIdentifier(refreshCtx, workspaceID, namespaceID, identifier) + }, c.translateError) +} +func (c *cacheMiddleware) FindRatelimitOverrideByID(ctx context.Context, workspaceID, overrideID string) (ratelimitOverrides entities.RatelimitOverride, err error) { + return c.ratelimitOverrideByID.SWR(ctx, KeyRatelimitOverrideByID{WorkspaceID: workspaceID, OverrideID: overrideID}, func(refreshCtx context.Context) (entities.RatelimitOverride, error) { + return c.db.FindRatelimitOverrideByID(refreshCtx, workspaceID, overrideID) + }, c.translateError) +} +func (c *cacheMiddleware) UpdateRatelimitOverride(ctx context.Context, override entities.RatelimitOverride) error { + return c.db.UpdateRatelimitOverride(ctx, override) +} +func (c *cacheMiddleware) DeleteRatelimitOverride(ctx context.Context, id string) error { + return c.db.DeleteRatelimitOverride(ctx, id) +} +func (c *cacheMiddleware) Close() error { + return c.db.Close() +} diff --git a/go/pkg/database/middleware/cache/keys.go b/go/pkg/database/middleware/cache/keys.go new file mode 100644 index 0000000000..53226b7410 --- /dev/null +++ b/go/pkg/database/middleware/cache/keys.go @@ -0,0 +1,17 @@ +package cache + +type KeyRatelimitNamespaceByName struct { + WorkspaceID string + NamespaceName string +} + +type KeyRatelimitOverridesByIdentifier struct { + WorkspaceID string + Identifier string + NamespaceID string +} + +type KeyRatelimitOverrideByID struct { + WorkspaceID string + OverrideID string +} diff --git a/go/pkg/database/permissions_find_by_key_id.go b/go/pkg/database/permissions_find_by_key_id.go new file mode 100644 index 0000000000..e6fe92d150 --- /dev/null +++ b/go/pkg/database/permissions_find_by_key_id.go @@ -0,0 +1,22 @@ +package database + +import ( + "context" + "database/sql" + "errors" + + "github.com/unkeyed/unkey/go/pkg/database/gen" + "github.com/unkeyed/unkey/go/pkg/fault" +) + +func (db *database) FindPermissionsByKeyID(ctx context.Context, keyID string) ([]string, error) { + permissions, err := db.read().FindPermissionsForKey(ctx, gen.FindPermissionsForKeyParams{KeyID: keyID}) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return []string{}, nil + } + return nil, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) + } + + return permissions, nil +} diff --git a/go/pkg/database/queries/key_find_by_hash.sql b/go/pkg/database/queries/key_find_by_hash.sql index d97713f6da..bb03710055 100644 --- a/go/pkg/database/queries/key_find_by_hash.sql +++ b/go/pkg/database/queries/key_find_by_hash.sql @@ -1,3 +1,6 @@ --- name: FindKeyByID :one -SELECT * FROM `keys` -WHERE id = sqlc.arg(id); + +-- name: FindKeyByHash :one +SELECT + * +FROM `keys` +WHERE hash = sqlc.arg(hash); diff --git a/go/pkg/database/queries/key_find_by_id.sql b/go/pkg/database/queries/key_find_by_id.sql index f4b3c1cbc0..884bc632db 100644 --- a/go/pkg/database/queries/key_find_by_id.sql +++ b/go/pkg/database/queries/key_find_by_id.sql @@ -1,3 +1,7 @@ --- name: FindRatelimitOverrideByIdentifier :one -SELECT * FROM `ratelimit_overrides` -WHERE identifier = sqlc.arg(identifier); +-- name: FindKeyByID :one +SELECT + sqlc.embed(k), + sqlc.embed(i) +FROM `keys` k +LEFT JOIN identities i ON k.identity_id = i.id +WHERE k.id = sqlc.arg(id); diff --git a/go/pkg/database/queries/key_insert.sql b/go/pkg/database/queries/key_insert.sql new file mode 100644 index 0000000000..59625128eb --- /dev/null +++ b/go/pkg/database/queries/key_insert.sql @@ -0,0 +1,42 @@ +-- name: InsertKey :exec +INSERT INTO `keys` ( + id, + key_auth_id, + hash, + start, + workspace_id, + for_workspace_id, + name, + owner_id, + identity_id, + meta, + created_at, + expires, + created_at_m, + enabled, + remaining_requests, + ratelimit_async, + ratelimit_limit, + ratelimit_duration, + environment +) VALUES ( + sqlc.arg(id), + sqlc.arg(keyring_id), + sqlc.arg(hash), + sqlc.arg(start), + sqlc.arg(workspace_id), + sqlc.arg(for_workspace_id), + sqlc.arg(name), + null, + sqlc.arg(identity_id), + sqlc.arg(meta), + sqlc.arg(created_at), + sqlc.arg(expires), + UNIX_TIMESTAMP() * 1000, + sqlc.arg(enabled), + sqlc.arg(remaining_requests), + sqlc.arg(ratelimit_async), + sqlc.arg(ratelimit_limit), + sqlc.arg(ratelimit_duration), + sqlc.arg(environment) +); diff --git a/go/pkg/database/queries/keyring_insert.sql b/go/pkg/database/queries/keyring_insert.sql new file mode 100644 index 0000000000..478362a2cb --- /dev/null +++ b/go/pkg/database/queries/keyring_insert.sql @@ -0,0 +1,22 @@ +-- name: InsertKeyring :exec +INSERT INTO `key_auth` ( + id, + workspace_id, + created_at, + created_at_m, + store_encrypted_keys, + default_prefix, + default_bytes, + size_approx, + size_last_updated_at +) VALUES ( + sqlc.arg(id), + sqlc.arg(workspace_id), + sqlc.arg(created_at), + sqlc.arg(created_at_m), + sqlc.arg(store_encrypted_keys), + sqlc.arg(default_prefix), + sqlc.arg(default_bytes), + 0, + 0 +); diff --git a/go/pkg/database/queries/permissions_by_key_id.sql b/go/pkg/database/queries/permissions_by_key_id.sql new file mode 100644 index 0000000000..9b370e3c49 --- /dev/null +++ b/go/pkg/database/queries/permissions_by_key_id.sql @@ -0,0 +1,20 @@ +-- name: FindPermissionsForKey :many +WITH direct_permissions AS ( + SELECT p.name as permission_name + FROM keys_permissions kp + JOIN permissions p ON kp.permission_id = p.id + WHERE kp.key_id = sqlc.arg(key_id) +), +role_permissions AS ( + SELECT p.name as permission_name + FROM keys_roles kr + JOIN roles_permissions rp ON kr.role_id = rp.role_id + JOIN permissions p ON rp.permission_id = p.id + WHERE kr.key_id = sqlc.arg(key_id) +) +SELECT DISTINCT permission_name +FROM ( + SELECT permission_name FROM direct_permissions + UNION ALL + SELECT permission_name FROM role_permissions +) all_permissions; diff --git a/go/pkg/database/queries/ratelimit_namespace_delete.sql b/go/pkg/database/queries/ratelimit_namespace_delete.sql index 405058d5d6..41b07d7e06 100644 --- a/go/pkg/database/queries/ratelimit_namespace_delete.sql +++ b/go/pkg/database/queries/ratelimit_namespace_delete.sql @@ -1,4 +1,4 @@ -- name: DeleteRatelimitNamespace :execresult UPDATE `ratelimit_namespaces` -SET deleted_at = NOW() +SET deleted_at = sqlc.arg(now) WHERE id = sqlc.arg(id); diff --git a/go/pkg/database/queries/ratelimit_namespace_insert.sql b/go/pkg/database/queries/ratelimit_namespace_insert.sql new file mode 100644 index 0000000000..a6033363c3 --- /dev/null +++ b/go/pkg/database/queries/ratelimit_namespace_insert.sql @@ -0,0 +1,20 @@ +-- name: InsertRatelimitNamespace :exec +INSERT INTO + `ratelimit_namespaces` ( + id, + workspace_id, + name, + created_at, + updated_at, + deleted_at + ) +VALUES + ( + sqlc.arg("id"), + sqlc.arg("workspace_id"), + sqlc.arg("name"), + sqlc.arg(created_at), + NULL, + NULL + ) +; diff --git a/go/pkg/database/queries/ratelimit_override_delete.sql b/go/pkg/database/queries/ratelimit_override_delete.sql index 58ef9dec7d..b10e171e5c 100644 --- a/go/pkg/database/queries/ratelimit_override_delete.sql +++ b/go/pkg/database/queries/ratelimit_override_delete.sql @@ -1,5 +1,5 @@ -- name: DeleteRatelimitOverride :execresult UPDATE `ratelimit_overrides` SET - deleted_at = NOW() + deleted_at = sqlc.arg(now) WHERE id = sqlc.arg(id); diff --git a/go/pkg/database/queries/ratelimit_override_find_by_id.sql b/go/pkg/database/queries/ratelimit_override_find_by_id.sql new file mode 100644 index 0000000000..d474b88d67 --- /dev/null +++ b/go/pkg/database/queries/ratelimit_override_find_by_id.sql @@ -0,0 +1,5 @@ +-- name: FindRatelimitOverridesById :one +SELECT * FROM ratelimit_overrides +WHERE + workspace_id = sqlc.arg(workspace_id) + AND id = sqlc.arg(override_id); diff --git a/go/pkg/database/queries/ratelimit_override_find_by_identifier.sql b/go/pkg/database/queries/ratelimit_override_find_by_identifier.sql index 8b395d5955..1d494f1401 100644 --- a/go/pkg/database/queries/ratelimit_override_find_by_identifier.sql +++ b/go/pkg/database/queries/ratelimit_override_find_by_identifier.sql @@ -1,3 +1,6 @@ --- name: FindKeyByHash :one -SELECT * FROM `keys` -WHERE hash = sqlc.arg(hash); +-- name: FindRatelimitOverridesByIdentifier :many +SELECT * FROM ratelimit_overrides +WHERE + workspace_id = sqlc.arg(workspace_id) + AND namespace_id = sqlc.arg(namespace_id) + AND identifier LIKE sqlc.arg(identifier); diff --git a/go/pkg/database/queries/ratelimit_override_insert.sql b/go/pkg/database/queries/ratelimit_override_insert.sql index b19d86b5cc..7395811c38 100644 --- a/go/pkg/database/queries/ratelimit_override_insert.sql +++ b/go/pkg/database/queries/ratelimit_override_insert.sql @@ -1,4 +1,4 @@ --- name: InsertOverride :exec +-- name: InsertRatelimitOverride :exec INSERT INTO `ratelimit_overrides` ( id, @@ -19,5 +19,5 @@ VALUES sqlc.arg("limit"), sqlc.arg("duration"), false, - now() + sqlc.arg("created_at") ) diff --git a/go/pkg/database/queries/ratelimit_override_update.sql b/go/pkg/database/queries/ratelimit_override_update.sql index ab27b131cc..a9d38e072e 100644 --- a/go/pkg/database/queries/ratelimit_override_update.sql +++ b/go/pkg/database/queries/ratelimit_override_update.sql @@ -4,5 +4,5 @@ SET `limit` = sqlc.arg(windowLimit), duration = sqlc.arg(duration), async = sqlc.arg(async), - updated_at = NOW() + updated_at = sqlc.arg(now) WHERE id = sqlc.arg(id); diff --git a/go/pkg/database/queries/workspace_insert.sql b/go/pkg/database/queries/workspace_insert.sql index 6a4c380f42..a6d41e3f62 100644 --- a/go/pkg/database/queries/workspace_insert.sql +++ b/go/pkg/database/queries/workspace_insert.sql @@ -14,7 +14,7 @@ VALUES ( sqlc.arg(id), sqlc.arg(tenant_id), sqlc.arg(name), - NOW(), + sqlc.arg(created_at), 'free', '{}', '{}', diff --git a/go/pkg/database/queries/workspace_soft_delete.sql b/go/pkg/database/queries/workspace_soft_delete.sql index c08510a21e..d7cc208813 100644 --- a/go/pkg/database/queries/workspace_soft_delete.sql +++ b/go/pkg/database/queries/workspace_soft_delete.sql @@ -1,5 +1,5 @@ -- name: SoftDeleteWorkspace :execresult UPDATE `workspaces` -SET deleted_at = NOW() +SET deleted_at = sqlc.arg(now) WHERE id = sqlc.arg(id) AND delete_protection = false; diff --git a/go/pkg/database/ratelimit_namespace_delete.go b/go/pkg/database/ratelimit_namespace_delete.go index 07e488cf80..272d12d314 100644 --- a/go/pkg/database/ratelimit_namespace_delete.go +++ b/go/pkg/database/ratelimit_namespace_delete.go @@ -4,11 +4,18 @@ import ( "context" "database/sql" + "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/fault" ) func (db *database) DeleteRatelimitNamespace(ctx context.Context, id string) error { - result, err := db.write().DeleteRatelimitNamespace(ctx, id) + result, err := db.write().DeleteRatelimitNamespace(ctx, gen.DeleteRatelimitNamespaceParams{ + ID: id, + Now: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, + }) if err != nil { return fault.Wrap(err, fault.WithDesc("failed to delete ratelimit namespace", "")) } diff --git a/go/pkg/database/ratelimit_override_delete.go b/go/pkg/database/ratelimit_override_delete.go index 351d537f49..99043b1356 100644 --- a/go/pkg/database/ratelimit_override_delete.go +++ b/go/pkg/database/ratelimit_override_delete.go @@ -4,11 +4,18 @@ import ( "context" "database/sql" + "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/fault" ) func (db *database) DeleteRatelimitOverride(ctx context.Context, id string) error { - result, err := db.write().DeleteRatelimitOverride(ctx, id) + result, err := db.write().DeleteRatelimitOverride(ctx, gen.DeleteRatelimitOverrideParams{ + ID: id, + Now: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, + }) if err != nil { return fault.Wrap(err, fault.WithDesc("failed to delete ratelimit override", "")) } diff --git a/go/pkg/database/ratelimit_override_find_by_id.go b/go/pkg/database/ratelimit_override_find_by_id.go new file mode 100644 index 0000000000..d437292c8d --- /dev/null +++ b/go/pkg/database/ratelimit_override_find_by_id.go @@ -0,0 +1,38 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + + "errors" + + "github.com/unkeyed/unkey/go/pkg/database/gen" + "github.com/unkeyed/unkey/go/pkg/database/transform" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/fault" +) + +func (db *database) FindRatelimitOverrideByID(ctx context.Context, workspaceId, overrideID string) (entities.RatelimitOverride, error) { + + model, err := db.read().FindRatelimitOverridesById(ctx, gen.FindRatelimitOverridesByIdParams{ + WorkspaceID: workspaceId, + OverrideID: overrideID, + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return entities.RatelimitOverride{}, fault.Wrap(err, + fault.WithTag(fault.NOT_FOUND), + fault.WithDesc("not found", fmt.Sprintf("Ratelimit override '%s' does not exist.", overrideID)), + ) + } + return entities.RatelimitOverride{}, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) + } + + override, err := transform.RatelimitOverrideModelToEntity(model) + if err != nil { + return entities.RatelimitOverride{}, fault.Wrap(err, + fault.WithDesc("cannot transform override model to entity", "")) + } + return override, nil +} diff --git a/go/pkg/database/ratelimit_override_find_by_identifier.go b/go/pkg/database/ratelimit_override_find_by_identifier.go index 2cd9ab7dbb..3e7958f477 100644 --- a/go/pkg/database/ratelimit_override_find_by_identifier.go +++ b/go/pkg/database/ratelimit_override_find_by_identifier.go @@ -3,31 +3,36 @@ package database import ( "context" "database/sql" - "fmt" "errors" + "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/database/transform" "github.com/unkeyed/unkey/go/pkg/entities" "github.com/unkeyed/unkey/go/pkg/fault" ) -func (db *database) FindRatelimitOverrideByIdentifier(ctx context.Context, identifier string) (entities.RatelimitOverride, error) { +func (db *database) FindRatelimitOverridesByIdentifier(ctx context.Context, workspaceId, namespaceId, identifier string) ([]entities.RatelimitOverride, error) { - model, err := db.read().FindRatelimitOverrideByIdentifier(ctx, identifier) + models, err := db.read().FindRatelimitOverridesByIdentifier(ctx, gen.FindRatelimitOverridesByIdentifierParams{ + WorkspaceID: workspaceId, + NamespaceID: namespaceId, + Identifier: identifier, + }) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return entities.RatelimitOverride{}, fault.Wrap(err, - fault.WithTag(fault.NOT_FOUND), - fault.WithDesc("not found", fmt.Sprintf("An override for %s does not exist.", identifier)), - ) + return []entities.RatelimitOverride{}, nil } - return entities.RatelimitOverride{}, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) + return []entities.RatelimitOverride{}, fault.Wrap(err, fault.WithTag(fault.DATABASE_ERROR)) } - e, err := transform.RatelimitOverrideModelToEntity(model) - if err != nil { - return entities.RatelimitOverride{}, fault.Wrap(err, fault.WithDesc("cannot transform model to entity", "")) + es := make([]entities.RatelimitOverride, len(models)) + for i := 0; i < len(models); i++ { + + es[i], err = transform.RatelimitOverrideModelToEntity(models[i]) + if err != nil { + return []entities.RatelimitOverride{}, fault.Wrap(err, fault.WithDesc("cannot transform model to entity", "")) + } } - return e, nil + return es, nil } diff --git a/go/pkg/database/ratelimit_override_insert.go b/go/pkg/database/ratelimit_override_insert.go index e46bd2a7c6..08a69badce 100644 --- a/go/pkg/database/ratelimit_override_insert.go +++ b/go/pkg/database/ratelimit_override_insert.go @@ -10,13 +10,14 @@ import ( func (db *database) InsertRatelimitOverride(ctx context.Context, override entities.RatelimitOverride) error { - err := db.write().InsertOverride(ctx, gen.InsertOverrideParams{ + err := db.write().InsertRatelimitOverride(ctx, gen.InsertRatelimitOverrideParams{ ID: override.ID, WorkspaceID: override.WorkspaceID, NamespaceID: override.NamespaceID, Identifier: override.Identifier, Limit: override.Limit, Duration: int32(override.Duration.Milliseconds()), // nolint:gosec + CreatedAt: db.clock.Now(), }) if err != nil { diff --git a/go/pkg/database/ratelimit_override_update.go b/go/pkg/database/ratelimit_override_update.go index 7d1cc0fd71..138d1f2f78 100644 --- a/go/pkg/database/ratelimit_override_update.go +++ b/go/pkg/database/ratelimit_override_update.go @@ -13,6 +13,10 @@ func (db *database) UpdateRatelimitOverride(ctx context.Context, e entities.Rate params := gen.UpdateRatelimitOverrideParams{ ID: e.ID, Windowlimit: e.Limit, + Now: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, Duration: int32(e.Duration.Milliseconds()), // nolint:gosec Async: sql.NullBool{ diff --git a/go/pkg/database/transform/identity.go b/go/pkg/database/transform/identity.go index 8fdbd340a3..eb6ba414f0 100644 --- a/go/pkg/database/transform/identity.go +++ b/go/pkg/database/transform/identity.go @@ -9,59 +9,27 @@ import ( "github.com/unkeyed/unkey/go/pkg/entities" ) -func KeyModelToEntity(m gen.Key) (entities.Key, error) { +func IdentityModelToEntity(m gen.Identity) (entities.Identity, error) { - key := entities.Key{ - ID: m.ID, - KeySpaceID: m.KeyAuthID, - WorkspaceID: m.WorkspaceID, - Hash: m.Hash, - Start: m.Start, - CreatedAt: m.CreatedAt, - ForWorkspaceID: "", - Name: "", - Enabled: m.Enabled, - IdentityID: "", - Meta: map[string]any{}, - UpdatedAt: time.Time{}, - DeletedAt: time.Time{}, - Environment: "", - Expires: time.Time{}, - Identity: nil, - Permissions: []string{}, + identity := entities.Identity{ + ID: m.ID, + ExternalID: m.ExternalID, + WorkspaceID: m.WorkspaceID, + CreatedAt: time.UnixMilli(m.CreatedAt), + Meta: map[string]any{}, + UpdatedAt: time.Time{}, + DeletedAt: time.Time{}, + Environment: m.Environment, } - if m.Name.Valid { - key.Name = m.Name.String + err := json.Unmarshal([]byte(m.Meta), &identity.Meta) + if err != nil { + return entities.Identity{}, fmt.Errorf("unable to unmarshal meta: %w", err) } - if m.Meta.Valid { - err := json.Unmarshal([]byte(m.Meta.String), &key.Meta) - if err != nil { - return entities.Key{}, fmt.Errorf("uanble to unmarshal meta: %w", err) - } - } - if m.Expires.Valid { - key.Expires = m.Expires.Time - } - - if m.ForWorkspaceID.Valid { - key.ForWorkspaceID = m.ForWorkspaceID.String - } - if m.IdentityID.Valid { - key.IdentityID = m.IdentityID.String - } - - if m.UpdatedAtM.Valid { - key.UpdatedAt = time.UnixMilli(m.UpdatedAtM.Int64) - } - - if m.DeletedAtM.Valid { - key.DeletedAt = time.UnixMilli(m.DeletedAtM.Int64) - } - if m.Environment.Valid { - key.Environment = m.Environment.String + if m.UpdatedAt.Valid { + identity.UpdatedAt = time.UnixMilli(m.UpdatedAt.Int64) } - return key, nil + return identity, nil } diff --git a/go/pkg/database/transform/key.go b/go/pkg/database/transform/key.go index 558bc30c22..7177c416ee 100644 --- a/go/pkg/database/transform/key.go +++ b/go/pkg/database/transform/key.go @@ -9,27 +9,56 @@ import ( "github.com/unkeyed/unkey/go/pkg/entities" ) -func IdentityModelToEntity(m gen.Identity) (entities.Identity, error) { +func KeyModelToEntity(m gen.Key) (entities.Key, error) { - identity := entities.Identity{ - ID: m.ID, - ExternalID: m.ExternalID, - WorkspaceID: m.WorkspaceID, - CreatedAt: time.UnixMilli(m.CreatedAt), - Meta: map[string]any{}, - UpdatedAt: time.Time{}, - DeletedAt: time.Time{}, - Environment: m.Environment, + key := entities.Key{ + ID: m.ID, + KeyringID: m.KeyAuthID, + WorkspaceID: m.WorkspaceID, + Hash: m.Hash, + Start: m.Start, + CreatedAt: m.CreatedAt, + ForWorkspaceID: "", + Name: "", + Enabled: m.Enabled, + Meta: map[string]any{}, + UpdatedAt: time.Time{}, + DeletedAt: time.Time{}, + Environment: "", + Expires: time.Time{}, + Identity: nil, + Permissions: []string{}, + RemainingRequests: nil, } - err := json.Unmarshal([]byte(m.Meta), &identity.Meta) - if err != nil { - return entities.Identity{}, fmt.Errorf("uanble to unmarshal meta: %w", err) + if m.Name.Valid { + key.Name = m.Name.String } - if m.UpdatedAt.Valid { - identity.UpdatedAt = time.UnixMilli(m.UpdatedAt.Int64) + if m.Meta.Valid { + err := json.Unmarshal([]byte(m.Meta.String), &key.Meta) + if err != nil { + return entities.Key{}, fmt.Errorf("unable to unmarshal meta: %w", err) + } + } + if m.Expires.Valid { + key.Expires = m.Expires.Time + } + + if m.ForWorkspaceID.Valid { + key.ForWorkspaceID = m.ForWorkspaceID.String + } + + if m.UpdatedAtM.Valid { + key.UpdatedAt = time.UnixMilli(m.UpdatedAtM.Int64) + } + + if m.DeletedAtM.Valid { + key.DeletedAt = time.UnixMilli(m.DeletedAtM.Int64) + } + if m.Environment.Valid { + key.Environment = m.Environment.String } - return identity, nil + return key, nil } diff --git a/go/pkg/database/transform/ratelimit_namespace.go b/go/pkg/database/transform/ratelimit_namespace.go index a76d93c366..3d5bcef67c 100644 --- a/go/pkg/database/transform/ratelimit_namespace.go +++ b/go/pkg/database/transform/ratelimit_namespace.go @@ -33,5 +33,6 @@ func RatelimitNamespaceEntityToInsertParams(e entities.RatelimitNamespace) gen.I ID: e.ID, WorkspaceID: e.WorkspaceID, Name: e.Name, + CreatedAt: e.CreatedAt, } } diff --git a/go/pkg/database/workspace_delete.go b/go/pkg/database/workspace_delete.go index a807100e1e..8b29d5717c 100644 --- a/go/pkg/database/workspace_delete.go +++ b/go/pkg/database/workspace_delete.go @@ -2,8 +2,10 @@ package database import ( "context" + "database/sql" "log/slog" + "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/fault" ) @@ -41,7 +43,13 @@ func (db *database) DeleteWorkspace(ctx context.Context, id string, hardDelete b return fault.Wrap(err, fault.WithDesc("failed to hard delete workspace", "")) } } else { - _, err = qtx.SoftDeleteWorkspace(ctx, id) + _, err = qtx.SoftDeleteWorkspace(ctx, gen.SoftDeleteWorkspaceParams{ + ID: id, + Now: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, + }) if err != nil { return fault.Wrap(err, fault.WithDesc("failed to soft delete workspace", "")) } diff --git a/go/pkg/database/workspace_insert.go b/go/pkg/database/workspace_insert.go index b0b7df3efe..918df33a35 100644 --- a/go/pkg/database/workspace_insert.go +++ b/go/pkg/database/workspace_insert.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "github.com/unkeyed/unkey/go/pkg/database/gen" "github.com/unkeyed/unkey/go/pkg/entities" @@ -14,6 +15,10 @@ func (db *database) InsertWorkspace(ctx context.Context, workspace entities.Work ID: workspace.ID, TenantID: workspace.TenantID, Name: workspace.Name, + CreatedAt: sql.NullTime{ + Time: db.clock.Now(), + Valid: true, + }, } err := db.write().InsertWorkspace(ctx, params) diff --git a/go/pkg/discovery/redis.go b/go/pkg/discovery/redis.go new file mode 100644 index 0000000000..62fc07ec6f --- /dev/null +++ b/go/pkg/discovery/redis.go @@ -0,0 +1,124 @@ +package discovery + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/redis/go-redis/v9" + "github.com/unkeyed/unkey/go/pkg/logging" +) + +type Redis struct { + rdb *redis.Client + logger logging.Logger + + addr string + nodeID string + + ttl time.Duration + heartbeatInterval time.Duration + shutdownC chan struct{} +} + +type RedisConfig struct { + URL string + NodeID string + Addr string + Logger logging.Logger +} + +func NewRedis(config RedisConfig) (*Redis, error) { + opts, err := redis.ParseURL(config.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse redis url: %w", err) + } + + rdb := redis.NewClient(opts) + + _, err = rdb.Ping(context.Background()).Result() + if err != nil { + return nil, fmt.Errorf("failed to ping redis: %w", err) + } + + r := &Redis{ + logger: config.Logger, + rdb: rdb, + nodeID: config.NodeID, + addr: config.Addr, + heartbeatInterval: time.Second * 60, + ttl: time.Second * 90, + shutdownC: make(chan struct{}), + } + + err = r.advertise(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to advertise state to redis: %w", err) + } + + go r.heartbeat() + + return r, nil +} + +func (r *Redis) heartbeat() { + + t := time.NewTicker(r.heartbeatInterval) + defer t.Stop() + + for { + select { + case <-r.shutdownC: + return + case <-t.C: + ctx := context.Background() + err := r.advertise(ctx) + if err != nil { + r.logger.Error(ctx, "failed to advertise state to redis", slog.String("err", err.Error())) + } + } + } +} + +func (r *Redis) key() string { + return fmt.Sprintf("discovery::nodes::%s", r.nodeID) +} +func (r *Redis) advertise(ctx context.Context) error { + return r.rdb.Set(ctx, r.key(), r.addr, r.ttl).Err() +} + +func (r *Redis) Discover() ([]string, error) { + pattern := r.key() + pattern = strings.ReplaceAll(pattern, r.nodeID, "*") + keys, err := r.rdb.Keys(context.Background(), pattern).Result() + if err != nil { + return nil, fmt.Errorf("failed to get keys: %w", err) + } + + if len(keys) == 0 { + return []string{}, nil + } + + results, err := r.rdb.MGet(context.Background(), keys...).Result() + if err != nil { + return nil, fmt.Errorf("failed to get addresses: %w", err) + } + + addrs := make([]string, len(results)) + var ok bool + for i, addr := range results { + addrs[i], ok = addr.(string) + if !ok { + return nil, fmt.Errorf("invalid address type") + } + } + + return addrs, nil +} + +func (r *Redis) Shutdown(ctx context.Context) error { + r.shutdownC <- struct{}{} + return r.rdb.Del(ctx, r.key()).Err() +} diff --git a/go/pkg/entities/key.go b/go/pkg/entities/key.go index 506e3f10af..11e834a158 100644 --- a/go/pkg/entities/key.go +++ b/go/pkg/entities/key.go @@ -7,8 +7,8 @@ type Key struct { // ID is the unique identifier for the key ID string - // KeySpaceID represents the key authorization space this key belongs to - KeySpaceID string + // KeyringID represents the key authorization space this key belongs to + KeyringID string // WorkspaceID is the ID of the workspace that owns this key WorkspaceID string @@ -26,9 +26,6 @@ type Key struct { // Name is an optional human-readable identifier for the key Name string - // IdentityID links this key to a specific identity in the system - IdentityID string - // Meta contains arbitrary metadata associated with the key as key-value pairs Meta map[string]any @@ -56,4 +53,6 @@ type Key struct { // All transient permissions, directly attached or via roles Permissions []string + + RemainingRequests *int64 } diff --git a/go/pkg/entities/keyring.go b/go/pkg/entities/keyring.go new file mode 100644 index 0000000000..98261e3586 --- /dev/null +++ b/go/pkg/entities/keyring.go @@ -0,0 +1,14 @@ +package entities + +import "time" + +type Keyring struct { + ID string + WorkspaceID string + StoreEncryptedKeys bool + DefaultPrefix string + DefaultBytes int + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time +} diff --git a/go/pkg/fault/tag.go b/go/pkg/fault/tag.go index 3c99c8a4a1..33cf452737 100644 --- a/go/pkg/fault/tag.go +++ b/go/pkg/fault/tag.go @@ -13,6 +13,10 @@ const ( // handling more predictable. UNTAGGED Tag = "UNTAGGED" + // BAD_REQUEST indicates that the client's request was malformed or invalid. + // This is typically used when request validation fails or when the request + // cannot be processed due to client-side errors. + BAD_REQUEST Tag = "BAD_REQUEST" // An object was not found in the system. NOT_FOUND Tag = "NOT_FOUND" diff --git a/go/pkg/membership/logger.go b/go/pkg/membership/logger.go index 4ed0124a2e..faa4f35f69 100644 --- a/go/pkg/membership/logger.go +++ b/go/pkg/membership/logger.go @@ -39,13 +39,10 @@ func (l logger) Write(p []byte) (n int, err error) { break case "INFO": l.logger.Info(context.Background(), string(p), slog.String("pkg", "memberlist")) - break case "WARN": l.logger.Warn(context.Background(), string(p), slog.String("pkg", "memberlist")) - break case "ERROR": l.logger.Error(context.Background(), string(p), slog.String("pkg", "memberlist")) - break } return len(p), nil } diff --git a/go/pkg/rbac/permissions.go b/go/pkg/rbac/permissions.go new file mode 100644 index 0000000000..9a2a446596 --- /dev/null +++ b/go/pkg/rbac/permissions.go @@ -0,0 +1,79 @@ +package rbac + +import ( + "errors" + "fmt" + "strings" +) + +type ActionType string + +const ( + // API Actions + ReadAPI ActionType = "read_api" + CreateAPI ActionType = "create_api" + DeleteAPI ActionType = "delete_api" + UpdateAPI ActionType = "update_api" + CreateKey ActionType = "create_key" + UpdateKey ActionType = "update_key" + DeleteKey ActionType = "delete_key" + EncryptKey ActionType = "encrypt_key" + DecryptKey ActionType = "decrypt_key" + ReadKey ActionType = "read_key" + + // Ratelimit Actions + Limit ActionType = "limit" + CreateNamespace ActionType = "create_namespace" + ReadNamespace ActionType = "read_namespace" + UpdateNamespace ActionType = "update_namespace" + DeleteNamespace ActionType = "delete_namespace" + SetOverride ActionType = "set_override" + ReadOverride ActionType = "read_override" + DeleteOverride ActionType = "delete_override" + + // RBAC Actions + CreatePermission ActionType = "create_permission" + UpdatePermission ActionType = "update_permission" + DeletePermission ActionType = "delete_permission" + ReadPermission ActionType = "read_permission" + CreateRole ActionType = "create_role" + UpdateRole ActionType = "update_role" + DeleteRole ActionType = "delete_role" + ReadRole ActionType = "read_role" + AddPermissionToKey ActionType = "add_permission_to_key" + RemovePermissionFromKey ActionType = "remove_permission_from_key" + AddRoleToKey ActionType = "add_role_to_key" + RemoveRoleFromKey ActionType = "remove_role_from_key" + AddPermissionToRole ActionType = "add_permission_to_role" + RemovePermissionFromRole ActionType = "remove_permission_from_role" + + // Identity Actions + CreateIdentity ActionType = "create_identity" + ReadIdentity ActionType = "read_identity" + UpdateIdentity ActionType = "update_identity" + DeleteIdentity ActionType = "delete_identity" +) + +type Tuple struct { + ResourceType string + ResourceID string + Action string +} + +func (t Tuple) String() string { + return fmt.Sprintf("%s:%s:%s", t.ResourceType, t.ResourceID, t.Action) +} + +func TupleFromString(s string) (Tuple, error) { + parts := strings.Split(s, ":") + if len(parts) != 3 { + return Tuple{}, errors.New("invalid tuple format") + + } + tuple := Tuple{ + ResourceType: parts[0], + ResourceID: parts[1], + Action: parts[2], + } + return tuple, nil +} diff --git a/go/pkg/rbac/query.go b/go/pkg/rbac/query.go new file mode 100644 index 0000000000..42f1475379 --- /dev/null +++ b/go/pkg/rbac/query.go @@ -0,0 +1,40 @@ +// query.go +package rbac + +type QueryOperator string + +const ( + OperatorNil QueryOperator = "" + OperatorAnd QueryOperator = "and" + OperatorOr QueryOperator = "or" +) + +type PermissionQuery struct { + Operation QueryOperator `json:"operation,omitempty"` + Value string `json:"value,omitempty"` + Children []PermissionQuery `json:"children,omitempty"` +} + +func And(queries ...PermissionQuery) PermissionQuery { + return PermissionQuery{ + Operation: OperatorAnd, + Value: "", + Children: queries, + } +} + +func Or(queries ...PermissionQuery) PermissionQuery { + return PermissionQuery{ + Operation: OperatorOr, + Value: "", + Children: queries, + } +} + +func P(permission string) PermissionQuery { + return PermissionQuery{ + Operation: OperatorNil, + Value: permission, + Children: []PermissionQuery{}, + } +} diff --git a/go/pkg/rbac/rbac.go b/go/pkg/rbac/rbac.go new file mode 100644 index 0000000000..67476b73fe --- /dev/null +++ b/go/pkg/rbac/rbac.go @@ -0,0 +1,81 @@ +package rbac + +import ( + "fmt" + "strings" +) + +type RBAC struct{} + +func New() *RBAC { + return &RBAC{} +} + +type EvaluationResult struct { + Valid bool + Message string +} + +func (r *RBAC) EvaluatePermissions(query PermissionQuery, permissions []string) (*EvaluationResult, error) { + return r.evaluateQueryV1(query, permissions) +} + +func (r *RBAC) evaluateQueryV1(query PermissionQuery, permissions []string) (*EvaluationResult, error) { + // Handle simple permission check + if query.Value != "" { + for _, p := range permissions { + if p == query.Value { + return &EvaluationResult{Valid: true, Message: ""}, nil + } + } + return &EvaluationResult{ + Valid: false, + Message: fmt.Sprintf("Missing permission: '%s'", query.Value), + }, nil + } + + // Handle AND operation + if query.Operation == OperatorAnd { + for _, child := range query.Children { + result, err := r.evaluateQueryV1(child, permissions) + if err != nil { + return nil, err + } + if !result.Valid { + return result, nil + } + } + return &EvaluationResult{Valid: true, Message: ""}, nil + } + + // Handle OR operation + if query.Operation == OperatorOr { + missingPerms := make([]string, 0) + for _, child := range query.Children { + result, err := r.evaluateQueryV1(child, permissions) + if err != nil { + return nil, err + } + if result.Valid { + return result, nil + } + missingPerms = append(missingPerms, fmt.Sprintf("'%v'", child)) + } + return &EvaluationResult{ + Valid: false, + Message: fmt.Sprintf("Missing one of these permissions: [%s], have: [%s]", + strings.Join(missingPerms, ", "), + strings.Join(formatPermissions(permissions), ", ")), + }, nil + } + + return nil, fmt.Errorf("invalid query structure") +} + +func formatPermissions(permissions []string) []string { + formatted := make([]string, len(permissions)) + for i, p := range permissions { + formatted[i] = fmt.Sprintf("'%s'", p) + } + return formatted +} diff --git a/go/pkg/rbac/rbac_test.go b/go/pkg/rbac/rbac_test.go new file mode 100644 index 0000000000..fd5b388e40 --- /dev/null +++ b/go/pkg/rbac/rbac_test.go @@ -0,0 +1,72 @@ +package rbac + +import ( + "testing" +) + +func TestRBAC_EvaluatePermissions(t *testing.T) { + tests := []struct { + name string + query PermissionQuery + permissions []string + wantValid bool + }{ + { + name: "Simple role check (Pass)", + query: P("admin"), + permissions: []string{"admin", "user", "guest"}, + wantValid: true, + }, + { + name: "Simple role check (Fail)", + query: P("developer"), + permissions: []string{"admin", "user", "guest"}, + wantValid: false, + }, + { + name: "AND of two permissions (Pass)", + query: And( + P("admin"), + P("user"), + ), + permissions: []string{"admin", "user", "guest"}, + wantValid: true, + }, + { + name: "OR of two permissions (Pass)", + query: Or( + P("admin"), + P("developer"), + ), + permissions: []string{"admin", "user", "guest"}, + wantValid: true, + }, + { + name: "Complex combination (Pass)", + query: And( + P("admin"), + Or( + P("user"), + P("guest"), + ), + ), + permissions: []string{"admin", "user", "guest"}, + wantValid: true, + }, + } + + rbac := New() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := rbac.EvaluatePermissions(tt.query, tt.permissions) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if result.Valid != tt.wantValid { + t.Errorf("want valid=%v, got valid=%v, message=%s", + tt.wantValid, result.Valid, result.Message) + } + }) + } +} diff --git a/go/pkg/ring/ring_test.go b/go/pkg/ring/ring_test.go index 09706c43c2..70e4cf1228 100644 --- a/go/pkg/ring/ring_test.go +++ b/go/pkg/ring/ring_test.go @@ -72,7 +72,6 @@ func TestRing(t *testing.T) { m, s := stat.MeanStdDev(cs, nil) relStddev := s / m - fmt.Printf("min: %d, max: %d, mean: %f, stddev: %f, relstddev: %f\n", minimum, maximum, m, s, relStddev) require.LessOrEqual(t, relStddev, 0.1, "relative std should be less than 0.1, got: %f", relStddev) } diff --git a/go/pkg/sim/events.go b/go/pkg/sim/events.go new file mode 100644 index 0000000000..7dce572b1d --- /dev/null +++ b/go/pkg/sim/events.go @@ -0,0 +1,33 @@ +package sim + +import ( + "errors" + "math/rand" +) + +type Idle struct { +} + +var _ Event[any] = (*Idle)(nil) + +func (i Idle) Name() string { + return "Idle" +} + +func (i Idle) Run(rng *rand.Rand, state *any) error { + return nil +} + +type Fail struct { + Message string +} + +var _ Event[any] = (*Fail)(nil) + +func (f Fail) Name() string { + return "Fail" +} + +func (f Fail) Run(rng *rand.Rand, state *any) error { + return errors.New(f.Message) +} diff --git a/go/pkg/sim/rng.go b/go/pkg/sim/rng.go new file mode 100644 index 0000000000..4ff70e8f74 --- /dev/null +++ b/go/pkg/sim/rng.go @@ -0,0 +1,13 @@ +package sim + +import "math/rand/v2" + +type source struct { + seed uint64 +} + +var _ rand.Source = (*source)(nil) + +func (s source) Uint64() uint64 { + return s.seed +} diff --git a/go/pkg/sim/seed.go b/go/pkg/sim/seed.go new file mode 100644 index 0000000000..121be45c2a --- /dev/null +++ b/go/pkg/sim/seed.go @@ -0,0 +1,11 @@ +package sim + +import ( + "math/rand" +) + +func NewSeed() int64 { + + // nolint:gosec + return rand.Int63() +} diff --git a/go/pkg/sim/simulation.go b/go/pkg/sim/simulation.go new file mode 100644 index 0000000000..1b9760596e --- /dev/null +++ b/go/pkg/sim/simulation.go @@ -0,0 +1,126 @@ +package sim + +import ( + "fmt" + "math/rand" + "testing" +) + +type Event[State any] interface { + // Run executes the event logic. + // State must allow parallel manipulation from multiple goroutines. + Run(rng *rand.Rand, state *State) error + + // Name returns the name of the event for logging and debugging purposes. + Name() string +} + +type Simulation[State any] struct { + t *testing.T + seed int64 + rng *rand.Rand + steps int + + Errors []error + + state *State + + // Tracks how many configurations have been applied + applied int +} + +type apply[S any] func(*Simulation[S]) *Simulation[S] + +func New[State any](t *testing.T, fns ...apply[State]) *Simulation[State] { + + seed := NewSeed() + + s := &Simulation[State]{ + t: t, + seed: seed, + // nolint:gosec + rng: rand.New(rand.NewSource(seed)), + steps: 1_000_000_000, + state: nil, + Errors: []error{}, + applied: 0, + } + + for _, fn := range fns { + s = fn(s) + s.applied++ + } + return s +} + +func WithSeed[S any](seed int64) apply[S] { + return func(s *Simulation[S]) *Simulation[S] { + if s.applied > 0 { + s.t.Fatalf("WithSeed called too late. If you need a custom seed, call WithSeed before any other configuration.") + } + + s.seed = seed + // nolint:gosec + s.rng = rand.New(rand.NewSource(seed)) + return s + } +} + +func WithSteps[S any](steps int) apply[S] { + return func(s *Simulation[S]) *Simulation[S] { + s.steps = steps + return s + } +} + +func WithState[S any](fn func(rng *rand.Rand) *S) apply[S] { + return func(s *Simulation[S]) *Simulation[S] { + s.state = fn(s.rng) + return s + } +} + +// Run must not be called concurrently +func (s *Simulation[State]) Run(events []Event[State]) { + s.t.Helper() + + if len(events) == 0 { + return + } + + fmt.Printf("Simulation [seed=%d], steps=%d\n", s.seed, s.steps) + + total := 0.0 + weights := make([]float64, len(events)) + + for i := range weights { + weights[i] = s.rng.Float64() + total += weights[i] + } + + for i := 0; i < s.steps; i++ { + if i%(s.steps/10) == 0 { + s.t.Logf("progress: %d%%\n", i*100/s.steps) + } + + r := s.rng.Float64() * total + + // Find which bucket it falls into + sum := 0.0 + var index int + for j, w := range weights { + sum += w + if r <= sum { + index = j + break + } + } + + event := events[index] + + err := event.Run(s.rng, s.state) + if err != nil { + s.Errors = append(s.Errors, err) + } + } +} diff --git a/go/pkg/testutil/containers.go b/go/pkg/testutil/containers.go index 7e0e68ca49..75f894eea3 100644 --- a/go/pkg/testutil/containers.go +++ b/go/pkg/testutil/containers.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - _ "github.com/go-sql-driver/mysql" + mysql "github.com/go-sql-driver/mysql" "github.com/unkeyed/unkey/go/pkg/database" "github.com/ory/dockertest/v3" @@ -48,22 +48,35 @@ func (c *Containers) RunMySQL() string { require.NoError(c.t, c.pool.Purge(resource)) }) - addr := fmt.Sprintf("unkey:password@(localhost:%s)/unkey", resource.GetPort("3306/tcp")) + cfg := mysql.NewConfig() + cfg.User = "unkey" + cfg.Passwd = "password" + cfg.Net = "tcp" + cfg.Addr = fmt.Sprintf("localhost:%s", resource.GetPort("3306/tcp")) + cfg.DBName = "unkey" + cfg.ParseTime = true + cfg.Logger = &mysql.NopLogger{} var db *sql.DB require.NoError(c.t, c.pool.Retry(func() error { - db, err = sql.Open("mysql", addr) - if err != nil { - return fmt.Errorf("unable to open mysql conenction: %w", err) + + connector, err2 := mysql.NewConnector(cfg) + if err2 != nil { + return fmt.Errorf("unable to create mysql connector: %w", err2) } - err = db.Ping() - if err != nil { - return fmt.Errorf("unable to ping mysql: %w", err) + + db = sql.OpenDB(connector) + err3 := db.Ping() + if err3 != nil { + return fmt.Errorf("unable to ping mysql: %w", err3) } return nil })) + c.t.Cleanup(func() { + require.NoError(c.t, db.Close()) + }) // Creating the database tables queries := strings.Split(string(database.Schema), ";") for _, query := range queries { @@ -79,5 +92,6 @@ func (c *Containers) RunMySQL() string { } - return addr + return cfg.FormatDSN() + } diff --git a/go/pkg/testutil/http.go b/go/pkg/testutil/http.go index c46df9410a..2d209df494 100644 --- a/go/pkg/testutil/http.go +++ b/go/pkg/testutil/http.go @@ -2,27 +2,65 @@ package testutil import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/api" + "github.com/unkeyed/unkey/go/internal/services/keys" + "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/database" + "github.com/unkeyed/unkey/go/pkg/entities" + "github.com/unkeyed/unkey/go/pkg/hash" "github.com/unkeyed/unkey/go/pkg/logging" + "github.com/unkeyed/unkey/go/pkg/uid" "github.com/unkeyed/unkey/go/pkg/zen" + "github.com/unkeyed/unkey/go/pkg/zen/validation" ) +type Resources struct { + RootWorkspace entities.Workspace + RootKeyring entities.Keyring + UserWorkspace entities.Workspace +} + type Harness struct { t *testing.T - logger logging.Logger + Clock clock.Clock + + srv *zen.Server + containers *Containers + validator *validation.Validator - srv *zen.Server + middleware []zen.Middleware + + DB database.Database + Logger logging.Logger + Keys keys.KeyService + Resources Resources } func NewHarness(t *testing.T) *Harness { + clk := clock.NewTestClock() + + logger := logging.New(logging.Config{Development: true, NoColor: false}) + + containers := NewContainers(t) - logger := logging.NewNoop() + dsn := containers.RunMySQL() + + db, err := database.New(database.Config{ + Logger: logger, + PrimaryDSN: dsn, + ReadOnlyDSN: "", + Clock: clk, + }) + require.NoError(t, err) srv, err := zen.New(zen.Config{ NodeID: "test", @@ -30,18 +68,146 @@ func NewHarness(t *testing.T) *Harness { }) require.NoError(t, err) + keyService, err := keys.New(keys.Config{ + Logger: logger, + DB: db, + }) + require.NoError(t, err) + + validator, err := validation.New() + require.NoError(t, err) + h := Harness{ - t: t, - logger: logger, - srv: srv, + t: t, + Logger: logger, + srv: srv, + containers: containers, + validator: validator, + Keys: keyService, + DB: db, + // resources are seeded later + // nolint:exhaustruct + Resources: Resources{}, + Clock: clk, + + middleware: []zen.Middleware{ + zen.WithTracing(), + // zen.WithMetrics(svc.EventBuffer) + zen.WithLogging(logger), + zen.WithErrorHandling(), + zen.WithValidation(validator), + }, } + h.seed() return &h } -func (h *Harness) Register(route zen.Route) { +// Register registers a route with the harness. +// You can override the middleware by passing a list of middleware. +func (h *Harness) Register(route zen.Route, middleware ...zen.Middleware) { + + if len(middleware) == 0 { + middleware = h.middleware + } + + h.srv.RegisterRoute( + middleware, + route, + ) + +} + +func (h *Harness) seed() { + + rootWorkspace := entities.Workspace{ + ID: uid.New("test_ws"), + TenantID: "unkey", + Name: "unkey", + CreatedAt: time.Now(), + DeletedAt: time.Time{}, + Plan: entities.WorkspacePlanPro, + Enabled: true, + DeleteProtection: true, + BetaFeatures: make(map[string]interface{}), + Features: make(map[string]interface{}), + StripeCustomerID: "", + StripeSubscriptionID: "", + TrialEnds: time.Time{}, + PlanLockedUntil: time.Time{}, + } + + err := h.DB.InsertWorkspace(context.Background(), rootWorkspace) + require.NoError(h.t, err) + + rootKeyring := entities.Keyring{ + ID: uid.New("test_kr"), + WorkspaceID: rootWorkspace.ID, + StoreEncryptedKeys: false, + DefaultPrefix: "test", + DefaultBytes: 16, + CreatedAt: time.Now(), + UpdatedAt: time.Time{}, + DeletedAt: time.Time{}, + } + + err = h.DB.InsertKeyring(context.Background(), rootKeyring) + require.NoError(h.t, err) + + userWorkspace := entities.Workspace{ + ID: uid.New("test_ws"), + TenantID: "user", + Name: "user", + CreatedAt: time.Now(), + DeletedAt: time.Time{}, + Plan: entities.WorkspacePlanPro, + Enabled: true, + DeleteProtection: true, + BetaFeatures: make(map[string]interface{}), + Features: make(map[string]interface{}), + StripeCustomerID: "", + StripeSubscriptionID: "", + TrialEnds: time.Time{}, + PlanLockedUntil: time.Time{}, + } + + err = h.DB.InsertWorkspace(context.Background(), userWorkspace) + require.NoError(h.t, err) + + h.Resources = Resources{ + RootWorkspace: rootWorkspace, + RootKeyring: rootKeyring, + UserWorkspace: userWorkspace, + } + +} + +func (h *Harness) CreateRootKey() string { + + key := uid.New("test_root_key") + + err := h.DB.InsertKey(context.Background(), entities.Key{ + ID: uid.New("test_root_key"), + Hash: hash.Sha256(key), + WorkspaceID: h.Resources.RootWorkspace.ID, + ForWorkspaceID: h.Resources.UserWorkspace.ID, + KeyringID: h.Resources.RootKeyring.ID, + Start: key[:4], + Name: "test", + Identity: nil, + Meta: make(map[string]any), + CreatedAt: time.Now(), + UpdatedAt: time.Time{}, + DeletedAt: time.Time{}, + Enabled: true, + Environment: "", + Expires: time.Time{}, + Permissions: []string{}, + RemainingRequests: nil, + }) + require.NoError(h.t, err) - h.srv.RegisterRoute([]zen.Middleware{}, route) + return key } @@ -55,9 +221,11 @@ func UnmarshalBody[Body any](t *testing.T, r *httptest.ResponseRecorder, body *B } type TestResponse[TBody any] struct { - Status int - Headers http.Header - Body TBody + Status int + Headers http.Header + Body *TBody + ErrorBody *api.BaseError + RawBody string } func CallRoute[Req any, Res any](h *Harness, route zen.Route, headers http.Header, req Req) TestResponse[Res] { @@ -74,20 +242,31 @@ func CallRoute[Req any, Res any](h *Harness, route zen.Route, headers http.Heade if httpReq.Header == nil { httpReq.Header = http.Header{} } - if route.Method() == http.MethodPost { - httpReq.Header.Set("Content-Type", "application/json") - } h.srv.Mux().ServeHTTP(rr, httpReq) require.NoError(h.t, err) - var res Res - err = json.NewDecoder(rr.Body).Decode(&res) - require.NoError(h.t, err) + rawBody := rr.Body.Bytes() - return TestResponse[Res]{ - Status: rr.Code, - Headers: rr.Header(), - Body: res, + res := TestResponse[Res]{ + Status: rr.Code, + Headers: rr.Header(), + RawBody: string(rawBody), + Body: nil, + ErrorBody: nil, } + + if rr.Code < 400 { + var responseBody Res + err = json.Unmarshal(rawBody, &responseBody) + require.NoError(h.t, err) + res.Body = &responseBody + } else { + var errorBody api.BaseError + err = json.Unmarshal(rawBody, &errorBody) + require.NoError(h.t, err) + res.ErrorBody = &errorBody + } + + return res } diff --git a/go/pkg/zen/auth.go b/go/pkg/zen/auth.go index 435ce7491a..e3a0114d4e 100644 --- a/go/pkg/zen/auth.go +++ b/go/pkg/zen/auth.go @@ -14,7 +14,13 @@ func Bearer(s *Session) (string, error) { return "", fault.New("empty authorization header", fault.WithTag(fault.UNAUTHORIZED)) } - bearer := strings.TrimSuffix(header, "Bearer ") + header = strings.TrimSpace(header) + if !strings.HasPrefix(header, "Bearer ") { + return "", fault.New("invalid format", fault.WithTag(fault.UNAUTHORIZED), + fault.WithDesc("missing bearer prefix", "Your authorization header is missing the 'Bearer ' prefix.")) + } + + bearer := strings.TrimPrefix(header, "Bearer ") if bearer == "" { return "", fault.New("invalid token", fault.WithTag(fault.UNAUTHORIZED)) } diff --git a/go/pkg/zen/middleware_errors.go b/go/pkg/zen/middleware_errors.go index 8df9615a2f..aacf3380a4 100644 --- a/go/pkg/zen/middleware_errors.go +++ b/go/pkg/zen/middleware_errors.go @@ -11,6 +11,7 @@ func WithErrorHandling() Middleware { return func(next HandleFunc) HandleFunc { return func(s *Session) error { err := next(s) + if err == nil { return nil } @@ -25,6 +26,18 @@ func WithErrorHandling() Middleware { Status: s.responseStatus, Instance: nil, }) + + case fault.BAD_REQUEST: + return s.JSON(http.StatusBadRequest, api.BadRequestError{ + Title: "Bad Request", + Type: "https://unkey.com/docs/errors/bad_request", + Detail: fault.UserFacingMessage(err), + RequestId: s.requestID, + Status: http.StatusBadRequest, + Instance: nil, + Errors: []api.ValidationError{}, + }) + case fault.UNAUTHORIZED: return s.JSON(http.StatusUnauthorized, api.UnauthorizedError{ Title: "Unauthorized", diff --git a/go/pkg/zen/middleware_auth.go b/go/pkg/zen/middleware_openapi_validation.go similarity index 100% rename from go/pkg/zen/middleware_auth.go rename to go/pkg/zen/middleware_openapi_validation.go diff --git a/go/pkg/zen/middleware_tracing.go b/go/pkg/zen/middleware_tracing.go new file mode 100644 index 0000000000..12aaeed68d --- /dev/null +++ b/go/pkg/zen/middleware_tracing.go @@ -0,0 +1,23 @@ +package zen + +import ( + "github.com/unkeyed/unkey/go/pkg/tracing" +) + +func WithTracing() Middleware { + + return func(next HandleFunc) HandleFunc { + return func(s *Session) error { + ctx, span := tracing.Start(s.Context(), s.r.Pattern) + defer span.End() + + s.ctx = ctx + + err := next(s) + if err != nil { + tracing.RecordError(span, err) + } + return err + } + } +} diff --git a/go/pkg/zen/middleware_validate.go b/go/pkg/zen/middleware_validate.go deleted file mode 100644 index c5289fd1e8..0000000000 --- a/go/pkg/zen/middleware_validate.go +++ /dev/null @@ -1,35 +0,0 @@ -package zen - -import ( - "strings" - - "github.com/unkeyed/unkey/go/internal/services/keys" - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/hash" -) - -func WithRootKeyAuth(svc keys.KeyService) Middleware { - return func(next HandleFunc) HandleFunc { - return func(s *Session) error { - - header := s.r.Header.Get("Authorization") - if header == "" { - return fault.New("empty authorization header", fault.WithTag(fault.UNAUTHORIZED)) - } - - bearer := strings.TrimSuffix(header, "Bearer ") - if bearer == "" { - return fault.New("invalid token", fault.WithTag(fault.UNAUTHORIZED)) - } - - key, err := svc.Verify(s.Context(), hash.Sha256(bearer)) - if err != nil { - return fault.Wrap(err) - } - - s.workspaceID = key.AuthorizedWorkspaceID - - return next(s) - } - } -} diff --git a/go/pkg/zen/server.go b/go/pkg/zen/server.go index f50231f62d..1a5728dcb0 100644 --- a/go/pkg/zen/server.go +++ b/go/pkg/zen/server.go @@ -58,6 +58,7 @@ func New(config Config) (*Server, error) { sessions: sync.Pool{ New: func() any { return &Session{ + ctx: context.Background(), workspaceID: "", requestID: "", w: nil, @@ -133,7 +134,7 @@ func (s *Server) RegisterRoute(middlewares []Middleware, route Route) { s.returnSession(sess) }() - err := sess.Init(w, r) + err := sess.init(w, r) if err != nil { s.logger.Error(context.Background(), "failed to init session") return diff --git a/go/pkg/zen/session.go b/go/pkg/zen/session.go index 94d59c3f00..2413b5f0b9 100644 --- a/go/pkg/zen/session.go +++ b/go/pkg/zen/session.go @@ -3,6 +3,7 @@ package zen import ( "context" "encoding/json" + "io" "net/http" "github.com/unkeyed/unkey/go/pkg/fault" @@ -16,6 +17,7 @@ import ( // All references to sessions, request bodies or anything within must not be // used outside of the handler. Make a copy of them if you need to. type Session struct { + ctx context.Context requestID string w http.ResponseWriter @@ -31,7 +33,8 @@ type Session struct { responseBody []byte } -func (s *Session) Init(w http.ResponseWriter, r *http.Request) error { +func (s *Session) init(w http.ResponseWriter, r *http.Request) error { + s.ctx = r.Context() s.requestID = uid.Request() s.w = w s.r = r @@ -41,7 +44,17 @@ func (s *Session) Init(w http.ResponseWriter, r *http.Request) error { } func (s *Session) Context() context.Context { - return s.r.Context() + return s.ctx + +} + +// AuthorizedWorkspaceID returns the workspaceID of the root key used as authentication mechanism. +// +// If the `WithRootKeyAuth` middleware is used, it is guaranteed to be populated. +// The request would've aborted and returned early if authentication failed. +// Otherwise an empty string is returned. +func (s *Session) AuthorizedWorkspaceID() string { + return s.workspaceID } // Request returns the underlying http.Request. @@ -56,7 +69,14 @@ func (s *Session) ResponseWriter() http.ResponseWriter { } func (s *Session) BindBody(dst any) error { - err := json.Unmarshal(s.requestBody, dst) + var err error + s.requestBody, err = io.ReadAll(s.r.Body) + if err != nil { + return fault.Wrap(err, fault.WithDesc("unable to read request body", "The request body is malformed.")) + } + defer s.r.Body.Close() + + err = json.Unmarshal(s.requestBody, dst) if err != nil { return fault.Wrap(err, fault.WithDesc("failed to unmarshal request body", "The request body was not valid json."), diff --git a/go/pkg/zen/validation/validator.go b/go/pkg/zen/validation/validator.go index 8db67a0ba7..a618840651 100644 --- a/go/pkg/zen/validation/validator.go +++ b/go/pkg/zen/validation/validator.go @@ -38,11 +38,21 @@ func New() (*Validator, error) { // nolint:wrapcheck return nil, fault.New("failed to create validator", messages...) } + valid, docErrors := v.ValidateDocument() + if !valid { + messages := make([]fault.Wrapper, len(docErrors)) + for i, e := range docErrors { + messages[i] = fault.WithDesc(e.Message, "") + } + + return nil, fault.New("openapi document is invalid", messages...) + } return &Validator{ validator: v, }, nil } func (v *Validator) Validate(r *http.Request) (api.BadRequestError, bool) { + valid, errors := v.validator.ValidateHttpRequest(r) if !valid { valErr := api.BadRequestError{ @@ -51,14 +61,12 @@ func (v *Validator) Validate(r *http.Request) (api.BadRequestError, bool) { Instance: nil, Status: http.StatusBadRequest, RequestId: ctxutil.GetRequestId(r.Context()), - Type: "https://unkey.com/docs/api-reference/errors/TODO", + Type: "https://unkey.com/docs/errors/bad_request", Errors: []api.ValidationError{}, } if len(errors) >= 1 { - err := errors[0] - valErr.Title = err.Message - valErr.Detail = err.HowToFix + err := errors[0] for _, e := range err.SchemaValidationErrors { diff --git a/go/schema.json b/go/schema.json index 6f53e5a8e6..54f3511ccb 100644 --- a/go/schema.json +++ b/go/schema.json @@ -28,23 +28,18 @@ "type": "object", "description": "Cluster discovery configuration, only one supported: static, cloudmap", "properties": { - "awsCloudmap": { + "redis": { "type": "object", - "description": "Cloudmap cluster discovery configuration", + "description": "Redis cluster discovery configuration", "properties": { - "region": { + "url": { "type": "string", - "description": "Cloudmap region", - "minLength": 1 - }, - "serviceName": { - "type": "string", - "description": "Cloudmap service name", + "description": "Redis URL", "minLength": 1 } }, "additionalProperties": false, - "required": ["serviceName", "region"] + "required": ["url"] }, "static": { "type": "object",