diff --git a/.github/ISSUE_TEMPLATE/testplan.md b/.github/ISSUE_TEMPLATE/testplan.md index db9192b9fd858..1030bd0c0eb79 100644 --- a/.github/ISSUE_TEMPLATE/testplan.md +++ b/.github/ISSUE_TEMPLATE/testplan.md @@ -180,67 +180,114 @@ as well as an upgrade of the previous version of Teleport. - [ ] Interact with a cluster using `tsh` - These commands should ideally be tested for recording and non-recording modes as they are implemented in a different ways. - - - [ ] tsh ssh \ - - [ ] tsh ssh \ - - [ ] tsh ssh \ - - [ ] tsh ssh \ - - [ ] tsh ssh -A \ - - [ ] tsh ssh -A \ - - [ ] tsh ssh -A \ - - [ ] tsh ssh -A \ - - [ ] tsh ssh \ ls - - [ ] tsh ssh \ ls - - [ ] tsh ssh \ ls - - [ ] tsh ssh \ ls - - [ ] tsh join \ - - [ ] tsh join \ - - [ ] tsh play \ - - [ ] tsh play \ - - [ ] tsh play \ - - [ ] tsh play \ - - [ ] tsh scp \ - - [ ] tsh scp \ - - [ ] tsh scp \ - - [ ] tsh scp \ - - [ ] tsh ssh -L \ - - [ ] tsh ssh -L \ - - [ ] tsh ssh -L \ - - [ ] tsh ssh -L \ - - [ ] tsh ssh -R \ - - [ ] tsh ssh -R \ - - [ ] tsh ssh -R \ - - [ ] tsh ssh -R \ - - [ ] tsh ls - - [ ] tsh clusters + These commands should ideally be tested for recording and non-recording modes as they are implemented in a different ways. + Recording can be disabled by adding `session_recording: off` to `auth_service` in your config. A regular node refers to + a [Teleport SSH service](https://goteleport.com/docs/enroll-resources/server-access/getting-started/). An agentless node is an [OpenSSH server](https://goteleport.com/docs/enroll-resources/server-access/openssh/openssh-agentless) that has been enrolled into Teleport. A remote cluster is a leaf cluster that is connected to a root cluster via a [trusted cluster setup](https://goteleport.com/docs/admin-guides/management/admin/trustedclusters/). Here's a recommended setup for testing: + +``` + ┌───────────────┐ + │ │ + ┌►│ Regular Node │ +┌───────────────┐ ┌───────────────┐ │ │ │ +│ │ │ │ │ └───────────────┘ +│ Root Cluster ├───►│ Leaf Cluster ├─┤ +│ │ │ │ │ ┌───────────────┐ +└───────────────┘ └───────────────┘ │ │ │ + └►│ OpenSSH Node │ + │ │ + └───────────────┘ +``` + +When you want to test a non-remote-cluster, use the Leaf Cluster as your proxy target. + + - [ ] `tsh ssh ` + - [ ] `tsh ssh ` + - [ ] `tsh ssh ` + - [ ] `tsh ssh ` + +Test agent had been forwarded by running `ssh-add -L` and check that your teleport keys are listed. Each cluster requires the `permit-agent-forwarding` flag and the role you're assuming in the leaf cluster needs `Agent Forwarding` enabled. Example connection command: +`tsh ssh -A --proxy $PROXY --cluster $REMOTE_CLUSTER $USER@$NODE_NAME` + + - [ ] `tsh ssh -A ` + - [ ] `tsh ssh -A ` + - [ ] `tsh ssh -A ` + - [ ] `tsh ssh -A ` + - [ ] `tsh ssh ls` + - [ ] `tsh ssh ls` + - [ ] `tsh ssh ls` + - [ ] `tsh ssh ls` + - [ ] `tsh join ` + - [ ] `tsh join ` + +For `tsh play`, ensure the role you assume on the leaf cluster has `read` and `list` for the `session` resource. Example allow rule: +```yaml +spec: + allow: + rules: + - resources: + - session + verbs: + - read + - list +``` + + - [ ] `tsh play ` + - [ ] `tsh play ` + - [ ] `tsh play ` + - [ ] `tsh play ` + - [ ] `tsh scp ` + - [ ] `tsh scp ` + - [ ] `tsh scp ` + - [ ] `tsh scp ` + +This forwards the local port to the remote node, test this with a web server running on the remote node, e.g. `python3 -m http.server 8000` on the remote node, setup a tunnel to the node with `tsh ssh -L 9000:localhost:8000 `, then `curl http://localhost:9000` from your local machine. + + - [ ] `tsh ssh -L ` + - [ ] `tsh ssh -L ` + - [ ] `tsh ssh -L ` + - [ ] `tsh ssh -L ` + +`-R` forwards the remote port to the local machine, test this with a web server running on your local machine, e.g. `python3 -m http.server 8000`, setup a tunnel to the node with `tsh ssh -R 9000:localhost:8000 `, then `curl http://localhost:9000` from the remote node. + + - [ ] `tsh ssh -R ` + - [ ] `tsh ssh -R ` + - [ ] `tsh ssh -R ` + - [ ] `tsh ssh -R ` + - [ ] `tsh ls` + - [ ] `tsh clusters` - [ ] Interact with a cluster using `ssh` - Make sure to test both recording and regular proxy modes. - - [ ] ssh \ - - [ ] ssh \ - - [ ] ssh \ - - [ ] ssh \ - - [ ] ssh -A \ - - [ ] ssh -A \ - - [ ] ssh -A \ - - [ ] ssh -A \ - - [ ] ssh \ ls - - [ ] ssh \ ls - - [ ] ssh \ ls - - [ ] ssh \ ls - - [ ] scp \ - - [ ] scp \ - - [ ] scp \ - - [ ] scp \ - - [ ] ssh -L \ - - [ ] ssh -L \ - - [ ] ssh -L \ - - [ ] ssh -L \ - - [ ] ssh -R \ - - [ ] ssh -R \ - - [ ] ssh -R \ - - [ ] ssh -R \ + + Make sure to test both recording and regular proxy modes. Generate an [SSH config](https://goteleport.com/docs/reference/cli/tsh/#tsh-config), one per cluster. An SSH command will look something like this: + + `ssh -p 22 -F /path/to/generated/ssh_config @.` + + To test connecting to a remote cluster, use the root cluster's `ssh_config` and the name of the remote cluster for ``. + + - [ ] `ssh ` + - [ ] `ssh ` + - [ ] `ssh ` + - [ ] `ssh ` + - [ ] `ssh -A ` + - [ ] `ssh -A ` + - [ ] `ssh -A ` + - [ ] `ssh -A ` + - [ ] `ssh ls` + - [ ] `ssh ls` + - [ ] `ssh ls` + - [ ] `ssh ls` + - [ ] `scp ` + - [ ] `scp ` + - [ ] `scp ` + - [ ] `scp ` + - [ ] `ssh -L ` + - [ ] `ssh -L ` + - [ ] `ssh -L ` + - [ ] `ssh -L ` + - [ ] `ssh -R ` + - [ ] `ssh -R ` + - [ ] `ssh -R ` + - [ ] `ssh -R ` - [ ] Verify proxy jump functionality Log into leaf cluster via root, shut down the root proxy and verify proxy jump works. @@ -1177,16 +1224,15 @@ release/dev build. If you are building Teleport Connect in development mode, you config option `hardwareKeyAgent.enabled: true` and restart Connect. You can run a non-login `tsh` command to check if the agent is running. -Before logging in to Teleport Connect: +In `tsh`, without logging into Teleport Connect: - [ ] `tsh login` prompts for PIV PIN and touch without using the Hardware Key Agent - [ ] All other `tsh` commands prompt for PIN and touch via the Hardware Key Agent - [ ] Test a subset of the `tsh` commands from the test above - [ ] The command is displayed in the PIN and touch prompts -- [ ] Connecting with OpenSSH `ssh` prompts for PIN and touch via the hardware key agent - [ ] The PIN is cached for the configured duration between basic `tsh` commands (set `pin_cache_ttl` to something longer that 15s if needed) -After logging in to Teleport Connect: +In Teleport Connect: - [ ] Login prompts for PIN and touch - [ ] Server Access diff --git a/api/types/sessionrecording.go b/api/types/sessionrecording.go index 832bfcc59a8ec..4a10220505cbc 100644 --- a/api/types/sessionrecording.go +++ b/api/types/sessionrecording.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "iter" "slices" "strings" "time" @@ -43,6 +44,16 @@ type SessionRecordingConfig interface { // SetProxyChecksHostKeys sets if the proxy will check host keys. SetProxyChecksHostKeys(bool) + // GetEncrypted gets if session recordings should be encrypted or not. + GetEncrypted() bool + + // GetEncryptionKeys gets the encryption keys for the session recording config. + GetEncryptionKeys() []*AgeEncryptionKey + + // SetEncryptionKeys sets the encryption keys for the session recording config. + // It returns true if there was a change applied and false otherwise. + SetEncryptionKeys(iter.Seq[*AgeEncryptionKey]) bool + // Clone returns a copy of the resource. Clone() SessionRecordingConfig } @@ -163,6 +174,62 @@ func (c *SessionRecordingConfigV2) SetProxyChecksHostKeys(t bool) { c.Spec.ProxyChecksHostKeys = NewBoolOption(t) } +// GetEncrypted gets if session recordings should be encrypted or not. +func (c *SessionRecordingConfigV2) GetEncrypted() bool { + encryption := c.Spec.Encryption + if encryption == nil { + return false + } + + return encryption.Enabled +} + +// GetEncryptionKeys gets the encryption keys for the session recording config. +func (c *SessionRecordingConfigV2) GetEncryptionKeys() []*AgeEncryptionKey { + if c.Status != nil { + return c.Status.EncryptionKeys + } + + return nil +} + +// SetEncryptionKeys sets the encryption keys for the session recording config. +// It returns true if there was a change applied and false otherwise. +func (c *SessionRecordingConfigV2) SetEncryptionKeys(keys iter.Seq[*AgeEncryptionKey]) bool { + existingKeys := make(map[string]struct{}) + for _, key := range c.GetEncryptionKeys() { + existingKeys[string(key.PublicKey)] = struct{}{} + } + + var keysChanged bool + var newKeys []*AgeEncryptionKey + addedKeys := make(map[string]struct{}) + for key := range keys { + if !keysChanged { + if _, exists := existingKeys[string(key.PublicKey)]; !exists { + keysChanged = true + } + } + + if _, added := addedKeys[string(key.PublicKey)]; !added { + addedKeys[string(key.PublicKey)] = struct{}{} + newKeys = append(newKeys, key) + } + + } + + if !keysChanged || len(newKeys) == 0 || len(existingKeys) == len(addedKeys) { + return false + } + + if c.Status == nil { + c.Status = &SessionRecordingConfigStatus{} + } + c.Status.EncryptionKeys = newKeys + + return true +} + // Clone returns a copy of the resource. func (c *SessionRecordingConfigV2) Clone() SessionRecordingConfig { return utils.CloneProtoMsg(c) diff --git a/api/utils/keys/alias.go b/api/utils/keys/alias.go deleted file mode 100644 index 8323a9c9e53ff..0000000000000 --- a/api/utils/keys/alias.go +++ /dev/null @@ -1,26 +0,0 @@ -/* -Copyright 2025 Gravitational, Inc. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package keys - -import "github.com/gravitational/teleport/api/utils/keys/hardwarekey" - -// Temporary aliases for types moved to the hardwarekey or piv packages -// TODO(Joerger): Remove once /e no longer relies on them. - -// AttestationStatement is an attestation statement for a hardware private key -// that supports json marshaling through the standard json/encoding package. -type AttestationStatement = hardwarekey.AttestationStatement - -// AttestationStatementFromProto converts an AttestationStatement from its protobuf form. -var AttestationStatementFromProto = hardwarekey.AttestationStatementFromProto diff --git a/api/utils/keys/policy_piv.go b/api/utils/keys/policy_piv.go deleted file mode 100644 index 43b76aaad534e..0000000000000 --- a/api/utils/keys/policy_piv.go +++ /dev/null @@ -1,45 +0,0 @@ -//go:build piv && !pivtest - -/* -Copyright 2025 Gravitational, Inc. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package keys - -import ( - "github.com/go-piv/piv-go/v2/piv" -) - -// GetPrivateKeyPolicyFromAttestation returns the PrivateKeyPolicy satisfied by the given hardware key attestation. -// TODO(Joerger): Move to /e where this is used. -func GetPrivateKeyPolicyFromAttestation(att *piv.Attestation) PrivateKeyPolicy { - if att == nil { - return PrivateKeyPolicyNone - } - - isTouchPolicy := att.TouchPolicy == piv.TouchPolicyCached || - att.TouchPolicy == piv.TouchPolicyAlways - - isPINPolicy := att.PINPolicy == piv.PINPolicyOnce || - att.PINPolicy == piv.PINPolicyAlways - - switch { - case isPINPolicy && isTouchPolicy: - return PrivateKeyPolicyHardwareKeyTouchAndPIN - case isPINPolicy: - return PrivateKeyPolicyHardwareKeyPIN - case isTouchPolicy: - return PrivateKeyPolicyHardwareKeyTouch - default: - return PrivateKeyPolicyHardwareKey - } -} diff --git a/docs/config.json b/docs/config.json index c2056467b8769..3f5c83337f8b7 100644 --- a/docs/config.json +++ b/docs/config.json @@ -727,7 +727,12 @@ }, { "source": "/enroll-resources/workload-identity/workload-attestation/", - "destination": "/machine-workload-identity/workload-identity/workload-attestation/", + "destination": "/reference/workload-identity/workload-identity-api-and-workload-attestation/", + "permanent": true + }, + { + "source": "/machine-workload-identity/workload-identity/workload-attestation/", + "destination": "/reference/workload-identity/workload-identity-api-and-workload-attestation/", "permanent": true }, { diff --git a/docs/pages/machine-workload-identity/workload-identity/workload-attestation.mdx b/docs/pages/machine-workload-identity/workload-identity/workload-attestation.mdx deleted file mode 100644 index 077e9894e9c1e..0000000000000 --- a/docs/pages/machine-workload-identity/workload-identity/workload-attestation.mdx +++ /dev/null @@ -1,247 +0,0 @@ ---- -title: Workload Attestation -description: An overview of the Teleport Workload Identity Workload Attestation feature. ---- - -Workload Attestation is the process completed by `tbot` to assert the identity -of a workload that has connected to the Workload API and requested certificates. -The information gathered during attestation is used to decide which, if any, -SPIFFE IDs should be encoded into an SVID and issued to the workload. - -Workload Attestors are the individual components that perform this attestation. -They use the process ID of the workload to gather information about the workload -from platform-specific APIs. For example, the Kubernetes Workload Attestor -queries the local Kubelet API to determine which Kubernetes pod the process -belongs to. - -The result of the attestation process is known as attestation metadata. This -attestation metadata is referred to by the rules you configured for `tbot`'s -Workload API service. For example, you may state that only workloads running in -a specific Kubernetes namespace should be issued a specific SPIFFE ID. - -Additionally, this metadata is included in the log messages output by `tbot` -when it issues an SVID. This allows you to audit the issuance of SVIDs and -understand why a specific SPIFFE ID was issued to a workload. - -## Unix - -The Unix Workload Attestor is the most basic attestor and allows you to restrict -the issuance of SVIDs to specific Unix processes based on a range of criteria. - -### Attestation Metadata - -The following metadata is produced by the Unix Workload Attestor and is -available to be used when configuring rules for `tbot`'s Workload API service: - -| Field | Description | -|-------------------|------------------------------------------------------------------------------| -| `unix.attested` | Indicates that the workload has been attested by the Unix Workload Attestor. | -| `unix.pid` | The process ID of the attested workload. | -| `unix.uid` | The effective user ID of the attested workload. | -| `unix.gid` | The effective primary group ID of the attested workload. | - -### Support for non-standard procfs mounting - -To resolve information about a process from the PID, the Unix Workload Attestor -reads information from the procfs filesystem. By default, it expects procfs to -be mounted at `/proc`. - -If procfs is mounted at a different location, you must configure the Unix -Workload Attestor to read from that alternative location by setting the -`HOST_PROC` environment variable. - -This is a sensitive configuration option, and you should ensure that it is -set correctly or not set at all. If misconfigured, an attacker could provide -falsified information about processes, and this could lead to the issuance of -SVIDs to unauthorized workloads. - -## Kubernetes - -The Kubernetes Workload Attestor allows you to restrict the issuance of SVIDs -to specific Kubernetes workloads based on a range of criteria. - -It works by first determining the pod ID for a given process ID and then by -querying the local kubelet API for details about that pod. - -### Attestation Metadata - -The following metadata is produced by the Kubernetes Workload Attestor and is -available to be used when configuring rules for `tbot`'s Workload API service: - -| Field | Description | -|------------------------------|------------------------------------------------------------------------------------| -| `kubernetes.attested` | Indicates that the workload has been attested by the Kubernetes Workload Attestor. | -| `kubernetes.namespace` | The namespace of the Kubernetes Pod. | -| `kubernetes.service_account` | The service account of the Kubernetes Pod. | -| `kubernetes.pod_name` | The name of the Kubernetes Pod. | - -### Deployment Guidance - -To use Kubernetes Workload Attestation, `tbot` must be deployed as a daemon -set. This is because the unix domain socket can only be accessed by pods on the -same node as the agent. Additionally, the daemon set must have the `hostPID` -property set to `true` to allow the agent to access information about -processes within other containers. - -The daemon set must also have a service account assigned that allows it to query -the Kubelet API. This is an example role with the required RBAC: - -```yaml -kind: ClusterRole -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: tbot -rules: - - resources: ["pods","nodes","nodes/proxy"] - apiGroups: [""] - verbs: ["get"] -``` - -Mapping the Workload API Unix domain socket into the containers of workloads -can be done in two ways: - -- Directly configuring a hostPath volume for the `tbot` daemonset and workloads - which will need to connect to it. -- Using [spiffe-csi-driver](https://github.com/spiffe/spiffe-csi). - -Example manifests for required Kubernetes resources: - -```yaml -kind: ClusterRole -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: tbot -rules: - - resources: ["pods","nodes","nodes/proxy"] - apiGroups: [""] - verbs: ["get"] ---- -kind: ClusterRoleBinding -apiVersion: rbac.authorization.k8s.io/v1 -metadata: - name: tbot -subjects: - - kind: ServiceAccount - name: tbot - namespace: default -roleRef: - kind: ClusterRole - name: tbot - apiGroup: rbac.authorization.k8s.io ---- -apiVersion: v1 -kind: ServiceAccount -metadata: - name: tbot - namespace: default ---- -apiVersion: v1 -kind: ConfigMap -metadata: - name: tbot-config - namespace: default -data: - tbot.yaml: | - version: v2 - onboarding: - join_method: kubernetes - # replace with the name of a join token you have created. - token: example-token - storage: - type: memory - # ensure this is configured to the address of your Teleport Proxy Service. - proxy_server: example.teleport.sh:443 - services: - - type: spiffe-workload-api - listen: unix:///run/tbot/sockets/workload.sock - attestor: - kubernetes: - enabled: true - kubelet: - # skip verification of the Kubelet API certificate as this is not - # usually issued by the cluster CA. - skip_verify: true - # replace the svid entries with the SPIFFE IDs that you wish to issue, - # using the `rules` blocks to restrict these to specific Kubernetes - # workloads. - svids: - - path: /my-service - rules: - - kubernetes: - namespace: default - service_account: example-sa ---- -apiVersion: apps/v1 -kind: DaemonSet -metadata: - name: tbot -spec: - selector: - matchLabels: - app: tbot - template: - metadata: - labels: - app: tbot - spec: - securityContext: - runAsUser: 0 - runAsGroup: 0 - hostPID: true - containers: - - name: tbot - image: public.ecr.aws/gravitational/tbot-distroless:(=teleport.version=) - imagePullPolicy: IfNotPresent - securityContext: - privileged: true - args: - - start - - -c - - /config/tbot.yaml - - --log-format - - json - volumeMounts: - - mountPath: /config - name: config - - mountPath: /var/run/secrets/tokens - name: join-sa-token - - name: tbot-sockets - mountPath: /run/tbot/sockets - readOnly: false - env: - - name: TELEPORT_NODE_NAME - valueFrom: - fieldRef: - fieldPath: spec.nodeName - - name: KUBERNETES_TOKEN_PATH - value: /var/run/secrets/tokens/join-sa-token - serviceAccountName: tbot - volumes: - - name: tbot-sockets - hostPath: - path: /run/tbot/sockets - type: DirectoryOrCreate - - name: config - configMap: - name: tbot-config - - name: join-sa-token - projected: - sources: - - serviceAccountToken: - path: join-sa-token - # 600 seconds is the minimum that Kubernetes supports. We - # recommend this value is used. - expirationSeconds: 600 - # `example.teleport.sh` must be replaced with the name of - # your Teleport cluster. - audience: example.teleport.sh -``` - -## Next steps - -- [Workload Identity Overview](./introduction.mdx): Overview of Teleport -Workload Identity. -- [Best Practices](./best-practices.mdx): Best practices for using Workload -Identity in Production. -- Read the [configuration reference](../../reference/machine-id/configuration.mdx) to explore -all the available configuration options. diff --git a/docs/pages/machine-workload-identity/workload-identity/workload-identity.mdx b/docs/pages/machine-workload-identity/workload-identity/workload-identity.mdx index db9dccac7c6b6..6f9b8a988c075 100644 --- a/docs/pages/machine-workload-identity/workload-identity/workload-identity.mdx +++ b/docs/pages/machine-workload-identity/workload-identity/workload-identity.mdx @@ -19,4 +19,5 @@ description: Securely issue flexible short-lived identities to your workloads - [Best Practices for Teleport Workload Identity](./best-practices.mdx): Answers common questions and describes best practices for using Teleport Workload Identity in production. - [JWT SVIDs](./jwt-svids.mdx): An overview of the JWT SVIDs issued by Teleport Workload Identity - [SPIFFE Federation](./federation.mdx): An overview of the Teleport Workload Identity SPIFFE Federation feature. -- [Workload Attestation](./workload-attestation.mdx): An overview of the Teleport Workload Identity Workload Attestation feature. +- [Workload Attestation](../../reference/workload-identity/workload-identity-api-and-workload-attestation.mdx): An overview of the Teleport Workload Identity Workload Attestation feature. +- [Workload Identity Resource](../../reference/workload-identity/workload-identity-resource.mdx): The full reference for the Workload Identity resource. diff --git a/go.mod b/go.mod index 98a5aaf94d1fa..4c556e40e2c76 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( cloud.google.com/go/storage v1.53.0 code.dny.dev/ssrf v0.2.0 connectrpc.com/connect v1.18.1 + filippo.io/age v1.2.1 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 @@ -603,7 +604,7 @@ replace ( github.com/microsoft/go-mssqldb => github.com/gravitational/go-mssqldb v1.8.1-teleport.2 github.com/opencontainers/selinux => github.com/gravitational/selinux v1.13.0-teleport github.com/redis/go-redis/v9 => github.com/gravitational/redis/v9 v9.6.1-teleport.1 - github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.2 + github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.4 ) // this package was included in google.golang.org/grpc but because it's still diff --git a/go.sum b/go.sum index 3ee413f1b636b..047a860ac7d17 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805 h1:u2qwJeEvnypw+OCPUHmoZE3IqwfuN5kgDfo5MLzpNM0= +c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805/go.mod h1:FomMrUJ2Lxt5jCLmZkG3FHa72zUprnhd3v/Z18Snm4w= cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= @@ -641,6 +643,8 @@ cuelang.org/go v0.12.1/go.mod h1:B4+kjvGGQnbkz+GuAv1dq/R308gTkp0sO28FdMrJ2Kw= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o= +filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= @@ -1556,8 +1560,8 @@ github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33 h1:VFE github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1 h1:Kt7aT9N7vbZmcMejGXnSAGap8TUwH3fMoHE8cQm14wc= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1/go.mod h1:n4RXV6T3SJ/vrJqmc4vBeHpaBspxWENDD67ssQlXXkg= -github.com/gravitational/predicate v1.3.2 h1:NAdaWihgGVS3nuKDZZHTda4wwNjGlvG4a+UYQ4f3gQc= -github.com/gravitational/predicate v1.3.2/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= +github.com/gravitational/predicate v1.3.4 h1:9N3JhBXNPcUh0w8DdlpnVmfnH9Z3xxbw43sD3E19VBE= +github.com/gravitational/predicate v1.3.4/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= github.com/gravitational/protobuf v1.3.2-teleport.2 h1:MO5eFXfGfDiAbBA7X8tDW2EMLfRWQVJMmK+MA6gV8AI= github.com/gravitational/protobuf v1.3.2-teleport.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gravitational/redis/v9 v9.6.1-teleport.1 h1:gPirfPKArN2nPhTKR3h9fnEg5YuYU933+CjlDJMo4H0= diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index e1b820c18b218..d9818abac63c4 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -386,5 +386,5 @@ replace ( github.com/keys-pub/go-libfido2 => github.com/gravitational/go-libfido2 v1.5.3-teleport.1 github.com/microsoft/go-mssqldb => github.com/gravitational/go-mssqldb v1.8.1-teleport.2 github.com/redis/go-redis/v9 => github.com/gravitational/redis/v9 v9.6.1-teleport.1 - github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.2 + github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.4 ) diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 2b92e2cc5099a..525d9ac3e790f 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -482,8 +482,8 @@ github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33 h1:VFE github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1 h1:Kt7aT9N7vbZmcMejGXnSAGap8TUwH3fMoHE8cQm14wc= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1/go.mod h1:n4RXV6T3SJ/vrJqmc4vBeHpaBspxWENDD67ssQlXXkg= -github.com/gravitational/predicate v1.3.2 h1:NAdaWihgGVS3nuKDZZHTda4wwNjGlvG4a+UYQ4f3gQc= -github.com/gravitational/predicate v1.3.2/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= +github.com/gravitational/predicate v1.3.4 h1:9N3JhBXNPcUh0w8DdlpnVmfnH9Z3xxbw43sD3E19VBE= +github.com/gravitational/predicate v1.3.4/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= github.com/gravitational/protobuf v1.3.2-teleport.2 h1:MO5eFXfGfDiAbBA7X8tDW2EMLfRWQVJMmK+MA6gV8AI= github.com/gravitational/protobuf v1.3.2-teleport.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gravitational/redis/v9 v9.6.1-teleport.1 h1:gPirfPKArN2nPhTKR3h9fnEg5YuYU933+CjlDJMo4H0= diff --git a/integrations/terraform-mwi/go.mod b/integrations/terraform-mwi/go.mod index fd111c377add9..cc5a63ebb7d88 100644 --- a/integrations/terraform-mwi/go.mod +++ b/integrations/terraform-mwi/go.mod @@ -555,5 +555,5 @@ replace ( github.com/microsoft/go-mssqldb => github.com/gravitational/go-mssqldb v1.8.1-teleport.2 github.com/opencontainers/selinux => github.com/gravitational/selinux v1.13.0-teleport github.com/redis/go-redis/v9 => github.com/gravitational/redis/v9 v9.6.1-teleport.1 - github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.2 + github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.4 ) diff --git a/integrations/terraform-mwi/go.sum b/integrations/terraform-mwi/go.sum index 8453119b831a4..c2fa62bb82848 100644 --- a/integrations/terraform-mwi/go.sum +++ b/integrations/terraform-mwi/go.sum @@ -707,8 +707,8 @@ github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33 h1:VFE github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1 h1:Kt7aT9N7vbZmcMejGXnSAGap8TUwH3fMoHE8cQm14wc= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1/go.mod h1:n4RXV6T3SJ/vrJqmc4vBeHpaBspxWENDD67ssQlXXkg= -github.com/gravitational/predicate v1.3.2 h1:NAdaWihgGVS3nuKDZZHTda4wwNjGlvG4a+UYQ4f3gQc= -github.com/gravitational/predicate v1.3.2/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= +github.com/gravitational/predicate v1.3.4 h1:9N3JhBXNPcUh0w8DdlpnVmfnH9Z3xxbw43sD3E19VBE= +github.com/gravitational/predicate v1.3.4/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= github.com/gravitational/protobuf v1.3.2-teleport.2 h1:MO5eFXfGfDiAbBA7X8tDW2EMLfRWQVJMmK+MA6gV8AI= github.com/gravitational/protobuf v1.3.2-teleport.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gravitational/redis/v9 v9.6.1-teleport.1 h1:gPirfPKArN2nPhTKR3h9fnEg5YuYU933+CjlDJMo4H0= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index 585e6b708ec73..14f72ab41ce39 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -471,7 +471,7 @@ replace ( github.com/keys-pub/go-libfido2 => github.com/gravitational/go-libfido2 v1.5.3-teleport.1 github.com/microsoft/go-mssqldb => github.com/gravitational/go-mssqldb v1.8.1-teleport.2 github.com/redis/go-redis/v9 => github.com/gravitational/redis/v9 v9.6.1-teleport.1 - github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.2 + github.com/vulcand/predicate => github.com/gravitational/predicate v1.3.4 ) // Doc generation tooling. diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 64c8044ac2c96..5372190497933 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -718,8 +718,8 @@ github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33 h1:VFE github.com/gravitational/kingpin/v2 v2.1.11-0.20230515143221-4ec6b70ecd33/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1 h1:Kt7aT9N7vbZmcMejGXnSAGap8TUwH3fMoHE8cQm14wc= github.com/gravitational/license v0.0.0-20250329001817-070456fa8ec1/go.mod h1:n4RXV6T3SJ/vrJqmc4vBeHpaBspxWENDD67ssQlXXkg= -github.com/gravitational/predicate v1.3.2 h1:NAdaWihgGVS3nuKDZZHTda4wwNjGlvG4a+UYQ4f3gQc= -github.com/gravitational/predicate v1.3.2/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= +github.com/gravitational/predicate v1.3.4 h1:9N3JhBXNPcUh0w8DdlpnVmfnH9Z3xxbw43sD3E19VBE= +github.com/gravitational/predicate v1.3.4/go.mod h1:cTQkp40X3YejTcWsZGvzAtfa28VXfBxT10H/Grt0Fzs= github.com/gravitational/protobuf v1.3.2-teleport.2 h1:MO5eFXfGfDiAbBA7X8tDW2EMLfRWQVJMmK+MA6gV8AI= github.com/gravitational/protobuf v1.3.2-teleport.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gravitational/redis/v9 v9.6.1-teleport.1 h1:gPirfPKArN2nPhTKR3h9fnEg5YuYU933+CjlDJMo4H0= diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go index 4904076472e36..500dd7c15bb9a 100644 --- a/lib/auth/accesspoint/accesspoint.go +++ b/lib/auth/accesspoint/accesspoint.go @@ -74,6 +74,7 @@ type Config struct { AccessMonitoringRules services.AccessMonitoringRules AppSession services.AppSession Apps services.Apps + BotInstance services.BotInstance ClusterConfig services.ClusterConfiguration CrownJewels services.CrownJewels DatabaseObjects services.DatabaseObjects @@ -211,6 +212,7 @@ func NewCache(cfg Config) (*cache.Cache, error) { PluginStaticCredentials: cfg.PluginStaticCredentials, GitServers: cfg.GitServers, HealthCheckConfig: cfg.HealthCheckConfig, + BotInstanceService: cfg.BotInstance, } return cache.New(cfg.Setup(cacheCfg)) diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go index da50bf1b151cd..36246d4987140 100644 --- a/lib/auth/authclient/api.go +++ b/lib/auth/authclient/api.go @@ -1237,6 +1237,12 @@ type Cache interface { GetRelayServer(ctx context.Context, name string) (*presencev1.RelayServer, error) // ListRelayServers returns a paginated list of relay server heartbeats. ListRelayServers(ctx context.Context, pageSize int, pageToken string) (_ []*presencev1.RelayServer, nextPageToken string, _ error) + + // GetBotInstance returns the specified BotInstance resource. + GetBotInstance(ctx context.Context, botName, instanceID string) (*machineidv1.BotInstance, error) + + // ListBotInstances returns a page of BotInstance resources. + ListBotInstances(ctx context.Context, botName string, pageSize int, lastToken string, search string) ([]*machineidv1.BotInstance, string, error) } type NodeWrapper struct { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index c9c341edff62f..1e775ca891a22 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5318,6 +5318,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { botInstanceService, err := machineidv1.NewBotInstanceService(machineidv1.BotInstanceServiceConfig{ Authorizer: cfg.Authorizer, + Cache: cfg.AuthServer.Cache, Backend: cfg.AuthServer.Services.BotInstance, Clock: cfg.AuthServer.GetClock(), }) diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 1dc6e73a0204e..57d5b06fe9e65 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -552,6 +552,7 @@ func InitTestAuthCache(p TestAuthCacheParams) error { PluginStaticCredentials: p.AuthServer.Services.PluginStaticCredentials, GitServers: p.AuthServer.Services.GitServers, HealthCheckConfig: p.AuthServer.Services.HealthCheckConfig, + BotInstance: p.AuthServer.Services.BotInstance, }) if err != nil { return trace.Wrap(err) diff --git a/lib/auth/machineid/machineidv1/bot_instance_service.go b/lib/auth/machineid/machineidv1/bot_instance_service.go index db89348713956..ff1ccf7454e46 100644 --- a/lib/auth/machineid/machineidv1/bot_instance_service.go +++ b/lib/auth/machineid/machineidv1/bot_instance_service.go @@ -48,10 +48,20 @@ const ( ExpiryMargin = time.Minute * 5 ) +// BotInstancesCache is the subset of the cached resources that the Service queries. +type BotInstancesCache interface { + // GetBotInstance returns the specified BotInstance resource. + GetBotInstance(ctx context.Context, botName, instanceID string) (*pb.BotInstance, error) + + // ListBotInstances returns a page of BotInstance resources. + ListBotInstances(ctx context.Context, botName string, pageSize int, lastToken string, search string) ([]*pb.BotInstance, string, error) +} + // BotInstanceServiceConfig holds configuration options for the BotInstance gRPC // service. type BotInstanceServiceConfig struct { Authorizer authz.Authorizer + Cache BotInstancesCache Backend services.BotInstance Logger *slog.Logger Clock clockwork.Clock @@ -64,6 +74,8 @@ func NewBotInstanceService(cfg BotInstanceServiceConfig) (*BotInstanceService, e return nil, trace.BadParameter("backend service is required") case cfg.Authorizer == nil: return nil, trace.BadParameter("authorizer is required") + case cfg.Cache == nil: + return nil, trace.BadParameter("cache service is required") } if cfg.Logger == nil { @@ -76,6 +88,7 @@ func NewBotInstanceService(cfg BotInstanceServiceConfig) (*BotInstanceService, e return &BotInstanceService{ logger: cfg.Logger, authorizer: cfg.Authorizer, + cache: cfg.Cache, backend: cfg.Backend, clock: cfg.Clock, }, nil @@ -87,6 +100,7 @@ type BotInstanceService struct { backend services.BotInstance authorizer authz.Authorizer + cache BotInstancesCache logger *slog.Logger clock clockwork.Clock } @@ -124,7 +138,7 @@ func (b *BotInstanceService) GetBotInstance(ctx context.Context, req *pb.GetBotI return nil, trace.Wrap(err) } - res, err := b.backend.GetBotInstance(ctx, req.BotName, req.InstanceId) + res, err := b.cache.GetBotInstance(ctx, req.BotName, req.InstanceId) if err != nil { return nil, trace.Wrap(err) } @@ -143,7 +157,7 @@ func (b *BotInstanceService) ListBotInstances(ctx context.Context, req *pb.ListB return nil, trace.Wrap(err) } - res, nextToken, err := b.backend.ListBotInstances(ctx, req.FilterBotName, int(req.PageSize), req.PageToken, req.FilterSearchTerm) + res, nextToken, err := b.cache.ListBotInstances(ctx, req.FilterBotName, int(req.PageSize), req.PageToken, req.FilterSearchTerm) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/machineid/machineidv1/bot_instance_service_test.go b/lib/auth/machineid/machineidv1/bot_instance_service_test.go index 39a4afe815404..37a4d0a491d76 100644 --- a/lib/auth/machineid/machineidv1/bot_instance_service_test.go +++ b/lib/auth/machineid/machineidv1/bot_instance_service_test.go @@ -334,6 +334,7 @@ func TestBotInstanceServiceSubmitHeartbeat(t *testing.T) { backend := newBotInstanceBackend(t) service, err := NewBotInstanceService(BotInstanceServiceConfig{ Backend: backend, + Cache: backend, Authorizer: authz.AuthorizerFunc(func(ctx context.Context) (*authz.Context, error) { return &authz.Context{ Identity: identityGetterFn(func() tlsca.Identity { @@ -391,6 +392,7 @@ func TestBotInstanceServiceSubmitHeartbeat_HeartbeatLimit(t *testing.T) { backend := newBotInstanceBackend(t) service, err := NewBotInstanceService(BotInstanceServiceConfig{ Backend: backend, + Cache: backend, Authorizer: authz.AuthorizerFunc(func(ctx context.Context) (*authz.Context, error) { return &authz.Context{ Identity: identityGetterFn(func() tlsca.Identity { @@ -581,6 +583,7 @@ func newBotInstanceService( service, err := NewBotInstanceService(BotInstanceServiceConfig{ Authorizer: authorizer, Backend: backendService, + Cache: backendService, }) require.NoError(t, err) diff --git a/lib/auth/recordingencryption/age.go b/lib/auth/recordingencryption/age.go new file mode 100644 index 0000000000000..b4e6ac495c548 --- /dev/null +++ b/lib/auth/recordingencryption/age.go @@ -0,0 +1,119 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package recordingencryption + +import ( + "context" + + "filippo.io/age" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" +) + +// X25519Stanza is the default stanza type used by age. +const X25519Stanza = "X25519" + +// RecordingStanza is the type used for the identifying stanza added by RecordingRecipient. +const RecordingStanza = "Recording-X25519" + +// DecryptionKeyFinder returns an EncryptionKeyPair related to at least one of the given public keys to be used +// for file key unwrapping. +type DecryptionKeyFinder interface { + FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) +} + +// RecordingIdentity removes public keys from stanzas and passes the unwrap call to the default +// age.X25519Identity. +type RecordingIdentity struct { + ctx context.Context + keyFinder DecryptionKeyFinder +} + +// NewRecordingIdentity returns a RecordingIdentity that will use the given DecryptionKeyFinder in order to facilitate +// file key unwrapping. +func NewRecordingIdentity(ctx context.Context, keyFinder DecryptionKeyFinder) *RecordingIdentity { + return &RecordingIdentity{ + ctx: ctx, + keyFinder: keyFinder, + } +} + +// Unwrap uses the additional stanzas added by RecordingRecipient.Wrap in order to find a matching X25519 identity. +func (i *RecordingIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + var publicKeys [][]byte + for _, stanza := range stanzas { + if stanza.Type != RecordingStanza { + continue + } + + if len(stanza.Args) != 1 { + continue + } + + publicKeys = append(publicKeys, []byte(stanza.Args[0])) + } + + pair, err := i.keyFinder.FindDecryptionKey(i.ctx, publicKeys...) + if err != nil { + return nil, trace.Wrap(err) + } + + identity, err := age.ParseX25519Identity(string(pair.PrivateKey)) + if err != nil { + return nil, trace.Wrap(err) + } + + return identity.Unwrap(stanzas) +} + +// RecordingRecipient adds the public key to the stanzas generated by the default age.X25519Recipient +type RecordingRecipient struct { + *age.X25519Recipient +} + +// ParseRecordingRecipient parses an Bech32 encoded age X25519 public key into a RecordingRecipient. +func ParseRecordingRecipient(s string) (*RecordingRecipient, error) { + recipient, err := age.ParseX25519Recipient(s) + if err != nil { + return nil, trace.Wrap(err) + } + + return &RecordingRecipient{X25519Recipient: recipient}, nil +} + +// Wrap a fileKey using the wrapped X2519Recipient. An additional stanza containing the bech32 encoded X25519 +// public key will be created to enable lookups during Unwrap. +func (r *RecordingRecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { + stanzas, err := r.X25519Recipient.Wrap(fileKey) + if err != nil { + return nil, trace.Wrap(err) + } + + // a new stanza has to be added because modifying the original stanza and returning it to "normal" during + // Unwrap fails due to MAC errors + for _, stanza := range stanzas { + if stanza.Type == X25519Stanza { + stanzas = append(stanzas, &age.Stanza{ + Type: RecordingStanza, + Args: []string{r.String()}, + }) + } + } + + return stanzas, nil +} diff --git a/lib/auth/recordingencryption/age_test.go b/lib/auth/recordingencryption/age_test.go new file mode 100644 index 0000000000000..a7329e22e0b1d --- /dev/null +++ b/lib/auth/recordingencryption/age_test.go @@ -0,0 +1,111 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package recordingencryption_test + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + + "filippo.io/age" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/recordingencryption" +) + +func TestRecordingAgePlugin(t *testing.T) { + ctx := t.Context() + keyFinder := newFakeKeyFinder() + recordingIdentity := recordingencryption.NewRecordingIdentity(ctx, keyFinder) + + ident, err := keyFinder.generateIdentity() + require.NoError(t, err) + + recipient, err := recordingencryption.ParseRecordingRecipient(ident.Recipient().String()) + require.NoError(t, err) + + out := bytes.NewBuffer(nil) + writer, err := age.Encrypt(out, recipient) + require.NoError(t, err) + + msg := []byte("testing age plugin for session recordings") + _, err = writer.Write(msg) + require.NoError(t, err) + + // writer must be closed to ensure data is flushed + err = writer.Close() + require.NoError(t, err) + + reader, err := age.Decrypt(out, recordingIdentity) + require.NoError(t, err) + plaintext, err := io.ReadAll(reader) + require.NoError(t, err) + + require.Equal(t, msg, plaintext) + + // running the same test with the raw recipient should fail because the + // the extra stanza added by RecordingRecipient won't be present and + // the private key won't be found + out.Reset() + writer, err = age.Encrypt(out, ident.Recipient()) + require.NoError(t, err) + _, err = writer.Write(msg) + require.NoError(t, err) + err = writer.Close() + require.NoError(t, err) + _, err = age.Decrypt(out, recordingIdentity) + require.Error(t, err) +} + +type fakeKeyFinder struct { + keys map[string]string +} + +func newFakeKeyFinder() *fakeKeyFinder { + return &fakeKeyFinder{ + keys: make(map[string]string), + } +} + +func (f *fakeKeyFinder) FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) { + for _, pubKey := range publicKeys { + key, ok := f.keys[string(pubKey)] + if !ok { + continue + } + + return &types.EncryptionKeyPair{ + PrivateKey: []byte(key), + PublicKey: pubKey, + }, nil + } + + return nil, errors.New("no accessible decryption key found") +} + +func (f *fakeKeyFinder) generateIdentity() (*age.X25519Identity, error) { + ident, err := age.GenerateX25519Identity() + if err != nil { + return nil, err + } + + f.keys[ident.Recipient().String()] = ident.String() + return ident, nil +} diff --git a/lib/auth/recordingencryption/manager.go b/lib/auth/recordingencryption/manager.go new file mode 100644 index 0000000000000..889ac9c624ced --- /dev/null +++ b/lib/auth/recordingencryption/manager.go @@ -0,0 +1,533 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package recordingencryption + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "iter" + "log/slog" + "slices" + "time" + + "filippo.io/age" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport" + recordingencryptionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/services" +) + +// KeyStore provides methods for interacting with encryption keys. +type KeyStore interface { + NewEncryptionKeyPair(ctx context.Context, purpose cryptosuites.KeyPurpose) (*types.EncryptionKeyPair, error) + GetDecrypter(ctx context.Context, keyPair *types.EncryptionKeyPair) (crypto.Decrypter, error) +} + +// ManagerConfig captures all of the dependencies required to instantiate a Manager. +type ManagerConfig struct { + Backend services.RecordingEncryption + ClusterConfig services.ClusterConfigurationInternal + KeyStore KeyStore + Logger *slog.Logger + LockConfig backend.RunWhileLockedConfig +} + +// NewManager returns a new Manager using the given ManagerConfig. +func NewManager(cfg ManagerConfig) (*Manager, error) { + switch { + case cfg.Backend == nil: + return nil, trace.BadParameter("backend is required") + case cfg.ClusterConfig == nil: + return nil, trace.BadParameter("cluster config is required") + case cfg.KeyStore == nil: + return nil, trace.BadParameter("key store is required") + } + + if cfg.Logger == nil { + cfg.Logger = slog.With(teleport.ComponentKey, "recording-encryption-manager") + } + + return &Manager{ + RecordingEncryption: cfg.Backend, + ClusterConfigurationInternal: cfg.ClusterConfig, + + keyStore: cfg.KeyStore, + lockConfig: cfg.LockConfig, + logger: cfg.Logger, + }, nil +} + +// A Manager wraps a services.RecordingEncryption and KeyStore in order to provide more complex operations +// than the CRUD methods exposed by services.RecordingEncryption. It primarily handles resolving RecordingEncryption +// state and searching for accessible decryption keys. +type Manager struct { + services.RecordingEncryption + services.ClusterConfigurationInternal + + logger *slog.Logger + lockConfig backend.RunWhileLockedConfig + keyStore KeyStore +} + +// CreateSessionRecordingConfig creates a new session recording configuration. If encryption is enabled then the +// recording encryption resource will also be resolved. +func (m *Manager) CreateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { + err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { + if cfg.GetEncrypted() { + encryption, err := m.resolveRecordingEncryption(ctx) + if err != nil { + return err + } + + _ = cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys)) + } + + sessionRecordingConfig, err = m.ClusterConfigurationInternal.CreateSessionRecordingConfig(ctx, cfg) + if err != nil { + return trace.Wrap(err) + } + + return nil + }) + + return sessionRecordingConfig, trace.Wrap(err) +} + +// UpdateSessionRecordingConfig updates an existing session recording configuration. If encryption is enabled then +// the recording encryption resource will also be resolved. +func (m *Manager) UpdateSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { + err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { + if cfg.GetEncrypted() { + encryption, err := m.resolveRecordingEncryption(ctx) + if err != nil { + return err + } + + _ = cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys)) + } + + sessionRecordingConfig, err = m.ClusterConfigurationInternal.UpdateSessionRecordingConfig(ctx, cfg) + if err != nil { + return trace.Wrap(err) + } + + return nil + }) + + return sessionRecordingConfig, trace.Wrap(err) +} + +// UpsertSessionRecordingConfig creates a new session recording configuration or overwrites an existing one. If +// encryption is enabled then the recording encryption resource will also be resolved. +func (m *Manager) UpsertSessionRecordingConfig(ctx context.Context, cfg types.SessionRecordingConfig) (sessionRecordingConfig types.SessionRecordingConfig, err error) { + err = backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { + if cfg.GetEncrypted() { + encryption, err := m.resolveRecordingEncryption(ctx) + if err != nil { + return err + } + + _ = cfg.SetEncryptionKeys(getAgeEncryptionKeys(encryption.GetSpec().ActiveKeys)) + } + + sessionRecordingConfig, err = m.ClusterConfigurationInternal.UpsertSessionRecordingConfig(ctx, cfg) + if err != nil { + return trace.Wrap(err) + } + + return nil + }) + + return sessionRecordingConfig, trace.Wrap(err) +} + +// ensureActiveRecordingEncryption returns the configured RecordingEncryption resource if it exists with active keys. If it does not, +// then the resource will be created or updated with a new active keypair. The bool return value indicates whether or not +// a new pair was provisioned. +func (m *Manager) ensureActiveRecordingEncryption(ctx context.Context) (*recordingencryptionv1.RecordingEncryption, bool, error) { + persistFn := m.RecordingEncryption.UpdateRecordingEncryption + encryption, err := m.RecordingEncryption.GetRecordingEncryption(ctx) + if err != nil { + if !trace.IsNotFound(err) { + return encryption, false, trace.Wrap(err) + } + encryption = &recordingencryptionv1.RecordingEncryption{ + Spec: &recordingencryptionv1.RecordingEncryptionSpec{}, + } + persistFn = m.RecordingEncryption.CreateRecordingEncryption + } + + activeKeys := encryption.GetSpec().ActiveKeys + + // no keys present, need to generate the initial active keypair + if len(activeKeys) > 0 { + return encryption, false, nil + } + + keyEncryptionPair, err := m.keyStore.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) + if err != nil { + return encryption, false, trace.Wrap(err, "generating wrapping key") + } + + ident, err := age.GenerateX25519Identity() + if err != nil { + return encryption, false, trace.Wrap(err, "generating age encryption key") + } + + encryptedIdent, err := keyEncryptionPair.EncryptOAEP([]byte(ident.String())) + if err != nil { + return encryption, false, trace.Wrap(err, "wrapping encryption key") + } + + wrappedKey := recordingencryptionv1.WrappedKey{ + KeyEncryptionPair: keyEncryptionPair, + RecordingEncryptionPair: &types.EncryptionKeyPair{ + PrivateKeyType: types.PrivateKeyType_RAW, + PrivateKey: encryptedIdent, + PublicKey: []byte(ident.Recipient().String()), + }, + } + encryption.Spec.ActiveKeys = []*recordingencryptionv1.WrappedKey{&wrappedKey} + encryption, err = persistFn(ctx, encryption) + if err != nil { + return encryption, false, trace.Wrap(err) + } + fp := sha256.Sum256(wrappedKey.RecordingEncryptionPair.PublicKey) + m.logger.InfoContext(ctx, "no active keys, generated initial recording encryption pair", "public_fingerprint", hex.EncodeToString(fp[:])) + return encryption, true, nil +} + +var errWaitingForKey = errors.New("waiting for key to be fulfilled") + +// getRecordingEncryptionKey returns the first active recording encryption key accessible to the configured key store. +func (m *Manager) getRecordingEncryptionKeyPair(ctx context.Context, keys []*recordingencryptionv1.WrappedKey) (*types.EncryptionKeyPair, error) { + var foundUnfulfilledKey bool + for _, key := range keys { + decrypter, err := m.keyStore.GetDecrypter(ctx, key.KeyEncryptionPair) + if err != nil { + continue + } + + // if we make it to this section the key is accessible to the current auth server + if key.RecordingEncryptionPair == nil { + foundUnfulfilledKey = true + continue + } + + decryptionKey, err := decrypter.Decrypt(rand.Reader, key.RecordingEncryptionPair.PrivateKey, nil) + if err != nil { + return nil, trace.Wrap(err, "decrypting known key") + } + + return &types.EncryptionKeyPair{ + PrivateKey: decryptionKey, + PublicKey: key.RecordingEncryptionPair.PublicKey, + }, nil + } + + if foundUnfulfilledKey { + return nil, trace.Wrap(errWaitingForKey) + } + + return nil, trace.NotFound("no accessible recording encryption pair found") +} + +// resolveRecordingEncryption examines the current state of the RescordingEncryption resource and advances it to the +// next state on behalf of the current auth server. +// +// When no active recording encryption key pairs exist, the first pair will be generated and wrapped using a new key +// encryption pair generated by the Manager's keystore. +// +// When at least one active keypair exists but none are accessible to the Manager's keystore, a new key encryption pair +// will be generated and saved without a key encryption pair. This is an unfulfilled key that some other instance of +// Manager on another auth server will need to fulfill asynchronously. +// +// If at least one active key is accessible to the Manager's keystore, then unfulfilled keys (identified by missing +// recording encryption key pairs) will be fulfilled using their public key encryption keys. +// +// If there are no unfulfilled keys present, then nothing should be done. +func (m *Manager) resolveRecordingEncryption(ctx context.Context) (*recordingencryptionv1.RecordingEncryption, error) { + encryption, generatedKey, err := m.ensureActiveRecordingEncryption(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + if generatedKey { + m.logger.DebugContext(ctx, "created initial recording encryption key") + return encryption, nil + } + + activeKeys := encryption.GetSpec().ActiveKeys + recordingEncryptionPair, err := m.getRecordingEncryptionKeyPair(ctx, activeKeys) + if err != nil { + if errors.Is(err, errWaitingForKey) { + // do nothing + return encryption, nil + } + + if trace.IsNotFound(err) { + m.logger.InfoContext(ctx, "no accessible recording encryption keys, posting new key to be fulfilled") + keypair, err := m.keyStore.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) + if err != nil { + return nil, trace.Wrap(err, "generating keypair for new wrapped key") + } + encryption.GetSpec().ActiveKeys = append(activeKeys, &recordingencryptionv1.WrappedKey{ + KeyEncryptionPair: keypair, + }) + + encryption, err = m.RecordingEncryption.UpdateRecordingEncryption(ctx, encryption) + return encryption, trace.Wrap(err, "updating session recording config") + } + + return nil, trace.Wrap(err) + } + + var shouldUpdate bool + for _, key := range activeKeys { + if key.RecordingEncryptionPair != nil { + continue + } + + encryptedKey, err := key.KeyEncryptionPair.EncryptOAEP(recordingEncryptionPair.PrivateKey) + if err != nil { + return encryption, trace.Wrap(err, "reencrypting decryption key") + } + + key.RecordingEncryptionPair = &types.EncryptionKeyPair{ + PrivateKey: encryptedKey, + PublicKey: recordingEncryptionPair.PublicKey, + } + + shouldUpdate = true + } + + if shouldUpdate { + m.logger.DebugContext(ctx, "fulfilling empty keys") + encryption, err = m.RecordingEncryption.UpdateRecordingEncryption(ctx, encryption) + if err != nil { + return encryption, trace.Wrap(err, "updating session recording config") + } + } + + return encryption, nil +} + +func (m *Manager) searchActiveKeys(ctx context.Context, activeKeys []*recordingencryptionv1.WrappedKey, publicKey []byte) (*types.EncryptionKeyPair, error) { + for _, key := range activeKeys { + if key.GetRecordingEncryptionPair() == nil { + continue + } + + if !slices.Equal(key.RecordingEncryptionPair.PublicKey, publicKey) { + continue + } + + decrypter, err := m.keyStore.GetDecrypter(ctx, key.KeyEncryptionPair) + if err != nil { + continue + } + + privateKey, err := decrypter.Decrypt(rand.Reader, key.RecordingEncryptionPair.PrivateKey, nil) + if err != nil { + return nil, trace.Wrap(err) + } + + return &types.EncryptionKeyPair{ + PrivateKey: privateKey, + PublicKey: key.RecordingEncryptionPair.PublicKey, + PrivateKeyType: key.RecordingEncryptionPair.PrivateKeyType, + }, nil + } + + return nil, trace.NotFound("no accessible decryption key found") +} + +// FindDecryptionKey returns the first accessible decryption key that matches one of the given public keys. +func (m *Manager) FindDecryptionKey(ctx context.Context, publicKeys ...[]byte) (*types.EncryptionKeyPair, error) { + encryption, err := m.RecordingEncryption.GetRecordingEncryption(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + // TODO (eriktate): search rotated keys as well once rotation is implemented + activeKeys := encryption.GetSpec().ActiveKeys + if len(publicKeys) == 0 { + return m.searchActiveKeys(ctx, activeKeys, nil) + } + + for _, publicKey := range publicKeys { + found, err := m.searchActiveKeys(ctx, activeKeys, publicKey) + if err != nil { + if trace.IsNotFound(err) { + continue + } + + if !slices.Equal(found.PublicKey, publicKey) { + continue + } + + decrypter, err := m.keyStore.GetDecrypter(ctx, found) + if err != nil { + if !trace.IsNotFound(err) { + m.logger.ErrorContext(ctx, "could not get decrypter from key store", "error", err) + } + continue + } + + privateKey, err := decrypter.Decrypt(rand.Reader, found.PrivateKey, nil) + if err != nil { + return nil, trace.Wrap(err) + } + + return &types.EncryptionKeyPair{ + PrivateKey: privateKey, + PublicKey: found.PublicKey, + PrivateKeyType: found.PrivateKeyType, + }, nil + } + + return found, nil + } + + return nil, trace.NotFound("no accessible decryption key found") +} + +func (m *Manager) Watch(ctx context.Context, events types.Events) (err error) { + // shouldRetryAfterJitterFn waits at most 5 seconds and returns a bool specifying whether or not + // execution should continue + shouldRetryAfterJitterFn := func() bool { + select { + case <-time.After(retryutils.SeventhJitter(time.Second * 5)): + return true + case <-ctx.Done(): + return false + } + } + + defer func() { + m.logger.InfoContext(ctx, "stopping encryption watcher", "error", err) + }() + + for { + watch, err := events.NewWatcher(ctx, types.Watch{ + Name: "recording_encryption_watcher", + Kinds: []types.WatchKind{ + { + Kind: types.KindRecordingEncryption, + }, + }, + }) + if err != nil { + m.logger.ErrorContext(ctx, "failed to create watcher, retrying", "error", err) + if !shouldRetryAfterJitterFn() { + return nil + } + continue + } + defer watch.Close() + + HandleEvents: + for { + select { + case ev := <-watch.Events(): + if err := m.handleEvent(ctx, ev, shouldRetryAfterJitterFn); err != nil { + m.logger.ErrorContext(ctx, "failure handling recording encryption event", "kind", ev.Resource.GetKind(), "error", err) + } + case <-watch.Done(): + if err := watch.Error(); err == nil { + return nil + } + + m.logger.ErrorContext(ctx, "watcher failed, retrying", "error", err) + if !shouldRetryAfterJitterFn() { + return nil + } + break HandleEvents + case <-ctx.Done(): + return nil + } + + } + } +} + +func (m *Manager) handleEvent(ctx context.Context, ev types.Event, shouldRetryFn func() bool) error { + if ev.Type != types.OpPut { + return nil + } + + if ev.Resource.GetKind() != types.KindRecordingEncryption { + return nil + } + + const retries = 3 + for retry := range retries { + err := backend.RunWhileLocked(ctx, m.lockConfig, func(ctx context.Context) error { + sessionRecordingConfig, err := m.GetSessionRecordingConfig(ctx) + if err != nil { + m.logger.ErrorContext(ctx, "failed to retrieve session_recording_config, retrying", "error", err) + return err + } + + if !sessionRecordingConfig.GetEncrypted() { + return nil + } + + if _, err := m.resolveRecordingEncryption(ctx); err != nil { + m.logger.ErrorContext(ctx, "failed to resolve recording encryption keys, retrying", "retry", retry, "retries_left", retries-retry, "error", err) + + return err + } + + return nil + }) + if err != nil && shouldRetryFn() { + continue + } + + return nil + } + + return trace.LimitExceeded("resolving recording encryption exceeded max retries") +} + +// getAgeEncryptionKeys returns an iterator of AgeEncryptionKeys from a list of WrappedKeys. This is for use in +// populating the EncryptionKeys field of SessionRecordingConfigStatus. +func getAgeEncryptionKeys(keys []*recordingencryptionv1.WrappedKey) iter.Seq[*types.AgeEncryptionKey] { + return func(yield func(*types.AgeEncryptionKey) bool) { + for _, key := range keys { + if key.RecordingEncryptionPair == nil { + continue + } + + if !yield(&types.AgeEncryptionKey{ + PublicKey: key.RecordingEncryptionPair.PublicKey, + }) { + return + } + } + } +} diff --git a/lib/auth/recordingencryption/manager_test.go b/lib/auth/recordingencryption/manager_test.go new file mode 100644 index 0000000000000..d803e5218f65d --- /dev/null +++ b/lib/auth/recordingencryption/manager_test.go @@ -0,0 +1,402 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package recordingencryption_test + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io" + "sync" + "testing" + "time" + + "filippo.io/age" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + recordingencryptionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/recordingencryption/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/auth/recordingencryption" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/utils" +) + +type oaepDecrypter struct { + crypto.Decrypter + hash crypto.Hash +} + +func (d oaepDecrypter) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) ([]byte, error) { + return d.Decrypter.Decrypt(rand, msg, &rsa.OAEPOptions{ + Hash: d.hash, + }) +} + +type fakeKeyStore struct { + keyType types.PrivateKeyType // abusing this field as a way to simulate different auth servers +} + +func (f *fakeKeyStore) NewEncryptionKeyPair(ctx context.Context, purpose cryptosuites.KeyPurpose) (*types.EncryptionKeyPair, error) { + decrypter, err := cryptosuites.GenerateDecrypterWithAlgorithm(cryptosuites.RSA2048) + if err != nil { + return nil, err + } + + private, ok := decrypter.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("expected RSA private key") + } + + privatePEM := pem.EncodeToMemory(&pem.Block{ + Type: keys.PKCS1PrivateKeyType, + Bytes: x509.MarshalPKCS1PrivateKey(private), + }) + + publicPEM := pem.EncodeToMemory(&pem.Block{ + Type: keys.PKCS1PublicKeyType, + Bytes: x509.MarshalPKCS1PublicKey(&private.PublicKey), + }) + + return &types.EncryptionKeyPair{ + PrivateKey: privatePEM, + PublicKey: publicPEM, + PrivateKeyType: f.keyType, + Hash: uint32(crypto.SHA256), + }, nil +} + +func (f *fakeKeyStore) GetDecrypter(ctx context.Context, keyPair *types.EncryptionKeyPair) (crypto.Decrypter, error) { + if keyPair.PrivateKeyType != f.keyType { + return nil, errors.New("could not access decrypter") + } + + block, _ := pem.Decode(keyPair.PrivateKey) + + private, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + return oaepDecrypter{Decrypter: private, hash: crypto.Hash(keyPair.Hash)}, nil +} + +func newLocalBackend( + t *testing.T, +) (context.Context, backend.Backend) { + t.Parallel() + ctx := t.Context() + clock := clockwork.NewRealClock() + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + bk := backend.NewSanitizer(mem) + return ctx, bk +} + +func newManagerConfig(t *testing.T, bk backend.Backend, keyType types.PrivateKeyType) recordingencryption.ManagerConfig { + recordingEncryptionService, err := local.NewRecordingEncryptionService(bk) + require.NoError(t, err) + + clusterConfigService, err := local.NewClusterConfigurationService(bk) + require.NoError(t, err) + + src := &types.SessionRecordingConfigV2{} + require.NoError(t, src.CheckAndSetDefaults()) + src.Spec.Encryption = &types.SessionRecordingEncryptionConfig{ + Enabled: true, + } + + return recordingencryption.ManagerConfig{ + Backend: recordingEncryptionService, + ClusterConfig: clusterConfigService, + KeyStore: &fakeKeyStore{keyType: keyType}, + Logger: utils.NewSlogLoggerForTests(), + LockConfig: backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: bk, + LockNameComponents: []string{"recording_encryption"}, + TTL: 5 * time.Second, + RetryInterval: 10 * time.Millisecond, + }, + }, + } +} + +// resolve is a proxy to Manager.resolveRecordingEncryption through calling UpsertSessionRecordingConfig +func resolve(ctx context.Context, service services.RecordingEncryption, manager *recordingencryption.Manager) (*recordingencryptionv1.RecordingEncryption, types.SessionRecordingConfig, error) { + req := types.SessionRecordingConfigV2{ + Spec: types.SessionRecordingConfigSpecV2{ + Encryption: &types.SessionRecordingEncryptionConfig{ + Enabled: true, + }, + }, + } + if err := req.CheckAndSetDefaults(); err != nil { + return nil, nil, trace.Wrap(err) + } + + src, err := manager.UpsertSessionRecordingConfig(ctx, &req) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + encryption, err := service.GetRecordingEncryption(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return encryption, src, nil +} + +func TestCreateUpdateSessionRecordingConfig(t *testing.T) { + ctx, bk := newLocalBackend(t) + + config := newManagerConfig(t, bk, types.PrivateKeyType_RAW) + manager, err := recordingencryption.NewManager(config) + require.NoError(t, err) + + req := &types.SessionRecordingConfigV2{} + require.NoError(t, req.CheckAndSetDefaults()) + req.Spec.Encryption = &types.SessionRecordingEncryptionConfig{ + Enabled: true, + } + + // create should provision initial keypair and write public key to SRC + src, err := manager.CreateSessionRecordingConfig(ctx, req) + require.NoError(t, err) + encryptionKeys := src.GetEncryptionKeys() + require.Len(t, encryptionKeys, 1) + + encryption, err := config.Backend.GetRecordingEncryption(ctx) + require.NoError(t, err) + activeKeys := encryption.GetSpec().GetActiveKeys() + require.Len(t, activeKeys, 1) + require.NotNil(t, activeKeys[0].RecordingEncryptionPair) + require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PrivateKey) + require.NotEmpty(t, activeKeys[0].RecordingEncryptionPair.PublicKey) + require.NotNil(t, activeKeys[0].KeyEncryptionPair) + require.NotEmpty(t, activeKeys[0].KeyEncryptionPair.PrivateKey) + require.NotEmpty(t, activeKeys[0].KeyEncryptionPair.PublicKey) + + // update should change nothing + src, err = manager.UpdateSessionRecordingConfig(ctx, src) + require.NoError(t, err) + newEncryptionKeys := src.GetEncryptionKeys() + require.ElementsMatch(t, newEncryptionKeys, encryptionKeys) + + encryption, err = config.Backend.GetRecordingEncryption(ctx) + require.NoError(t, err) + newActiveKeys := encryption.GetSpec().GetActiveKeys() + require.ElementsMatch(t, newActiveKeys, activeKeys) +} + +func TestResolveRecordingEncryption(t *testing.T) { + // SETUP + ctx, bk := newLocalBackend(t) + + managerAType := types.PrivateKeyType_RAW + managerBType := types.PrivateKeyType_AWS_KMS + + configA := newManagerConfig(t, bk, managerAType) + configB := configA + configB.KeyStore = &fakeKeyStore{managerBType} + + managerA, err := recordingencryption.NewManager(configA) + require.NoError(t, err) + + managerB, err := recordingencryption.NewManager(configB) + require.NoError(t, err) + + service := configA.Backend + + // TEST + // CASE: service A first evaluation initializes recording encryption resource + encryption, src, err := resolve(ctx, service, managerA) + require.NoError(t, err) + activeKeys := encryption.GetSpec().GetActiveKeys() + + require.Len(t, activeKeys, 1) + require.Len(t, src.GetEncryptionKeys(), 1) + firstKey := activeKeys[0] + + // should generate a wrapped key with the initial recording encryption pair + require.NotNil(t, firstKey.KeyEncryptionPair) + require.NotNil(t, firstKey.RecordingEncryptionPair) + + // CASE: service B should generate an unfulfilled key since there's an existing recording encryption resource + encryption, src, err = resolve(ctx, service, managerB) + require.NoError(t, err) + + activeKeys = encryption.GetSpec().ActiveKeys + require.Len(t, activeKeys, 2) + require.Len(t, src.GetEncryptionKeys(), 1) + for _, key := range activeKeys { + require.NotNil(t, key.KeyEncryptionPair) + if key.KeyEncryptionPair.PrivateKeyType == managerAType { + require.NotNil(t, key.RecordingEncryptionPair) + } else { + require.Nil(t, key.RecordingEncryptionPair) + } + } + + // service B re-evaluting with an unfulfilled key should do nothing + encryption, src, err = resolve(ctx, service, managerB) + require.NoError(t, err) + activeKeys = encryption.GetSpec().ActiveKeys + require.Len(t, activeKeys, 2) + require.Len(t, src.GetEncryptionKeys(), 1) + for _, key := range activeKeys { + require.NotNil(t, key.KeyEncryptionPair) + if key.KeyEncryptionPair.PrivateKeyType == managerAType { + require.NotNil(t, key.RecordingEncryptionPair) + } else { + require.Nil(t, key.RecordingEncryptionPair) + } + } + + // CASE: service A evaluation should fulfill service B's key + encryption, src, err = resolve(ctx, service, managerA) + require.NoError(t, err) + activeKeys = encryption.GetSpec().ActiveKeys + require.Len(t, activeKeys, 2) + require.Len(t, src.GetEncryptionKeys(), 1) + for _, key := range activeKeys { + require.NotNil(t, key.KeyEncryptionPair) + require.NotNil(t, key.RecordingEncryptionPair) + } +} + +func TestResolveRecordingEncryptionConcurrent(t *testing.T) { + // SETUP + ctx, bk := newLocalBackend(t) + + managerAType := types.PrivateKeyType_RAW + managerBType := types.PrivateKeyType_AWS_KMS + serviceCType := types.PrivateKeyType_GCP_KMS + + configA := newManagerConfig(t, bk, managerAType) + configB := configA + configB.KeyStore = &fakeKeyStore{managerBType} + configC := configA + configC.KeyStore = &fakeKeyStore{serviceCType} + recordingEncryptionService := configA.Backend + managerA, err := recordingencryption.NewManager(configA) + require.NoError(t, err) + + managerB, err := recordingencryption.NewManager(configB) + require.NoError(t, err) + + serviceC, err := recordingencryption.NewManager(configC) + require.NoError(t, err) + + service := configA.Backend + resolveFn := func(manager *recordingencryption.Manager, wg *sync.WaitGroup) { + wg.Add(1) + go func() { + defer wg.Done() + resolve(ctx, service, manager) + require.NoError(t, err) + }() + } + + wg := sync.WaitGroup{} + resolveFn(managerA, &wg) + resolveFn(managerB, &wg) + resolveFn(serviceC, &wg) + wg.Wait() + + encryption, err := recordingEncryptionService.GetRecordingEncryption(ctx) + require.NoError(t, err) + + activeKeys := encryption.GetSpec().ActiveKeys + // each service should have an active wrapped key + require.Len(t, activeKeys, 3) + var fulfilledKeys int + for _, key := range activeKeys { + // all wrapped keys should have KeyEncryptionPairs + require.NotNil(t, key.KeyEncryptionPair) + require.NotEmpty(t, key.KeyEncryptionPair.PublicKey) + require.NotEmpty(t, key.KeyEncryptionPair.PrivateKey) + + if key.RecordingEncryptionPair != nil { + fulfilledKeys += 1 + } + } + + // only the first service to run should have a fulfilled wrapped key + require.Equal(t, 1, fulfilledKeys) +} + +func TestFindDecryptionKeyFromActiveKeys(t *testing.T) { + // SETUP + ctx, bk := newLocalBackend(t) + keyTypeA := types.PrivateKeyType_RAW + keyTypeB := types.PrivateKeyType_AWS_KMS + + configA := newManagerConfig(t, bk, keyTypeA) + configB := configA + configB.KeyStore = &fakeKeyStore{keyTypeB} + managerA, err := recordingencryption.NewManager(configA) + require.NoError(t, err) + + managerB, err := recordingencryption.NewManager(configB) + require.NoError(t, err) + + service := configA.Backend + _, _, err = resolve(ctx, service, managerA) + require.NoError(t, err) + + encryption, _, err := resolve(ctx, service, managerB) + require.NoError(t, err) + + activeKeys := encryption.GetSpec().ActiveKeys + require.Len(t, activeKeys, 2) + pubKey := activeKeys[0].RecordingEncryptionPair.PublicKey + + // fail to find private key for manager B because it is waiting for key fulfillment + _, err = managerB.FindDecryptionKey(ctx, pubKey) + require.Error(t, err) + + _, _, err = resolve(ctx, service, managerA) + require.NoError(t, err) + + // find private key for manager A because it provisioned the key + decryptionPair, err := managerA.FindDecryptionKey(ctx, pubKey) + require.NoError(t, err) + ident, err := age.ParseX25519Identity(string(decryptionPair.PrivateKey)) + require.NoError(t, err) + require.Equal(t, ident.Recipient().String(), string(pubKey)) + + // find private key for manager B after fulfillment + decryptionPair, err = managerB.FindDecryptionKey(ctx, pubKey) + require.NoError(t, err) + ident, err = age.ParseX25519Identity(string(decryptionPair.PrivateKey)) + require.NoError(t, err) + require.Equal(t, ident.Recipient().String(), string(pubKey)) +} diff --git a/lib/backend/firestore/firestorebk.go b/lib/backend/firestore/firestorebk.go index 24edbb0f27e01..70a38cf10fed0 100644 --- a/lib/backend/firestore/firestorebk.go +++ b/lib/backend/firestore/firestorebk.go @@ -225,7 +225,6 @@ func newRecordFromDoc(doc *firestore.DocumentSnapshot) (*record, error) { Timestamp: br.Timestamp, Expires: br.Expires, RevisionV2: br.RevisionV2, - snapShot: doc, } default: if err := doc.DataTo(&r); err != nil { @@ -241,11 +240,12 @@ func newRecordFromDoc(doc *firestore.DocumentSnapshot) (*record, error) { Value: []byte(rl.Value), Timestamp: rl.Timestamp, Expires: rl.Expires, - snapShot: doc, } } } + r.snapShot = doc + if r.RevisionV2 == "" { r.RevisionV1 = toRevisionV1(doc.UpdateTime) } diff --git a/lib/cache/bot_instance.go b/lib/cache/bot_instance.go new file mode 100644 index 0000000000000..1334799e5bada --- /dev/null +++ b/lib/cache/bot_instance.go @@ -0,0 +1,154 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "slices" + "strings" + + "github.com/gravitational/trace" + "google.golang.org/protobuf/proto" + + "github.com/gravitational/teleport/api/defaults" + machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/clientutils" + "github.com/gravitational/teleport/lib/services" +) + +type botInstanceIndex string + +const ( + botInstanceNameIndex botInstanceIndex = "name" +) + +func keyForNameIndex(botInstance *machineidv1.BotInstance) string { + return makeNameIndexKey( + botInstance.GetSpec().GetBotName(), + botInstance.GetMetadata().GetName(), + ) +} + +func makeNameIndexKey(botName string, instanceID string) string { + return botName + "/" + instanceID +} + +func newBotInstanceCollection(upstream services.BotInstance, w types.WatchKind) (*collection[*machineidv1.BotInstance, botInstanceIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter upstream (BotInstance)") + } + + return &collection[*machineidv1.BotInstance, botInstanceIndex]{ + store: newStore( + proto.CloneOf[*machineidv1.BotInstance], + map[botInstanceIndex]func(*machineidv1.BotInstance) string{ + // Index on a combination of bot name and instance name + botInstanceNameIndex: keyForNameIndex, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*machineidv1.BotInstance, error) { + var out []*machineidv1.BotInstance + clientutils.IterateResources(ctx, + func(ctx context.Context, limit int, start string) ([]*machineidv1.BotInstance, string, error) { + return upstream.ListBotInstances(ctx, "", limit, start, "") + }, + func(hcc *machineidv1.BotInstance) error { + out = append(out, hcc) + return nil + }, + ) + return out, nil + }, + watch: w, + }, nil +} + +// GetBotInstance returns the specified BotInstance resource. +func (c *Cache) GetBotInstance(ctx context.Context, botName, instanceID string) (*machineidv1.BotInstance, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetBotInstance") + defer span.End() + + getter := genericGetter[*machineidv1.BotInstance, botInstanceIndex]{ + cache: c, + collection: c.collections.botInstances, + index: botInstanceNameIndex, + upstreamGet: func(ctx context.Context, _ string) (*machineidv1.BotInstance, error) { + return c.Config.BotInstanceService.GetBotInstance(ctx, botName, instanceID) + }, + } + + out, err := getter.get(ctx, makeNameIndexKey(botName, instanceID)) + return out, trace.Wrap(err) +} + +// ListBotInstances returns a page of BotInstance resources. +func (c *Cache) ListBotInstances(ctx context.Context, botName string, pageSize int, lastToken string, search string) ([]*machineidv1.BotInstance, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListBotInstances") + defer span.End() + + lister := genericLister[*machineidv1.BotInstance, botInstanceIndex]{ + cache: c, + collection: c.collections.botInstances, + index: botInstanceNameIndex, + defaultPageSize: defaults.DefaultChunkSize, + upstreamList: func(ctx context.Context, limit int, start string) ([]*machineidv1.BotInstance, string, error) { + return c.Config.BotInstanceService.ListBotInstances(ctx, botName, limit, start, search) + }, + filter: func(b *machineidv1.BotInstance) bool { + return matchBotInstance(b, botName, search) + }, + nextToken: func(b *machineidv1.BotInstance) string { + return keyForNameIndex(b) + }, + } + out, next, err := lister.list(ctx, + pageSize, + lastToken, + ) + return out, next, trace.Wrap(err) +} + +func matchBotInstance(b *machineidv1.BotInstance, botName string, search string) bool { + // If updating this, ensure it's consistent with the upstream search logic in `lib/services/local/bot_instance.go`. + + if botName != "" && b.Spec.BotName != botName { + return false + } + + if search == "" { + return true + } + + latestHeartbeats := b.GetStatus().GetLatestHeartbeats() + heartbeat := b.Status.InitialHeartbeat // Use initial heartbeat as a fallback + if len(latestHeartbeats) > 0 { + heartbeat = latestHeartbeats[len(latestHeartbeats)-1] + } + + values := []string{ + b.Spec.BotName, + b.Spec.InstanceId, + } + + if heartbeat != nil { + values = append(values, heartbeat.Hostname, heartbeat.JoinMethod, heartbeat.Version, "v"+heartbeat.Version) + } + + return slices.ContainsFunc(values, func(val string) bool { + return strings.Contains(strings.ToLower(val), strings.ToLower(search)) + }) +} diff --git a/lib/cache/bot_instance_test.go b/lib/cache/bot_instance_test.go new file mode 100644 index 0000000000000..6395eed93f104 --- /dev/null +++ b/lib/cache/bot_instance_test.go @@ -0,0 +1,222 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + "github.com/gravitational/teleport/api/types" +) + +// TestBotInstanceCache tests that CRUD operations on bot instances resources are +// replicated from the backend to the cache. +func TestBotInstanceCache(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources153(t, p, testFuncs153[*machineidv1.BotInstance]{ + newResource: func(key string) (*machineidv1.BotInstance, error) { + return &machineidv1.BotInstance{ + Kind: types.KindBotInstance, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &machineidv1.BotInstanceSpec{ + BotName: "bot-1", + InstanceId: key, + }, + Status: &machineidv1.BotInstanceStatus{}, + }, nil + }, + cacheGet: func(ctx context.Context, key string) (*machineidv1.BotInstance, error) { + return p.cache.GetBotInstance(ctx, "bot-1", key) + }, + cacheList: func(ctx context.Context) ([]*machineidv1.BotInstance, error) { + results, _, err := p.cache.ListBotInstances(ctx, "", 0, "", "") + return results, err + }, + create: func(ctx context.Context, resource *machineidv1.BotInstance) error { + _, err := p.botInstanceService.CreateBotInstance(ctx, resource) + return err + }, + list: func(ctx context.Context) ([]*machineidv1.BotInstance, error) { + results, _, err := p.botInstanceService.ListBotInstances(ctx, "", 0, "", "") + return results, err + }, + update: func(ctx context.Context, bi *machineidv1.BotInstance) error { + _, err := p.botInstanceService.PatchBotInstance(ctx, "bot-1", bi.Metadata.GetName(), func(_ *machineidv1.BotInstance) (*machineidv1.BotInstance, error) { + return bi, nil + }) + return err + }, + delete: func(ctx context.Context, key string) error { + return p.botInstanceService.DeleteBotInstance(ctx, "bot-1", key) + }, + deleteAll: func(ctx context.Context) error { + return p.botInstanceService.DeleteAllBotInstances(ctx) + }, + }) +} + +// TestBotInstanceCachePaging tests that items from the cache are paginated. +func TestBotInstanceCachePaging(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for _, n := range []int{5, 1, 3, 4, 2} { + _, err := p.botInstanceService.CreateBotInstance(ctx, &machineidv1.BotInstance{ + Kind: types.KindBotInstance, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &machineidv1.BotInstanceSpec{ + BotName: "bot-1", + InstanceId: "instance-" + strconv.Itoa(n), + }, + Status: &machineidv1.BotInstanceStatus{}, + }) + require.NoError(t, err) + } + + // Let the cache catch up + require.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := p.cache.GetBotInstance(ctx, "bot-1", "instance-2") + require.NoError(t, err) + }, 2*time.Second, 10*time.Millisecond) + + // page size equal to total items + results, nextPageToken, err := p.cache.ListBotInstances(ctx, "", 0, "", "") + require.NoError(t, err) + require.Empty(t, nextPageToken) + require.Len(t, results, 5) + require.Equal(t, "instance-1", results[0].GetMetadata().GetName()) + require.Equal(t, "instance-2", results[1].GetMetadata().GetName()) + require.Equal(t, "instance-3", results[2].GetMetadata().GetName()) + require.Equal(t, "instance-4", results[3].GetMetadata().GetName()) + require.Equal(t, "instance-5", results[4].GetMetadata().GetName()) + + // page size smaller than total items + results, nextPageToken, err = p.cache.ListBotInstances(ctx, "", 3, "", "") + require.NoError(t, err) + require.Equal(t, "bot-1/instance-4", nextPageToken) + require.Len(t, results, 3) + require.Equal(t, "instance-1", results[0].GetMetadata().GetName()) + require.Equal(t, "instance-2", results[1].GetMetadata().GetName()) + require.Equal(t, "instance-3", results[2].GetMetadata().GetName()) + + // next page + results, nextPageToken, err = p.cache.ListBotInstances(ctx, "", 3, nextPageToken, "") + require.NoError(t, err) + require.Empty(t, nextPageToken) + require.Len(t, results, 2) + require.Equal(t, "instance-4", results[0].GetMetadata().GetName()) + require.Equal(t, "instance-5", results[1].GetMetadata().GetName()) +} + +// TestBotInstanceCacheBotFilter tests that cache items are filtered by bot name. +func TestBotInstanceCacheBotFilter(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for n := range 2 { + for m := range 5 { + _, err := p.botInstanceService.CreateBotInstance(ctx, &machineidv1.BotInstance{ + Kind: types.KindBotInstance, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &machineidv1.BotInstanceSpec{ + BotName: "bot-" + strconv.Itoa(n+1), + InstanceId: "instance-" + strconv.Itoa((n+1)*(m+1)), + }, + Status: &machineidv1.BotInstanceStatus{}, + }) + require.NoError(t, err) + } + } + + // Let the cache catch up + require.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := p.cache.GetBotInstance(ctx, "bot-2", "instance-10") + require.NoError(t, err) + }, 2*time.Second, 10*time.Millisecond) + + results, _, err := p.cache.ListBotInstances(ctx, "bot-2", 0, "", "") + require.NoError(t, err) + require.Len(t, results, 5) + + for _, b := range results { + require.Equal(t, "bot-2", b.GetSpec().GetBotName()) + } +} + +// TestBotInstanceCacheSearchFilter tests that cache items are filtered by search query. +func TestBotInstanceCacheSearchFilter(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + for n := range 10 { + instance := &machineidv1.BotInstance{ + Kind: types.KindBotInstance, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &machineidv1.BotInstanceSpec{ + BotName: "bot-1", + InstanceId: "instance-" + strconv.Itoa(n+1), + }, + Status: &machineidv1.BotInstanceStatus{ + LatestHeartbeats: []*machineidv1.BotInstanceStatusHeartbeat{ + { + Hostname: "host-" + strconv.Itoa(n%2), + }, + }, + }, + } + + _, err := p.botInstanceService.CreateBotInstance(ctx, instance) + require.NoError(t, err) + } + + // Let the cache catch up + require.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := p.cache.GetBotInstance(ctx, "bot-1", "instance-10") + require.NoError(t, err) + }, 2*time.Second, 10*time.Millisecond) + + results, _, err := p.cache.ListBotInstances(ctx, "", 0, "", "host-1") + require.NoError(t, err) + require.Len(t, results, 5) +} diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 93638a776b376..af618d0af9622 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -210,6 +210,7 @@ func ForAuth(cfg Config) Config { {Kind: types.KindWorkloadIdentity}, {Kind: types.KindHealthCheckConfig}, {Kind: types.KindRelayServer}, + {Kind: types.KindBotInstance}, } cfg.QueueSize = defaults.AuthQueueSize // We don't want to enable partial health for auth cache because auth uses an event stream @@ -744,6 +745,8 @@ type Config struct { GitServers services.GitServerGetter // HealthCheckConfig is a health check config service. HealthCheckConfig services.HealthCheckConfigReader + // BotInstanceService is the upstream service that we're caching + BotInstanceService services.BotInstance } // CheckAndSetDefaults checks parameters and sets default values diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index f6f1a26e36cc8..6a8a99ab58d9e 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -149,6 +149,7 @@ type testPack struct { gitServers *local.GitServerService workloadIdentity *local.WorkloadIdentityService healthCheckConfig *local.HealthCheckConfigService + botInstanceService *local.BotInstanceService } // testFuncs are functions to support testing an object in a cache. @@ -427,6 +428,11 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) { return nil, trace.Wrap(err) } + p.botInstanceService, err = local.NewBotInstanceService(p.backend, p.backend.Clock()) + if err != nil { + return nil, trace.Wrap(err) + } + return p, nil } @@ -483,6 +489,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption) GitServers: p.gitServers, HealthCheckConfig: p.healthCheckConfig, WorkloadIdentity: p.workloadIdentity, + BotInstanceService: p.botInstanceService, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -758,6 +765,7 @@ func TestCompletenessInit(t *testing.T) { EventsC: p.eventsC, GitServers: p.gitServers, HealthCheckConfig: p.healthCheckConfig, + BotInstanceService: p.botInstanceService, })) require.NoError(t, err) @@ -845,6 +853,7 @@ func TestCompletenessReset(t *testing.T) { EventsC: p.eventsC, GitServers: p.gitServers, HealthCheckConfig: p.healthCheckConfig, + BotInstanceService: p.botInstanceService, })) require.NoError(t, err) @@ -1003,6 +1012,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) { neverOK: true, // ensure reads are never healthy GitServers: p.gitServers, HealthCheckConfig: p.healthCheckConfig, + BotInstanceService: p.botInstanceService, })) require.NoError(t, err) @@ -1100,6 +1110,7 @@ func initStrategy(t *testing.T) { EventsC: p.eventsC, GitServers: p.gitServers, HealthCheckConfig: p.healthCheckConfig, + BotInstanceService: p.botInstanceService, })) require.NoError(t, err) @@ -1860,6 +1871,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) { scopedrole.KindScopedRole: types.Resource153ToLegacy(&accessv1.ScopedRole{}), scopedrole.KindScopedRoleAssignment: types.Resource153ToLegacy(&accessv1.ScopedRoleAssignment{}), types.KindRelayServer: types.ProtoResource153ToLegacy(new(presencev1.RelayServer)), + types.KindBotInstance: types.ProtoResource153ToLegacy(new(machineidv1.BotInstance)), } for name, cfg := range cases { @@ -1925,6 +1937,8 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) { require.Empty(t, cmp.Diff(resource.(types.Resource153UnwrapperT[*accessv1.ScopedRoleAssignment]).UnwrapT(), uw.UnwrapT(), protocmp.Transform())) case types.Resource153UnwrapperT[*presencev1.RelayServer]: require.Empty(t, cmp.Diff(resource.(types.Resource153UnwrapperT[*presencev1.RelayServer]).UnwrapT(), uw.UnwrapT(), protocmp.Transform())) + case types.Resource153UnwrapperT[*machineidv1.BotInstance]: + require.Empty(t, cmp.Diff(resource.(types.Resource153UnwrapperT[*machineidv1.BotInstance]).UnwrapT(), uw.UnwrapT(), protocmp.Transform())) default: require.Empty(t, cmp.Diff(resource, event.Resource)) } diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 3311edbb3b4bc..799e7c907e5f7 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -139,6 +139,7 @@ type collections struct { secReports *collection[*secreports.Report, securityReportIndex] secReportsStates *collection[*secreports.ReportState, securityReportStateIndex] relayServers *collection[*presencev1.RelayServer, relayServerIndex] + botInstances *collection[*machineidv1.BotInstance, botInstanceIndex] } // isKnownUncollectedKind is true if a resource kind is not stored in @@ -731,6 +732,14 @@ func setupCollections(c Config) (*collections, error) { } out.relayServers = collect out.byKind[resourceKind] = out.relayServers + case types.KindBotInstance: + collect, err := newBotInstanceCollection(c.BotInstanceService, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.botInstances = collect + out.byKind[resourceKind] = out.botInstances default: if _, ok := out.byKind[resourceKind]; !ok { return nil, trace.BadParameter("resource %q is not supported", watch.Kind) diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 7dda432e5e9f6..24da1fcbeed08 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -2782,7 +2782,7 @@ func ConfigureOpenSSH(clf *CommandLineFlags, cfg *servicecfg.Config) error { cfg.Hostname = hostname cfg.OpenSSH.InstanceAddr = clf.Address cfg.OpenSSH.AdditionalPrincipals = []string{hostname, clf.Address} - for _, principal := range strings.Split(clf.AdditionalPrincipals, ",") { + for principal := range strings.SplitSeq(clf.AdditionalPrincipals, ",") { if principal == "" { continue } diff --git a/lib/config/configuration_test.go b/lib/config/configuration_test.go index 3ea7cab3b1682..722e527b2edc6 100644 --- a/lib/config/configuration_test.go +++ b/lib/config/configuration_test.go @@ -241,7 +241,7 @@ func TestBooleanParsing(t *testing.T) { } for i, tc := range testCases { msg := fmt.Sprintf("test case %v", i) - conf, err := ReadFromString(base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(` + conf, err := ReadFromString(base64.StdEncoding.EncodeToString(fmt.Appendf(nil, ` teleport: advertise_ip: 10.10.10.1 proxy_service: @@ -250,7 +250,7 @@ proxy_service: auth_service: enabled: yes disconnect_expired_cert: %v -`, tc.s, tc.s)))) +`, tc.s, tc.s))) require.NoError(t, err, msg) require.Equal(t, tc.b, conf.Proxy.TrustXForwardedFor.Value(), msg) require.Equal(t, tc.b, conf.Auth.DisconnectExpiredCert.Value, msg) @@ -270,13 +270,13 @@ func TestDuration(t *testing.T) { } for i, tc := range testCases { comment := fmt.Sprintf("test case %v", i) - conf, err := ReadFromString(base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(` + conf, err := ReadFromString(base64.StdEncoding.EncodeToString(fmt.Appendf(nil, ` teleport: advertise_ip: 10.10.10.1 auth_service: enabled: yes client_idle_timeout: %v -`, tc.s)))) +`, tc.s))) require.NoError(t, err, comment) require.Equal(t, tc.d, conf.Auth.ClientIdleTimeout.Value(), comment) } @@ -1313,7 +1313,7 @@ func TestTunnelStrategy(t *testing.T) { config string readErr require.ErrorAssertionFunc applyErr require.ErrorAssertionFunc - tunnelStrategy interface{} + tunnelStrategy any }{ { desc: "Ensure default is used when no tunnel strategy is given", @@ -1376,7 +1376,7 @@ func TestTunnelStrategy(t *testing.T) { err = ApplyFileConfig(conf, cfg) tc.applyErr(t, err) - var actualStrategy interface{} + var actualStrategy any if cfg.Auth.NetworkingConfig == nil { } else if s := cfg.Auth.NetworkingConfig.GetAgentMeshTunnelStrategy(); s != nil { actualStrategy = s @@ -1947,10 +1947,10 @@ func TestMergingCAPinConfig(t *testing.T) { t.Run(tt.desc, func(t *testing.T) { clf := CommandLineFlags{ CAPins: tt.cliPins, - ConfigString: base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf( + ConfigString: base64.StdEncoding.EncodeToString(fmt.Appendf(nil, configWithCAPins, tt.configPins, - ))), + )), } cfg := servicecfg.MakeDefaultConfig() require.Empty(t, cfg.CAPins) @@ -2597,7 +2597,7 @@ func TestAppsCLF(t *testing.T) { inAppURI: "", inLegacyAppFlags: true, outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "application name (--app-name) and URI (--app-uri) flags are both required to join application proxy to the cluster") }, @@ -2608,7 +2608,7 @@ func TestAppsCLF(t *testing.T) { inAppName: "", inAppURI: "", outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "to join application proxy to the cluster provide application name (--name) and either URI (--uri) or Cloud type (--cloud)") }, @@ -2620,7 +2620,7 @@ func TestAppsCLF(t *testing.T) { inAppURI: "http://localhost:8080", inLegacyAppFlags: true, outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "application name (--app-name) is required to join application proxy to the cluster") }, @@ -2631,7 +2631,7 @@ func TestAppsCLF(t *testing.T) { inAppName: "", inAppURI: "http://localhost:8080", outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "to join application proxy to the cluster provide application name (--name)") }, @@ -2643,7 +2643,7 @@ func TestAppsCLF(t *testing.T) { inAppURI: "", inLegacyAppFlags: true, outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "URI (--app-uri) flag is required to join application proxy to the cluster") }, @@ -2654,7 +2654,7 @@ func TestAppsCLF(t *testing.T) { inAppName: "foo", inAppURI: "", outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "to join application proxy to the cluster provide URI (--uri) or Cloud type (--cloud)") }, @@ -2693,7 +2693,7 @@ func TestAppsCLF(t *testing.T) { inAppName: "-foo", inAppURI: "http://localhost:8080", outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "application name \"-foo\" must be a lower case valid DNS subdomain: https://goteleport.com/docs/enroll-resources/application-access/guides/connecting-apps/#application-name") }, @@ -2702,7 +2702,7 @@ func TestAppsCLF(t *testing.T) { desc: "missing uri", inAppName: "foo", outApps: nil, - requireError: func(t require.TestingT, err error, i ...interface{}) { + requireError: func(t require.TestingT, err error, i ...any) { require.True(t, trace.IsBadParameter(err)) require.ErrorContains(t, err, "missing application \"foo\" URI") }, @@ -3071,7 +3071,6 @@ func TestDatabaseCLIFlags(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() @@ -3670,11 +3669,11 @@ jamf_service: func TestAuthHostedPlugins(t *testing.T) { t.Parallel() - badParameter := func(t require.TestingT, err error, msgAndArgs ...interface{}) { + badParameter := func(t require.TestingT, err error, msgAndArgs ...any) { require.Error(t, err) require.True(t, trace.IsBadParameter(err), `expected "bad parameter", but got %v`, err) } - notExist := func(t require.TestingT, err error, msgAndArgs ...interface{}) { + notExist := func(t require.TestingT, err error, msgAndArgs ...any) { require.Error(t, err) require.ErrorIs(t, err, os.ErrNotExist, `expected "does not exist", but got %v`, err) } @@ -4024,7 +4023,7 @@ func TestApplyOktaConfig(t *testing.T) { EnabledFlag: "yes", }, }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("okta_service is enabled but no api_endpoint is specified")) }, }, @@ -4037,7 +4036,7 @@ func TestApplyOktaConfig(t *testing.T) { }, APIEndpoint: `bad%url`, }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter(`malformed URL bad%%url`)) }, }, @@ -4050,7 +4049,7 @@ func TestApplyOktaConfig(t *testing.T) { }, APIEndpoint: `http://`, }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("api_endpoint has no host")) }, }, @@ -4063,7 +4062,7 @@ func TestApplyOktaConfig(t *testing.T) { }, APIEndpoint: `//hostname`, }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("api_endpoint has no scheme")) }, }, @@ -4075,7 +4074,7 @@ func TestApplyOktaConfig(t *testing.T) { }, APIEndpoint: "https://test-endpoint", }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("okta_service is enabled but no api_token_path is specified")) }, }, @@ -4088,7 +4087,7 @@ func TestApplyOktaConfig(t *testing.T) { APIEndpoint: "https://test-endpoint", APITokenPath: "/non-existent/path", }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("error trying to find file %s", i...)) }, }, @@ -4104,7 +4103,7 @@ func TestApplyOktaConfig(t *testing.T) { SyncAccessListsFlag: "yes", }, }, - errAssertionFunc: func(tt require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(tt require.TestingT, err error, i ...any) { require.ErrorIs(t, err, trace.BadParameter("default owners must be set when access list import is enabled")) }, }, @@ -4124,7 +4123,7 @@ func TestApplyOktaConfig(t *testing.T) { }, }, }, - errAssertionFunc: func(t require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "error parsing group filter: ^admin-.[[[*$") }, }, @@ -4144,7 +4143,7 @@ func TestApplyOktaConfig(t *testing.T) { }, }, }, - errAssertionFunc: func(t require.TestingT, err error, i ...interface{}) { + errAssertionFunc: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "error parsing app filter: ^admin-.[[[*$") }, }, diff --git a/lib/config/database_test.go b/lib/config/database_test.go index fe5902ed99914..13f30012eb0b3 100644 --- a/lib/config/database_test.go +++ b/lib/config/database_test.go @@ -355,7 +355,6 @@ func TestMakeDatabaseConfig(t *testing.T) { } for name, tt := range tests { - tt := tt t.Run(name, func(t *testing.T) { t.Parallel() configString, err := MakeDatabaseAgentConfigString(tt.flags) diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 42bb529c57d35..da1997b83bb47 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -552,7 +552,7 @@ type LogFormat struct { ExtraFields []string `yaml:"extra_fields,omitempty"` } -func (l *Log) UnmarshalYAML(unmarshal func(interface{}) error) error { +func (l *Log) UnmarshalYAML(unmarshal func(any) error) error { // the next two lines are needed because of an infinite loop issue // https://github.com/go-yaml/yaml/issues/107 type logYAML Log @@ -1516,11 +1516,7 @@ func (ssh *SSH) X11ServerConfig() (*x11.ServerConfig, error) { cfg.DisplayOffset = x11.DefaultDisplayOffset if ssh.X11.DisplayOffset != nil { - cfg.DisplayOffset = int(*ssh.X11.DisplayOffset) - - if cfg.DisplayOffset > x11.MaxDisplayNumber { - cfg.DisplayOffset = x11.MaxDisplayNumber - } + cfg.DisplayOffset = min(int(*ssh.X11.DisplayOffset), x11.MaxDisplayNumber) } cfg.MaxDisplay = cfg.DisplayOffset + x11.DefaultMaxDisplays diff --git a/lib/config/fileconf_test.go b/lib/config/fileconf_test.go index b6cc274be0aa4..e64003a624002 100644 --- a/lib/config/fileconf_test.go +++ b/lib/config/fileconf_test.go @@ -60,7 +60,7 @@ discovery_service: // cfgMap is a shorthand for a type that can hold the nested key-value // representation of a parsed YAML file. -type cfgMap map[interface{}]interface{} +type cfgMap map[any]any // editConfig takes the minimal YAML configuration file, de-serializes it into a // nested key-value dictionary suitable for manipulation by a test case, @@ -79,8 +79,8 @@ func editConfig(t *testing.T, mutate func(cfg cfgMap)) []byte { // requireEqual creates an assertion function with a bound `expected` value // for use with table-driven tests -func requireEqual(expected interface{}) require.ValueAssertionFunc { - return func(t require.TestingT, actual interface{}, msgAndArgs ...interface{}) { +func requireEqual(expected any) require.ValueAssertionFunc { + return func(t require.TestingT, actual any, msgAndArgs ...any) { require.Equal(t, expected, actual, msgAndArgs...) } } @@ -236,8 +236,8 @@ func TestAuthenticationSection(t *testing.T) { "second_factor": "u2f", "u2f": cfgMap{ "app_id": "https://graviton:3080", - "facets": []interface{}{"https://graviton:3080"}, - "device_attestation_cas": []interface{}{ + "facets": []any{"https://graviton:3080"}, + "device_attestation_cas": []any{ "testdata/u2f_attestation_ca.pam", "-----BEGIN CERTIFICATE-----\nfake certificate\n-----END CERTIFICATE-----", }, @@ -266,11 +266,11 @@ func TestAuthenticationSection(t *testing.T) { "second_factor": "webauthn", "webauthn": cfgMap{ "rp_id": "example.com", - "attestation_allowed_cas": []interface{}{ + "attestation_allowed_cas": []any{ "testdata/u2f_attestation_ca.pam", "-----BEGIN CERTIFICATE-----\nfake certificate1\n-----END CERTIFICATE-----", }, - "attestation_denied_cas": []interface{}{ + "attestation_denied_cas": []any{ "-----BEGIN CERTIFICATE-----\nfake certificate2\n-----END CERTIFICATE-----", "testdata/u2f_attestation_ca.pam", }, @@ -300,7 +300,7 @@ func TestAuthenticationSection(t *testing.T) { "second_factor": "on", "u2f": cfgMap{ "app_id": "https://example.com", - "facets": []interface{}{ + "facets": []any{ "https://example.com", }, }, @@ -457,7 +457,7 @@ func TestAuthenticationSection(t *testing.T) { "signature_algorithm_suite": "balanced-v0", } }, - expectError: func(t require.TestingT, err error, msgAndArgs ...interface{}) { + expectError: func(t require.TestingT, err error, msgAndArgs ...any) { require.ErrorContains(t, err, "invalid value: balanced-v0") }, }, { @@ -1020,7 +1020,7 @@ func TestHardwareKeyConfig(t *testing.T) { }, } }, - expectParseError: func(t require.TestingT, err error, i ...interface{}) { + expectParseError: func(t require.TestingT, err error, i ...any) { require.Error(t, err) require.True(t, trace.IsBadParameter(err), "got err = %v, want BadParameter", err) }, @@ -1197,7 +1197,7 @@ func TestX11Config(t *testing.T) { "max_display": 100, } }, - expectConfigError: func(t require.TestingT, err error, i ...interface{}) { + expectConfigError: func(t require.TestingT, err error, i ...any) { require.Error(t, err) require.True(t, trace.IsBadParameter(err), "got err = %v, want BadParameter", err) }, diff --git a/lib/events/athena/consumer.go b/lib/events/athena/consumer.go index 4c51ec53f0fa8..85cd6b6cf1ade 100644 --- a/lib/events/athena/consumer.go +++ b/lib/events/athena/consumer.go @@ -490,7 +490,7 @@ func (s *sqsMessagesCollector) fromSQS(ctx context.Context) { ) wg.Add(s.cfg.noOfWorkers) - for i := 0; i < s.cfg.noOfWorkers; i++ { + for i := range s.cfg.noOfWorkers { go func(i int) { defer wg.Done() for { @@ -904,7 +904,7 @@ func (c *consumer) deleteMessagesFromQueue(ctx context.Context, handles []string var wg sync.WaitGroup // Start the worker goroutines - for i := 0; i < noOfWorkers; i++ { + for range noOfWorkers { wg.Add(1) go func() { defer wg.Done() @@ -935,10 +935,7 @@ func (c *consumer) deleteMessagesFromQueue(ctx context.Context, handles []string // Batch the receipt handles and send them to the worker pool. for i := 0; i < len(handles); i += maxDeleteBatchSize { - end := i + maxDeleteBatchSize - if end > len(handles) { - end = len(handles) - } + end := min(i+maxDeleteBatchSize, len(handles)) workerCh <- handles[i:end] } close(workerCh) diff --git a/lib/events/athena/consumer_test.go b/lib/events/athena/consumer_test.go index b51b87ae7d6bc..59103ae87b7aa 100644 --- a/lib/events/athena/consumer_test.go +++ b/lib/events/athena/consumer_test.go @@ -91,8 +91,7 @@ func Test_consumer_sqsMessagesCollector(t *testing.T) { c := newSqsMessagesCollector(cfg) eventsChan := c.getEventsChan() - readSQSCtx, readCancel := context.WithCancel(context.Background()) - defer readCancel() + readSQSCtx := t.Context() go c.fromSQS(readSQSCtx) // receiver is used to read messages from eventsChan. @@ -164,8 +163,7 @@ func Test_consumer_sqsMessagesCollector(t *testing.T) { eventsChan := c.getEventsChan() - readSQSCtx, readCancel := context.WithCancel(context.Background()) - defer readCancel() + readSQSCtx := t.Context() go c.fromSQS(readSQSCtx) @@ -210,8 +208,7 @@ func Test_consumer_sqsMessagesCollector(t *testing.T) { eventsChan := c.getEventsChan() - readSQSCtx, readCancel := context.WithCancel(context.Background()) - defer readCancel() + readSQSCtx := t.Context() go c.fromSQS(readSQSCtx) @@ -221,7 +218,7 @@ func Test_consumer_sqsMessagesCollector(t *testing.T) { // When over 100 unique days are sent eventsToSend := make([]apievents.AuditEvent, 0, 101) - for i := 0; i < 101; i++ { + for i := range 101 { day := fclock.Now().Add(time.Duration(i) * 24 * time.Hour) eventsToSend = append(eventsToSend, &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent, Time: day}, AppMetadata: apievents.AppMetadata{AppName: "app1"}}) } @@ -666,7 +663,7 @@ func TestDeleteMessagesFromQueue(t *testing.T) { handlesGen := func(n int) []string { out := make([]string, 0, n) - for i := 0; i < n; i++ { + for i := range n { out = append(out, fmt.Sprintf("handle-%d", i)) } return out diff --git a/lib/events/athena/querier_test.go b/lib/events/athena/querier_test.go index 270d6c36110f1..578f05e63328b 100644 --- a/lib/events/athena/querier_test.go +++ b/lib/events/athena/querier_test.go @@ -90,7 +90,7 @@ func TestSearchEvents(t *testing.T) { sliceOfDummyEvents := func(noOfEvents int) []apievents.AuditEvent { out := make([]apievents.AuditEvent, 0, noOfEvents) - for i := 0; i < noOfEvents; i++ { + for range noOfEvents { out = append(out, &apievents.AppCreate{ Metadata: apievents.Metadata{ ID: uuid.NewString(), diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index 7b70a2df24bd4..95cc08da4cb03 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -138,7 +138,7 @@ func TestConcurrentStreaming(t *testing.T) { // on the download that the first one started streams := 2 errors := make(chan error, streams) - for i := 0; i < streams; i++ { + for range streams { go func() { eventsC, errC := alog.StreamSessionEvents(ctx, sid, 0) for { @@ -157,7 +157,7 @@ func TestConcurrentStreaming(t *testing.T) { // This test just verifies that the streamer does not panic when multiple // concurrent streams are waiting on the same download to complete. - for i := 0; i < streams; i++ { + for range streams { <-errors } } diff --git a/lib/events/azsessions/azsessions.go b/lib/events/azsessions/azsessions.go index 1e8e3cffaae96..2c6ad40250f93 100644 --- a/lib/events/azsessions/azsessions.go +++ b/lib/events/azsessions/azsessions.go @@ -411,10 +411,7 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload return trace.Wrap(err) } - m := batchSize - if len(parts[i:]) < batchSize { - m = len(parts[i:]) - } + m := min(len(parts[i:]), batchSize) for _, part := range parts[i : i+m] { if err := batch.Delete(partName(upload, part.Number), nil); err != nil { diff --git a/lib/events/dynamoevents/dynamoevents.go b/lib/events/dynamoevents/dynamoevents.go index ee2297558f21a..317e5016349e9 100644 --- a/lib/events/dynamoevents/dynamoevents.go +++ b/lib/events/dynamoevents/dynamoevents.go @@ -471,7 +471,7 @@ func (l *Log) configureTable(ctx context.Context, svc *applicationautoscaling.Cl // Define scaling policy. Defines the ratio of {read,write} consumed capacity to // provisioned capacity DynamoDB will try and maintain. - for i := 0; i < 2; i++ { + for i := range 2 { if _, err := svc.PutScalingPolicy(ctx, &applicationautoscaling.PutScalingPolicyInput{ PolicyName: aws.String(p.readPolicy), PolicyType: autoscalingtypes.PolicyTypeTargetTrackingScaling, @@ -807,7 +807,7 @@ func (l *Log) GetEventExportChunks(ctx context.Context, req *auditlogpb.GetEvent // The reason that this doesn't fill in the values as literals within the list is to prevent injection attacks. func eventFilterList(amount int) string { var eventTypes []string - for i := 0; i < amount; i++ { + for i := range amount { eventTypes = append(eventTypes, fmt.Sprintf(":eventType%d", i)) } return "(" + strings.Join(eventTypes, ", ") + ")" @@ -1059,7 +1059,7 @@ func getSubPageCheckpoint(e *event) (string, error) { func (l *Log) SearchSessionEvents(ctx context.Context, req events.SearchSessionEventsRequest) ([]apievents.AuditEvent, string, error) { filter := searchEventsFilter{eventTypes: events.SessionRecordingEvents} if req.Cond != nil { - params := condFilterParams{attrValues: make(map[string]interface{}), attrNames: make(map[string]string)} + params := condFilterParams{attrValues: make(map[string]any), attrNames: make(map[string]string)} expr, err := fromWhereExpr(req.Cond, ¶ms) if err != nil { return nil, "", trace.Wrap(err) @@ -1085,7 +1085,7 @@ type searchEventsFilter struct { } type condFilterParams struct { - attrValues map[string]interface{} + attrValues map[string]any attrNames map[string]string } @@ -1115,7 +1115,7 @@ func fromWhereExpr(cond *types.WhereExpr, params *condFilterParams) (string, err return fmt.Sprintf("NOT (%s)", inner), nil } - addAttrValue := func(in interface{}) string { + addAttrValue := func(in any) string { for k, v := range params.attrValues { if in == v { return k @@ -1280,10 +1280,7 @@ func (l *Log) deleteAllItems(ctx context.Context) error { } for len(requests) > 0 { - top := 25 - if top > len(requests) { - top = len(requests) - } + top := min(25, len(requests)) chunk := requests[:top] requests = requests[top:] @@ -1471,7 +1468,7 @@ dateLoop: for i, date := range l.dates { l.checkpoint.Date = date - attributes := map[string]interface{}{ + attributes := map[string]any{ ":date": date, ":start": l.fromUTC.Unix(), ":end": l.toUTC.Unix(), @@ -1556,7 +1553,7 @@ dateLoop: func (l *eventsFetcher) QueryBySessionIDIndex(ctx context.Context, sessionID string, filterExpr *string) (values []event, err error) { query := "SessionID = :id" - attributes := map[string]interface{}{ + attributes := map[string]any{ ":id": sessionID, } for i, eventType := range l.filter.eventTypes { diff --git a/lib/events/dynamoevents/dynamoevents_test.go b/lib/events/dynamoevents/dynamoevents_test.go index a176cff26a354..783c12e694797 100644 --- a/lib/events/dynamoevents/dynamoevents_test.go +++ b/lib/events/dynamoevents/dynamoevents_test.go @@ -159,7 +159,7 @@ func TestSizeBreak(t *testing.T) { blob := randStringAlpha(eventSize) const eventCount int = 10 - for i := 0; i < eventCount; i++ { + for i := range eventCount { err := tt.suite.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ Method: events.LoginMethodSAML, Status: apievents.Status{Success: true}, @@ -168,7 +168,7 @@ func TestSizeBreak(t *testing.T) { Type: events.UserLoginEvent, Time: tt.suite.Clock.Now().UTC().Add(time.Second * time.Duration(i)), }, - IdentityAttributes: apievents.MustEncodeMap(map[string]interface{}{"test.data": blob}), + IdentityAttributes: apievents.MustEncodeMap(map[string]any{"test.data": blob}), }) require.NoError(t, err) } @@ -233,7 +233,7 @@ func TestLargeTableRetrieve(t *testing.T) { tt := setupDynamoContext(t) const eventCount = 4000 - for i := 0; i < eventCount; i++ { + for range eventCount { err := tt.suite.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ Method: events.LoginMethodSAML, Status: apievents.Status{Success: true}, @@ -251,7 +251,7 @@ func TestLargeTableRetrieve(t *testing.T) { err error ) ctx := context.Background() - for i := 0; i < dynamoDBLargeQueryRetries; i++ { + for range dynamoDBLargeQueryRetries { time.Sleep(tt.suite.QueryDelay) history, _, err = tt.suite.Log.SearchEvents(ctx, events.SearchEventsRequest{ @@ -282,14 +282,14 @@ func TestFromWhereExpr(t *testing.T) { R: &types.WhereExpr{Contains: types.WhereExpr2{L: &types.WhereExpr{Field: "participants"}, R: &types.WhereExpr{Literal: "test-user"}}}, }} - params := condFilterParams{attrNames: map[string]string{}, attrValues: map[string]interface{}{}} + params := condFilterParams{attrNames: map[string]string{}, attrValues: map[string]any{}} expr, err := fromWhereExpr(cond, ¶ms) require.NoError(t, err) require.Equal(t, "(NOT ((FieldsMap.#condName0 = :condValue0) OR (FieldsMap.#condName0 = :condValue1))) AND (contains(FieldsMap.#condName1, :condValue2))", expr) require.Equal(t, condFilterParams{ attrNames: map[string]string{"#condName0": "login", "#condName1": "participants"}, - attrValues: map[string]interface{}{":condValue0": "root", ":condValue1": "admin", ":condValue2": "test-user"}, + attrValues: map[string]any{":condValue0": "root", ":condValue1": "admin", ":condValue2": "test-user"}, }, params) } @@ -511,8 +511,8 @@ func TestSearchEventsLimitEndOfDay(t *testing.T) { const eventCount int = 10 // create events for two days - for dayDiff := 0; dayDiff < 2; dayDiff++ { - for i := 0; i < eventCount; i++ { + for dayDiff := range 2 { + for i := range eventCount { err := tt.suite.Log.EmitAuditEvent(ctx, &apievents.UserLogin{ Method: events.LoginMethodSAML, Status: apievents.Status{Success: true}, @@ -521,7 +521,7 @@ func TestSearchEventsLimitEndOfDay(t *testing.T) { Type: events.UserLoginEvent, Time: tt.suite.Clock.Now().UTC().Add(time.Hour*24*time.Duration(dayDiff) + time.Second*time.Duration(i)), }, - IdentityAttributes: apievents.MustEncodeMap(map[string]interface{}{"test.data": blob}), + IdentityAttributes: apievents.MustEncodeMap(map[string]any{"test.data": blob}), }) require.NoError(t, err) } diff --git a/lib/events/dynamoevents/legacy_test.go b/lib/events/dynamoevents/legacy_test.go index 0fd63182c04f1..4f109be40ec1b 100644 --- a/lib/events/dynamoevents/legacy_test.go +++ b/lib/events/dynamoevents/legacy_test.go @@ -33,7 +33,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "binary field": { attributeJSON: `{ "B": "dGVzdAo=", "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberB{}, i1) attr, _ := i1.(*types.AttributeValueMemberB) // Parsed binaries include line feed character. @@ -43,7 +43,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "bool field": { attributeJSON: `{ "B": null, "BOOL": true, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberBOOL{}, i1) attr, _ := i1.(*types.AttributeValueMemberBOOL) require.True(t, attr.Value) @@ -52,7 +52,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "binary set field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": ["aGVsbG8K", "d29ybGQK"], "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberBS{}, i1) attr, _ := i1.(*types.AttributeValueMemberBS) // Parsed binaries include line feed character. @@ -62,7 +62,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "list field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": [{"S": "hello"}, {"S": "world"}], "M": null, "N": null, "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberL{}, i1) attr, _ := i1.(*types.AttributeValueMemberL) require.Len(t, attr.Value, 2) @@ -73,7 +73,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "map field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": {"name": { "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": "test", "SS": null }}, "N": null, "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberM{}, i1) attr, _ := i1.(*types.AttributeValueMemberM) require.Len(t, attr.Value, 1) @@ -88,7 +88,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "number field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": "123.4", "NS": null, "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberN{}, i1) attr, _ := i1.(*types.AttributeValueMemberN) require.Equal(t, "123.4", attr.Value) @@ -97,7 +97,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "number set field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": ["123", "4.5"], "NULL": null, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberNS{}, i1) attr, _ := i1.(*types.AttributeValueMemberNS) require.ElementsMatch(t, []string{"123", "4.5"}, attr.Value) @@ -106,7 +106,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "null field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": true, "S": null, "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberNULL{}, i1) attr, _ := i1.(*types.AttributeValueMemberNULL) require.True(t, attr.Value) @@ -115,7 +115,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "string field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": "test", "SS": null }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberS{}, i1) attr, _ := i1.(*types.AttributeValueMemberS) require.Equal(t, "test", attr.Value) @@ -124,7 +124,7 @@ func TestParseLegacyDynamoAttributes(t *testing.T) { "string set field": { attributeJSON: `{ "B": null, "BOOL": null, "BS": null, "L": null, "M": null, "N": null, "NS": null, "NULL": null, "S": null, "SS": ["hello", "world"] }`, expectConvertError: require.NoError, - expectedAttribute: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + expectedAttribute: func(tt require.TestingT, i1 any, i2 ...any) { require.IsType(t, &types.AttributeValueMemberSS{}, i1) attr, _ := i1.(*types.AttributeValueMemberSS) require.ElementsMatch(t, []string{"hello", "world"}, attr.Value) diff --git a/lib/events/emitter_test.go b/lib/events/emitter_test.go index c7b4cc2076d77..f34b07a5dc7c5 100644 --- a/lib/events/emitter_test.go +++ b/lib/events/emitter_test.go @@ -73,8 +73,7 @@ func TestProtoStreamer(t *testing.T) { }, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() for i, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -175,7 +174,7 @@ func TestAsyncEmitter(t *testing.T) { require.NoError(t, err) } - for i := 0; i < len(evts); i++ { + for i := range evts { select { case event := <-chanEmitter.C(): require.Equal(t, evts[i], event) diff --git a/lib/events/events_test.go b/lib/events/events_test.go index af74ffa593372..031fdae0907e8 100644 --- a/lib/events/events_test.go +++ b/lib/events/events_test.go @@ -273,7 +273,7 @@ func TestJSON(t *testing.T) { type testCase struct { name string json string - event interface{} + event any } testCases := []testCase{ { @@ -506,20 +506,20 @@ func TestJSON(t *testing.T) { UserMetadata: apievents.UserMetadata{ User: "bob@example.com", }, - IdentityAttributes: apievents.MustEncodeMap(map[string]interface{}{ + IdentityAttributes: apievents.MustEncodeMap(map[string]any{ "followers_url": "https://api.github.com/users/bob/followers", "err": nil, "public_repos": 20, "site_admin": false, - "app_metadata": map[string]interface{}{"roles": []interface{}{"example/admins", "example/devc"}}, - "emails": []interface{}{ - map[string]interface{}{ + "app_metadata": map[string]any{"roles": []any{"example/admins", "example/devc"}}, + "emails": []any{ + map[string]any{ "email": "bob@example.com", "primary": true, "verified": true, "visibility": "public", }, - map[string]interface{}{ + map[string]any{ "email": "bob@alternative.com", "primary": false, "verified": true, @@ -1003,7 +1003,6 @@ func TestJSON(t *testing.T) { }, } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -1084,7 +1083,7 @@ func setProtoFields(msg proto.Message) { m := msg.ProtoReflect() fields := m.Descriptor().Fields() - for i := 0; i < fields.Len(); i++ { + for i := range fields.Len() { fd := fields.Get(i) if m.Has(fd) { continue diff --git a/lib/events/eventstest/generate.go b/lib/events/eventstest/generate.go index 25d2e2070d9ef..48d43bcf20bc6 100644 --- a/lib/events/eventstest/generate.go +++ b/lib/events/eventstest/generate.go @@ -125,8 +125,7 @@ func GenerateTestSession(params SessionParams) []apievents.AuditEvent { } genEvents := []apievents.AuditEvent{&sessionStart} - i := int64(0) - for i = 0; i < params.PrintEvents; i++ { + for i := range params.PrintEvents { event := &apievents.SessionPrint{ Metadata: apievents.Metadata{ Index: i + 1, @@ -144,8 +143,7 @@ func GenerateTestSession(params SessionParams) []apievents.AuditEvent { genEvents = append(genEvents, event) } - i++ - sessionEnd.Metadata.Index = i + sessionEnd.Metadata.Index = int64(len(genEvents)) genEvents = append(genEvents, &sessionEnd) return genEvents diff --git a/lib/events/eventstest/uploader.go b/lib/events/eventstest/uploader.go index 99cee3b1e3f94..21daa3f04b9ad 100644 --- a/lib/events/eventstest/uploader.go +++ b/lib/events/eventstest/uploader.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "io" + "slices" "sort" "sync" "time" @@ -210,9 +211,7 @@ func (m *MemoryUploader) GetParts(uploadID string) ([][]byte, error) { for partNumber := range up.parts { partNumbers = append(partNumbers, partNumber) } - sort.Slice(partNumbers, func(i, j int) bool { - return partNumbers[i] < partNumbers[j] - }) + slices.Sort(partNumbers) for _, partNumber := range partNumbers { sortedParts = append(sortedParts, up.parts[partNumber].data) } @@ -242,9 +241,7 @@ func (m *MemoryUploader) ListParts(ctx context.Context, upload events.StreamUplo for partNumber := range up.parts { partNumbers = append(partNumbers, partNumber) } - sort.Slice(partNumbers, func(i, j int) bool { - return partNumbers[i] < partNumbers[j] - }) + slices.Sort(partNumbers) for _, partNumber := range partNumbers { sortedParts = append(sortedParts, events.StreamPart{Number: partNumber}) } diff --git a/lib/events/export/date_exporter.go b/lib/events/export/date_exporter.go index 2ce6d1413356c..fe288792d9ec0 100644 --- a/lib/events/export/date_exporter.go +++ b/lib/events/export/date_exporter.go @@ -22,6 +22,7 @@ import ( "cmp" "context" "log/slog" + "maps" "sync" "sync/atomic" "time" @@ -138,9 +139,7 @@ func (s *DateExporterState) Clone() DateExporterState { Cursors: make(map[string]string, len(s.Cursors)), } copy(cloned.Completed, s.Completed) - for chunk, cursor := range s.Cursors { - cloned.Cursors[chunk] = cursor - } + maps.Copy(cloned.Cursors, s.Cursors) return cloned } @@ -343,12 +342,12 @@ func (e *DateExporter) run(ctx context.Context) { // to halt. func (e *DateExporter) waitForInflightChunks() { // acquire all semaphore tokens to block until all inflight chunks have been processed - for i := 0; i < e.cfg.Concurrency; i++ { + for range e.cfg.Concurrency { e.sem <- struct{}{} } // release all semaphore tokens - for i := 0; i < e.cfg.Concurrency; i++ { + for range e.cfg.Concurrency { <-e.sem } } diff --git a/lib/events/export/date_exporter_test.go b/lib/events/export/date_exporter_test.go index 44f405ab6d173..738d14307aace 100644 --- a/lib/events/export/date_exporter_test.go +++ b/lib/events/export/date_exporter_test.go @@ -228,7 +228,6 @@ func testDateExporterBasics(t *testing.T, randomFlake bool, batch bool) { // TestDateExporterResume verifies non-trivial exporter resumption behavior, with and without // random flake. - func TestDateExporterResume(t *testing.T) { t.Parallel() for _, randomFlake := range []bool{false, true} { @@ -332,8 +331,6 @@ func testDateExporterResume(t *testing.T, randomFlake bool) { // get the final state of the exporter state := exporter.GetState() - fmt.Printf("cursors=%+v\n", state.Cursors) - // recreate exporter with state from previous run exporter, err = NewDateExporter(DateExporterConfig{ Client: clt, diff --git a/lib/events/export/exporter_test.go b/lib/events/export/exporter_test.go index 715626609159f..fdf4238af6787 100644 --- a/lib/events/export/exporter_test.go +++ b/lib/events/export/exporter_test.go @@ -21,6 +21,7 @@ package export import ( "context" "fmt" + "slices" "sync" "testing" "time" @@ -133,7 +134,7 @@ func testExportAll(t *testing.T, tc exportTestCase) { getExported := func() []*auditlogpb.ExportEventUnstructured { exportedMu.Lock() defer exportedMu.Unlock() - return append([]*auditlogpb.ExportEventUnstructured(nil), exported...) + return slices.Clone(exported) } var idleOnce sync.Once diff --git a/lib/events/filelog.go b/lib/events/filelog.go index 508bfa5d1ceeb..25bb4b259fca6 100644 --- a/lib/events/filelog.go +++ b/lib/events/filelog.go @@ -28,6 +28,7 @@ import ( "log/slog" "os" "path/filepath" + "slices" "sort" "strings" "sync" @@ -589,13 +590,7 @@ func (l *FileLog) findInFile(path string, filter searchEventsFilter) ([]EventFie l.logger.WarnContext(context.Background(), "invalid JSON in line found", "file", path, "line_number", lineNo) continue } - accepted := len(filter.eventTypes) == 0 - for _, eventType := range filter.eventTypes { - if ef.GetString(EventType) == eventType { - accepted = true - break - } - } + accepted := len(filter.eventTypes) == 0 || slices.Contains(filter.eventTypes, ef.GetString(EventType)) if !accepted { continue } @@ -657,7 +652,7 @@ func (f ByTimeAndIndex) Swap(i, j int) { } // getTime converts json time to string -func getTime(v interface{}) time.Time { +func getTime(v any) time.Time { sval, ok := v.(string) if !ok { return time.Time{} @@ -669,7 +664,7 @@ func getTime(v interface{}) time.Time { return t } -func getEventIndex(v interface{}) float64 { +func getEventIndex(v any) float64 { switch val := v.(type) { case float64: return val diff --git a/lib/events/filelog_test.go b/lib/events/filelog_test.go index 7ad6a0b7e2195..bf594d2625f35 100644 --- a/lib/events/filelog_test.go +++ b/lib/events/filelog_test.go @@ -273,7 +273,7 @@ func TestLargeEvent(t *testing.T) { // are many very small string fields that will require quoting. func makeLargeMongoQuery() (string, error) { record := map[string]string{"_id": `{"$oid":"63a0dd6da68baaeb828581fe"}`} - for i := 0; i < 100; i++ { + for i := range 100 { t := fmt.Sprintf("%v", i) record[t] = t } diff --git a/lib/events/filesessions/fileasync_chaos_test.go b/lib/events/filesessions/fileasync_chaos_test.go index 26e7a2b522911..e493c9d4caee0 100644 --- a/lib/events/filesessions/fileasync_chaos_test.go +++ b/lib/events/filesessions/fileasync_chaos_test.go @@ -50,8 +50,7 @@ func TestChaosUpload(t *testing.T) { t.Skip("Skipping chaos test in short mode.") } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() eventsC := make(chan events.UploadEvent, 100) memUploader := eventstest.NewMemoryUploader(eventsC) @@ -132,7 +131,7 @@ func TestChaosUpload(t *testing.T) { err error } streamsCh := make(chan streamState, parallelStreams) - for i := 0; i < parallelStreams; i++ { + for range parallelStreams { go func() { inEvents := eventstest.GenerateTestSession(eventstest.SessionParams{PrintEvents: 4096}) sid := inEvents[0].(events.SessionMetadataGetter).GetSessionID() @@ -162,7 +161,7 @@ func TestChaosUpload(t *testing.T) { // wait for all streams to be completed streams := make(map[string]streamState) - for i := 0; i < parallelStreams; i++ { + for range parallelStreams { select { case status := <-streamsCh: require.NoError(t, status.err) @@ -174,7 +173,7 @@ func TestChaosUpload(t *testing.T) { require.Len(t, streams, parallelStreams) - for i := 0; i < parallelStreams; i++ { + for range parallelStreams { select { case event := <-eventsC: require.NoError(t, event.Error) diff --git a/lib/events/filesessions/fileasync_test.go b/lib/events/filesessions/fileasync_test.go index a015bde8ec15c..15811091accdd 100644 --- a/lib/events/filesessions/fileasync_test.go +++ b/lib/events/filesessions/fileasync_test.go @@ -84,7 +84,7 @@ func TestUploadParallel(t *testing.T) { sessions := make(map[string][]apievents.AuditEvent) - for i := 0; i < 5; i++ { + for range 5 { fileStreamer, err := NewStreamer(p.scanDir) require.NoError(t, err) @@ -406,7 +406,7 @@ func TestUploadBackoff(t *testing.T) { attempts := 10 var prev time.Time var diffs []time.Duration - for i := 0; i < attempts; i++ { + for i := range attempts { // wait for the upload event var event events.UploadEvent select { @@ -611,7 +611,7 @@ func runResume(t *testing.T, testCase resumeTestCase) { t.Fatalf("Timeout waiting for async upload, try `go test -v` to get more logs for details") } - for i := 0; i < testCase.retries; i++ { + for i := range testCase.retries { if testCase.onRetry != nil { testCase.onRetry(t, i, uploader) } diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index d58d60e054d26..2cb5ac941921d 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -151,7 +151,7 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload } unlock, err := utils.FSTryWriteLock(uploadPath) Loop: - for i := 0; i < 3; i++ { + for range 3 { switch { case err == nil: break Loop diff --git a/lib/events/firestoreevents/firestoreevents.go b/lib/events/firestoreevents/firestoreevents.go index 6c1e5c6eb72dd..6ee9a1ab5ba5d 100644 --- a/lib/events/firestoreevents/firestoreevents.go +++ b/lib/events/firestoreevents/firestoreevents.go @@ -498,7 +498,7 @@ func (l *Log) query( // Iterate over the documents in the query. // The iterator is limited to [limit] documents so in order to know if we // have more pages to read when filtering, we can read only [limit] documents. - for i := 0; i < limit; i++ { + for range limit { docSnap, err := fstoreIterator.Next() if errors.Is(err, iterator.Done) { // iterator.Done is returned when there are no more documents to read. diff --git a/lib/events/firestoreevents/firestoreevents_test.go b/lib/events/firestoreevents/firestoreevents_test.go index 6b34a5cdfc0f3..f5b720f87d7ff 100644 --- a/lib/events/firestoreevents/firestoreevents_test.go +++ b/lib/events/firestoreevents/firestoreevents_test.go @@ -53,7 +53,7 @@ func setupFirestoreContext(t *testing.T) *firestoreContext { fakeClock := clockwork.NewFakeClock() config := EventsConfig{} - config.SetFromParams(map[string]interface{}{ + config.SetFromParams(map[string]any{ "collection_name": "tp-events-test", "project_id": "tp-testproj", "endpoint": "localhost:8618", diff --git a/lib/events/gcssessions/gcsstream.go b/lib/events/gcssessions/gcsstream.go index 8e1d47bbbeb5f..b20c3a834a911 100644 --- a/lib/events/gcssessions/gcsstream.go +++ b/lib/events/gcssessions/gcsstream.go @@ -216,7 +216,7 @@ func (h *Handler) cleanupUpload(ctx context.Context, upload events.StreamUpload) func (h *Handler) partsToObjects(upload events.StreamUpload, parts []events.StreamPart) []*storage.ObjectHandle { objects := make([]*storage.ObjectHandle, len(parts)) bucket := h.gcsClient.Bucket(h.Config.Bucket) - for i := 0; i < len(parts); i++ { + for i := range parts { objects[i] = bucket.Object(h.partPath(upload, parts[i].Number)) } return objects diff --git a/lib/events/pgevents/utils_test.go b/lib/events/pgevents/utils_test.go index 532e75fb40a8d..b2938ea37becb 100644 --- a/lib/events/pgevents/utils_test.go +++ b/lib/events/pgevents/utils_test.go @@ -32,7 +32,7 @@ import ( func TestPaginationKeyRoundtrip(t *testing.T) { t.Parallel() - for i := 0; i < 1000; i++ { + for range 1000 { var b [24]byte _, err := rand.Read(b[:]) require.NoError(t, err) @@ -45,7 +45,7 @@ func TestPaginationKeyRoundtrip(t *testing.T) { require.Equal(t, startKey, toNextKey(eventTime, eventID)) } - for i := 0; i < 1000; i++ { + for range 1000 { var b [8]byte _, err := rand.Read(b[:]) require.NoError(t, err) diff --git a/lib/events/search_limiter_test.go b/lib/events/search_limiter_test.go index d6eede465071a..bbd393954c5c0 100644 --- a/lib/events/search_limiter_test.go +++ b/lib/events/search_limiter_test.go @@ -43,7 +43,7 @@ func TestSearchEventsLimiter(t *testing.T) { }, }) require.NoError(t, err) - for i := 0; i < 20; i++ { + for range 20 { require.NoError(t, s.EmitAuditEvent(context.Background(), &apievents.AccessRequestCreate{})) } }) @@ -63,7 +63,7 @@ func TestSearchEventsLimiter(t *testing.T) { someDate := clockwork.NewFakeClock().Now().UTC() ctx := context.Background() - for i := 0; i < burst; i++ { + for i := range burst { var err error // rate limit is shared between both search endpoints. if i%2 == 0 { diff --git a/lib/events/session_writer.go b/lib/events/session_writer.go index 15dc4fd72adbc..479962dee41dd 100644 --- a/lib/events/session_writer.go +++ b/lib/events/session_writer.go @@ -547,7 +547,7 @@ func (a *SessionWriter) tryResumeStream() (apievents.Stream, error) { } var resumedStream apievents.Stream start := time.Now() - for i := 0; i < FastAttempts; i++ { + for i := range FastAttempts { var streamType string if a.lastStatus == nil { // The stream was either never created or has failed to receive the @@ -605,7 +605,7 @@ func (a *SessionWriter) updateStatus(status apievents.StreamStatus) { return } lastIndex := -1 - for i := 0; i < len(a.buffer); i++ { + for i := range a.buffer { if status.LastEventIndex < a.buffer[i].GetAuditEvent().GetIndex() { break } diff --git a/lib/events/sessionlog.go b/lib/events/sessionlog.go index e0b43a50a1671..c4b032fad3cb8 100644 --- a/lib/events/sessionlog.go +++ b/lib/events/sessionlog.go @@ -53,7 +53,7 @@ func (f *gzipWriter) Close() error { // so it makes sense to reset the writer and reuse the // internal buffers to avoid too many objects on the heap var writerPool = sync.Pool{ - New: func() interface{} { + New: func() any { w, _ := gzip.NewWriterLevel(io.Discard, gzip.BestSpeed) return w }, diff --git a/lib/events/setter_test.go b/lib/events/setter_test.go index 8ad8a396f927f..64801de4ed075 100644 --- a/lib/events/setter_test.go +++ b/lib/events/setter_test.go @@ -37,7 +37,7 @@ func TestPreparerIncrementalIndex(t *testing.T) { }) require.NoError(t, err) - for i := 0; i < 10; i++ { + for i := range 10 { e, err := preparer.PrepareSessionEvent(generateEvent()) require.NoError(t, err) require.Equal(t, int64(i), e.GetAuditEvent().GetIndex(), "unexpected event index") @@ -56,7 +56,7 @@ func TestPreparerTimeBasedIndex(t *testing.T) { require.NoError(t, err) var lastIndex int64 - for i := 0; i < 9; i++ { + for range 9 { clock.Advance(time.Second) e, err := preparer.PrepareSessionEvent(generateEvent()) require.NoError(t, err) @@ -90,7 +90,7 @@ func TestPreparerTimeBasedIndexCollisions(t *testing.T) { }) require.NoError(t, err) - for i := 0; i < 9; i++ { + for range 9 { clock.Advance(time.Second) evtOne, err := preparerOne.PrepareSessionEvent(generateEvent()) require.NoError(t, err) diff --git a/lib/events/stream.go b/lib/events/stream.go index 5792bc5ad9f37..33148b6f406fa 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -571,10 +571,7 @@ func (w *sliceWriter) receiveAndUpload() error { } case <-flushCh: now := clock.Now().UTC() - inactivityPeriod := now.Sub(lastEvent) - if inactivityPeriod < 0 { - inactivityPeriod = 0 - } + inactivityPeriod := max(now.Sub(lastEvent), 0) if inactivityPeriod >= w.proto.cfg.InactivityFlushPeriod { // inactivity period exceeded threshold, // there is no need to schedule a timer until the next @@ -769,7 +766,7 @@ func (w *sliceWriter) startUpload(partNumber int64, slice *slice) (*activeUpload return } - for i := 0; i < defaults.MaxIterationLimit; i++ { + for i := range defaults.MaxIterationLimit { log := log.With("attempt", i) part, err := w.proto.cfg.Uploader.UploadPart(w.proto.cancelCtx, w.proto.cfg.Upload, partNumber, reader) diff --git a/lib/events/stream_test.go b/lib/events/stream_test.go index a9d0ee1c94f60..85ad36ebc8508 100644 --- a/lib/events/stream_test.go +++ b/lib/events/stream_test.go @@ -206,8 +206,7 @@ func TestProtoStreamLargeEvent(t *testing.T) { // TestReadCorruptedRecording tests that the streamer can successfully decode the kind of corrupted // recordings that some older bugged versions of teleport might end up producing when under heavy load/throttling. func TestReadCorruptedRecording(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() f, err := os.Open("testdata/corrupted-session") require.NoError(t, err) diff --git a/lib/events/test/suite.go b/lib/events/test/suite.go index df228c94d44e3..ccd6b1fa9bf93 100644 --- a/lib/events/test/suite.go +++ b/lib/events/test/suite.go @@ -387,7 +387,7 @@ func (s *EventsSuite) EventPagination(t *testing.T) { } Outer: - for i := 0; i < len(names); i++ { + for range names { arr, checkpoint, err = s.Log.SearchEvents(ctx, events.SearchEventsRequest{ From: baseTime2, To: baseTime2.Add(time.Second), diff --git a/lib/kube/grpc/grpc_test.go b/lib/kube/grpc/grpc_test.go index 94a7bd38ac858..5f380f5aef9e7 100644 --- a/lib/kube/grpc/grpc_test.go +++ b/lib/kube/grpc/grpc_test.go @@ -593,7 +593,6 @@ func TestListKubernetesResources(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() _, restCfg := testCtx.GenTestKubeClientTLSCert(t, tt.args.user.GetName(), "") diff --git a/lib/kube/grpc/websocket_client_test.go b/lib/kube/grpc/websocket_client_test.go index 444de6a9eab85..1f9480a064e85 100644 --- a/lib/kube/grpc/websocket_client_test.go +++ b/lib/kube/grpc/websocket_client_test.go @@ -28,6 +28,7 @@ import ( "net" "net/http" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -188,13 +189,7 @@ func (e *wsStreamClient) Stream(options clientremotecommand.StreamOptions) error defer conn.Close() streamingProto := conn.Subprotocol() - found := false - for _, p := range supportedProtocols { - if p == streamingProto { - found = true - break - } - } + found := slices.Contains(supportedProtocols, streamingProto) if !found { return fmt.Errorf("unsupported streaming protocol: %q", streamingProto) } @@ -218,13 +213,7 @@ func (e *wsStreamClient) ForwardPorts() error { defer conn.Close() streamingProto := conn.Subprotocol() - found := false - for _, p := range supportedProtocols { - if p == streamingProto { - found = true - break - } - } + found := slices.Contains(supportedProtocols, streamingProto) if !found { return fmt.Errorf("unsupported streaming protocol: %q", streamingProto) } diff --git a/lib/kube/kubeconfig/kubeconfig.go b/lib/kube/kubeconfig/kubeconfig.go index c77f9ce39afff..9bbaa36982a0f 100644 --- a/lib/kube/kubeconfig/kubeconfig.go +++ b/lib/kube/kubeconfig/kubeconfig.go @@ -290,7 +290,7 @@ func setContext(contexts map[string]*clientcmdapi.Context, name, cluster, auth, if kubeName != "" { newContext.Extensions[teleportKubeClusterNameExtension] = &runtime.Unknown{ // We need to wrap the kubeName in quotes to make sure it is parsed as a string. - Raw: []byte(fmt.Sprintf("%q", kubeName)), + Raw: fmt.Appendf(nil, "%q", kubeName), } } diff --git a/lib/kube/kubeconfig/kubeconfig_test.go b/lib/kube/kubeconfig/kubeconfig_test.go index 7a100d496f57e..02267919adbd9 100644 --- a/lib/kube/kubeconfig/kubeconfig_test.go +++ b/lib/kube/kubeconfig/kubeconfig_test.go @@ -326,7 +326,7 @@ func TestUpdateWithExec(t *testing.T) { LocationOfOrigin: kubeconfigPath, Extensions: map[string]runtime.Object{ teleportKubeClusterNameExtension: &runtime.Unknown{ - Raw: []byte(fmt.Sprintf("%q", kubeCluster)), + Raw: fmt.Appendf(nil, "%q", kubeCluster), ContentType: "application/json", }, }, @@ -398,7 +398,7 @@ func TestUpdateWithExecAndProxy(t *testing.T) { LocationOfOrigin: kubeconfigPath, Extensions: map[string]runtime.Object{ teleportKubeClusterNameExtension: &runtime.Unknown{ - Raw: []byte(fmt.Sprintf("%q", kubeCluster)), + Raw: fmt.Appendf(nil, "%q", kubeCluster), ContentType: "application/json", }, }, diff --git a/lib/kube/proxy/auth_test.go b/lib/kube/proxy/auth_test.go index 0d7cc19a37bc4..37d8c0f7e460a 100644 --- a/lib/kube/proxy/auth_test.go +++ b/lib/kube/proxy/auth_test.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "path/filepath" + "slices" "testing" "github.com/google/go-cmp/cmp" @@ -97,17 +98,11 @@ func (c *mockSARClient) Create(_ context.Context, sar *authzapi.SelfSubjectAcces } var verbAllowed, resourceAllowed bool - for _, v := range c.allowedVerbs { - if v == sar.Spec.ResourceAttributes.Verb { - verbAllowed = true - break - } + if slices.Contains(c.allowedVerbs, sar.Spec.ResourceAttributes.Verb) { + verbAllowed = true } - for _, r := range c.allowedResources { - if r == sar.Spec.ResourceAttributes.Resource { - resourceAllowed = true - break - } + if slices.Contains(c.allowedResources, sar.Spec.ResourceAttributes.Resource) { + resourceAllowed = true } sar.Status.Allowed = verbAllowed && resourceAllowed diff --git a/lib/kube/proxy/exec_test.go b/lib/kube/proxy/exec_test.go index cc0ed45908870..eaf9bc37bbe59 100644 --- a/lib/kube/proxy/exec_test.go +++ b/lib/kube/proxy/exec_test.go @@ -349,7 +349,6 @@ func TestExecMissingGETPermissionError(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() const errorCode = http.StatusForbidden @@ -544,7 +543,6 @@ func TestExecWebsocketEndToEndErrReturn(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index 8405c6e217d9b..3bcbb81e3e93b 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -930,7 +930,6 @@ func TestSetupImpersonationHeaders(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.desc, func(t *testing.T) { var kubeCreds kubeCreds if !tt.isProxy { @@ -1307,8 +1306,7 @@ func (m *mockSemaphoreClient) GetRole(ctx context.Context, name string) (types.R } func TestKubernetesConnectionLimit(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() type testCase struct { name string @@ -1463,7 +1461,7 @@ func TestKubernetesLicenseEnforcement(t *testing.T) { string(entitlements.K8s): {Enabled: false}, }, }, - assertErrFunc: func(tt require.TestingT, err error, i ...interface{}) { + assertErrFunc: func(tt require.TestingT, err error, i ...any) { require.Error(tt, err) var kubeErr *kubeerrors.StatusError require.ErrorAs(tt, err, &kubeErr) @@ -1474,7 +1472,6 @@ func TestKubernetesLicenseEnforcement(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() // creates a Kubernetes service with a configured cluster pointing to mock api server diff --git a/lib/kube/proxy/kube_creds_test.go b/lib/kube/proxy/kube_creds_test.go index ca2f537e6de05..f9ac76e63cef0 100644 --- a/lib/kube/proxy/kube_creds_test.go +++ b/lib/kube/proxy/kube_creds_test.go @@ -377,7 +377,7 @@ func Test_DynamicKubeCreds(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("timeout waiting for cluster to be ready") } - for i := 0; i < 10; i++ { + for i := range 10 { require.Equal(t, got.getKubeRestConfig().CAData, []byte(fixtures.TLSCACertPEM)) require.NoError(t, tt.args.validateBearerToken(got.getKubeRestConfig().BearerToken)) require.Equal(t, tt.wantAddr, got.getTargetAddr()) diff --git a/lib/kube/proxy/moderated_sessions_test.go b/lib/kube/proxy/moderated_sessions_test.go index 1b2abc5704ea2..0b0e15b9ff6db 100644 --- a/lib/kube/proxy/moderated_sessions_test.go +++ b/lib/kube/proxy/moderated_sessions_test.go @@ -246,7 +246,6 @@ func TestModeratedSessions(t *testing.T) { }, } for _, tt := range tests { - tt := tt if tt.want.sessionEndEvent { numberOfExpectedSessionEnds++ } @@ -649,7 +648,6 @@ func TestInteractiveSessionsNoAuth(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() diff --git a/lib/kube/proxy/portforward_test.go b/lib/kube/proxy/portforward_test.go index 7dc7a147e22b7..b218aa7f21c7b 100644 --- a/lib/kube/proxy/portforward_test.go +++ b/lib/kube/proxy/portforward_test.go @@ -480,7 +480,6 @@ func TestPortForwardUnderlyingProtocol(t *testing.T) { }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() kubeMock, err := testingkubemock.NewKubeAPIMock( diff --git a/lib/kube/proxy/portforward_websocket.go b/lib/kube/proxy/portforward_websocket.go index 4ca01ea473062..beaec43c43647 100644 --- a/lib/kube/proxy/portforward_websocket.go +++ b/lib/kube/proxy/portforward_websocket.go @@ -73,7 +73,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error { // One pair of (Data,Error) channels per port. channels := make([]wsstream.ChannelType, 2*len(ports)) - for i := 0; i < len(channels); i++ { + for i := range channels { channels[i] = wsstream.ReadWriteChannel } @@ -107,7 +107,7 @@ func runPortForwardingWebSocket(req portForwardRequest) error { // Create the websocket stream pairs. streamPairs := make([]*websocketChannelPair, len(ports)) - for i := 0; i < len(ports); i++ { + for i := range ports { var ( dataStream = streams[2*i+portForwardDataChannel] errorStream = streams[2*i+portForwardErrorChannel] @@ -171,7 +171,7 @@ func extractTargetPortsFromStrings(portsStrings []string) ([]uint16, error) { if len(portString) == 0 { return nil, trace.BadParameter("query parameter %q cannot be empty", PortHeader) } - for _, p := range strings.Split(portString, ",") { + for p := range strings.SplitSeq(portString, ",") { port, err := parsePortString(p) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/kube/proxy/resource_deletecollection.go b/lib/kube/proxy/resource_deletecollection.go index 8bf2e15c57c07..3bb95180b2260 100644 --- a/lib/kube/proxy/resource_deletecollection.go +++ b/lib/kube/proxy/resource_deletecollection.go @@ -277,7 +277,7 @@ func getItemsUsingReflection(obj runtime.Object) (getItemsUsingReflectionOutput, func setItemsUsingReflection(itemsR reflect.Value, underlyingType reflect.Type, items []kubeObjectInterface) { // make a new slice of the same type as the original one. slice := reflect.MakeSlice(itemsR.Type(), len(items), len(items)) - for i := 0; i < len(items); i++ { + for i := range items { item := items[i] // convert the item to the underlying type of the slice. // this is needed because items is a slice of pointers that diff --git a/lib/kube/proxy/resource_filters_test.go b/lib/kube/proxy/resource_filters_test.go index 22bdede072282..156c5311b4544 100644 --- a/lib/kube/proxy/resource_filters_test.go +++ b/lib/kube/proxy/resource_filters_test.go @@ -167,7 +167,7 @@ func Test_filterBuffer(t *testing.T) { require.NoError(t, err) data := &bytes.Buffer{} name := filepath.Base(tt.args.dataFile) - err = temp.ExecuteTemplate(data, name, map[string]interface{}{ + err = temp.ExecuteTemplate(data, name, map[string]any{ "Kind": teleToKubeResource[r].obj, "API": teleToKubeResource[r].api, }, diff --git a/lib/kube/proxy/resource_rbac_test.go b/lib/kube/proxy/resource_rbac_test.go index 9a9dd6d196247..ab999b090abe5 100644 --- a/lib/kube/proxy/resource_rbac_test.go +++ b/lib/kube/proxy/resource_rbac_test.go @@ -996,7 +996,6 @@ func TestDeletePodCollectionRBAC(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() requestID := kubetypes.UID(uuid.NewString()) diff --git a/lib/kube/proxy/response_rewriter_test.go b/lib/kube/proxy/response_rewriter_test.go index 6460eb27e1e28..fe0f8e20df199 100644 --- a/lib/kube/proxy/response_rewriter_test.go +++ b/lib/kube/proxy/response_rewriter_test.go @@ -147,7 +147,6 @@ func TestErrorRewriter(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() // generate a kube client with user certs for auth diff --git a/lib/kube/proxy/server_test.go b/lib/kube/proxy/server_test.go index 31db4978d93f3..684773af44b53 100644 --- a/lib/kube/proxy/server_test.go +++ b/lib/kube/proxy/server_test.go @@ -182,7 +182,7 @@ func TestMTLSClientCAs(t *testing.T) { // 100 additional CAs registered, all CAs should be sent to the client in // the handshake. t.Run("100 CAs", func(t *testing.T) { - for i := 0; i < 100; i++ { + for i := range 100 { addCA(t, fmt.Sprintf("cluster-%d", i)) } testDial(t, 101) diff --git a/lib/kube/proxy/sess_test.go b/lib/kube/proxy/sess_test.go index 34bcd0ec198f2..825d04b6e0257 100644 --- a/lib/kube/proxy/sess_test.go +++ b/lib/kube/proxy/sess_test.go @@ -92,7 +92,6 @@ func TestSessionEndError(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var ( diff --git a/lib/kube/proxy/streamproto/proto.go b/lib/kube/proxy/streamproto/proto.go index 75f12b15e8b22..c333b24091427 100644 --- a/lib/kube/proxy/streamproto/proto.go +++ b/lib/kube/proxy/streamproto/proto.go @@ -177,7 +177,7 @@ func (s *SessionStream) readTask() { // If it's a close error, we want to send a message to the stdout if s.isClient && errors.As(err, &closeErr) && closeErr.Text != "" { select { - case s.in <- []byte(fmt.Sprintf("\r\n---\r\nConnection closed: %v\r\n", closeErr.Text)): + case s.in <- fmt.Appendf(nil, "\r\n---\r\nConnection closed: %v\r\n", closeErr.Text): case <-s.done: return } diff --git a/lib/kube/proxy/websocket_client_test.go b/lib/kube/proxy/websocket_client_test.go index 75fb57b5133ad..47a674ea15406 100644 --- a/lib/kube/proxy/websocket_client_test.go +++ b/lib/kube/proxy/websocket_client_test.go @@ -28,6 +28,7 @@ import ( "net" "net/http" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -182,13 +183,7 @@ func (e *wsStreamClient) Stream(options clientremotecommand.StreamOptions) error defer conn.Close() streamingProto := conn.Subprotocol() - found := false - for _, p := range supportedProtocols { - if p == streamingProto { - found = true - break - } - } + found := slices.Contains(supportedProtocols, streamingProto) if !found { return fmt.Errorf("unsupported streaming protocol: %q", streamingProto) } @@ -212,13 +207,7 @@ func (e *wsStreamClient) ForwardPorts() error { defer conn.Close() streamingProto := conn.Subprotocol() - found := false - for _, p := range supportedProtocols { - if p == streamingProto { - found = true - break - } - } + found := slices.Contains(supportedProtocols, streamingProto) if !found { return fmt.Errorf("unsupported streaming protocol: %q", streamingProto) } diff --git a/lib/kube/token/source_test.go b/lib/kube/token/source_test.go index 9d3e5fd5a4092..3c5863f1256eb 100644 --- a/lib/kube/token/source_test.go +++ b/lib/kube/token/source_test.go @@ -79,7 +79,7 @@ func TestGetIDToken(t *testing.T) { name: "no-token-no-var", getEnv: fakeGetEnv(""), readFile: fakeReadFile("foobarbizz", "/custom"), - assertError: func(t require.TestingT, err error, i ...interface{}) { + assertError: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, kubernetesDefaultTokenPath+": no such file") }, }, @@ -87,7 +87,7 @@ func TestGetIDToken(t *testing.T) { name: "no-token-with-var", getEnv: fakeGetEnv("/custom"), readFile: fakeReadFile("foobarbizz", kubernetesDefaultTokenPath), - assertError: func(t require.TestingT, err error, i ...interface{}) { + assertError: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "/custom: no such file") }, }, diff --git a/lib/kube/token/validator_test.go b/lib/kube/token/validator_test.go index 70d68fddb766d..c50455681eecc 100644 --- a/lib/kube/token/validator_test.go +++ b/lib/kube/token/validator_test.go @@ -365,7 +365,6 @@ func TestIDTokenValidator_Validate(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.token, func(t *testing.T) { // Fill value of raw to avoid duplication in test table if tt.wantResult != nil { diff --git a/lib/service/service.go b/lib/service/service.go index 4a713980fcb50..1cabe0238f1bc 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2724,6 +2724,7 @@ func (process *TeleportProcess) newAccessCacheForServices(cfg accesspoint.Config cfg.PluginStaticCredentials = services.PluginStaticCredentials cfg.GitServers = services.GitServers cfg.HealthCheckConfig = services.HealthCheckConfig + cfg.BotInstance = services.BotInstance return accesspoint.NewCache(cfg) } diff --git a/lib/services/local/bot_instance.go b/lib/services/local/bot_instance.go index 0a31674719838..aec504d0309db 100644 --- a/lib/services/local/bot_instance.go +++ b/lib/services/local/bot_instance.go @@ -110,27 +110,41 @@ func (b *BotInstanceService) ListBotInstances(ctx context.Context, botName strin } r, nextToken, err := service.ListResourcesWithFilter(ctx, pageSize, lastKey, func(item *machineidv1.BotInstance) bool { - latestHeartbeats := item.GetStatus().GetLatestHeartbeats() - heartbeat := item.Status.InitialHeartbeat // Use initial heartbeat as a fallback - if len(latestHeartbeats) > 0 { - heartbeat = latestHeartbeats[len(latestHeartbeats)-1] - } + return matchBotInstance(item, botName, search) + }) - values := []string{ - item.Spec.BotName, - item.Spec.InstanceId, - } + return r, nextToken, trace.Wrap(err) +} - if heartbeat != nil { - values = append(values, heartbeat.Hostname, heartbeat.JoinMethod, heartbeat.Version, "v"+heartbeat.Version) - } +func matchBotInstance(b *machineidv1.BotInstance, botName string, search string) bool { + // If updating this, ensure it's consistent with the cache search logic in `lib/cache/bot_instance.go`. - return slices.ContainsFunc(values, func(val string) bool { - return strings.Contains(strings.ToLower(val), strings.ToLower(search)) - }) - }) + if botName != "" && b.Spec.BotName != botName { + return false + } - return r, nextToken, trace.Wrap(err) + if search == "" { + return true + } + + latestHeartbeats := b.GetStatus().GetLatestHeartbeats() + heartbeat := b.Status.InitialHeartbeat // Use initial heartbeat as a fallback + if len(latestHeartbeats) > 0 { + heartbeat = latestHeartbeats[len(latestHeartbeats)-1] + } + + values := []string{ + b.Spec.BotName, + b.Spec.InstanceId, + } + + if heartbeat != nil { + values = append(values, heartbeat.Hostname, heartbeat.JoinMethod, heartbeat.Version, "v"+heartbeat.Version) + } + + return slices.ContainsFunc(values, func(val string) bool { + return strings.Contains(strings.ToLower(val), strings.ToLower(search)) + }) } // DeleteBotInstance deletes a specific bot instance matching the given bot name @@ -140,6 +154,11 @@ func (b *BotInstanceService) DeleteBotInstance(ctx context.Context, botName, ins return trace.Wrap(serviceWithPrefix.DeleteResource(ctx, instanceID)) } +// DeleteAllBotInstances deletes all bot instances for all bots +func (b *BotInstanceService) DeleteAllBotInstances(ctx context.Context) error { + return trace.Wrap(b.service.DeleteAllResources(ctx)) +} + // PatchBotInstance uses the supplied function to patch the bot instance // matching the given (botName, instanceID) key and persists the patched // resource. It will make multiple attempts if a `CompareFailed` error is diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index f91d2a27e3dfb..4fc2fb2b0b4bc 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -501,7 +501,19 @@ func (c *Client) startInputStreaming(stopCh chan struct{}) error { continue } - c.UpdateClientActivity() + // If the message was due to user input, then we update client activity + // in order to refresh the client_idle_timeout checks. + // + // Note: we count some of the directory sharing messages as client activity + // because we don't want a session to be closed due to inactivity during a large + // file transfer. + switch msg.(type) { + case tdp.KeyboardButton, tdp.MouseMove, tdp.MouseButton, tdp.MouseWheel, + tdp.SharedDirectoryAnnounce, tdp.SharedDirectoryInfoResponse, + tdp.SharedDirectoryReadResponse, tdp.SharedDirectoryWriteResponse: + + c.UpdateClientActivity() + } if withheldResize != nil { c.cfg.Logger.DebugContext(context.Background(), "Sending withheld screen size to client") diff --git a/lib/vnet/ssh_proxy.go b/lib/vnet/ssh_proxy.go index 9dd2c2a0235e6..df82175c25fd7 100644 --- a/lib/vnet/ssh_proxy.go +++ b/lib/vnet/ssh_proxy.go @@ -19,13 +19,12 @@ package vnet import ( "context" "errors" + "io" "log/slog" "sync" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" - - "github.com/gravitational/teleport/lib/utils" ) // sshConn represents an established SSH client or server connection. @@ -171,60 +170,134 @@ func proxyChannel( return } - // Copy channel requests in both directions concurrently. If either fails or - // exits it will cancel the context so that utils.ProxyConn below will close - // both channels so the other goroutine can also exit. + // Copy channel data and requests from the incoming channel to the target + // channel, and vice-versa. + target := newSSHChan(targetChan, targetChanRequests, slog.With("direction", "client->target")) + incoming := newSSHChan(incomingChan, incomingChanRequests, slog.With("direction", "target->client")) + var wg sync.WaitGroup wg.Add(2) - ctx, cancel := context.WithCancel(ctx) go func() { - proxyChannelRequests(ctx, log, targetChan, incomingChanRequests, cancel) - cancel() + target.writeFrom(ctx, incoming) wg.Done() }() go func() { - proxyChannelRequests(ctx, log, incomingChan, targetChanRequests, cancel) - cancel() + incoming.writeFrom(ctx, target) wg.Done() }() + wg.Wait() +} - // ProxyConn copies channel data bidirectionally. If the context is - // canceled it will terminate, it always closes both channels before - // returning. - if err := utils.ProxyConn(ctx, incomingChan, targetChan); err != nil && - !utils.IsOKNetworkError(err) && !errors.Is(err, context.Canceled) { - log.DebugContext(ctx, "Unexpected error proxying channel data", "error", err) +// sshChan manages all writes to an SSH channel and handles closing the channel +// once no more data or requests will be written to it. +type sshChan struct { + ch ssh.Channel + requests <-chan *ssh.Request + log *slog.Logger +} + +func newSSHChan(ch ssh.Channel, requests <-chan *ssh.Request, log *slog.Logger) *sshChan { + return &sshChan{ + ch: ch, + requests: requests, + log: log, } +} + +// writeFrom writes channel data and requests from the source to this SSH channel. +// +// In the happy path it waits for: +// - channel data reads from source to return EOF +// - the source request channel to be closed +// and then closes this channel. +// +// Channel data reads from source can return EOF at any time if it has sent +// SSH_MSG_CHANNEL_EOF but it is still valid to send more channel requests +// after this. +// +// If an unrecoverable error is encountered it immediately closes both +// channels. +func (c *sshChan) writeFrom(ctx context.Context, source *sshChan) { + // Close the channel after all data and request writes are complete. + defer c.ch.Close() - // Wait for all goroutines to terminate. + var wg sync.WaitGroup + wg.Add(2) + go func() { + c.writeDataFrom(ctx, source) + wg.Done() + }() + go func() { + c.writeRequestsFrom(ctx, source) + wg.Done() + }() wg.Wait() } -func proxyChannelRequests( - ctx context.Context, - log *slog.Logger, - targetChan ssh.Channel, - reqs <-chan *ssh.Request, - closeChannels func(), -) { - log = log.With("request_layer", "channel") +// writeDataFrom writes channel data from source to this SSH channel. +// It handles standard channel data and extended channel data of type stderr. +func (c *sshChan) writeDataFrom(ctx context.Context, source *sshChan) { + // Close the channel for writes only after both the standard and stderr + // streams are finished writing. + defer c.ch.CloseWrite() + + errors := make(chan error, 2) + go func() { + _, err := io.Copy(c.ch, source.ch) + errors <- err + }() + go func() { + _, err := io.Copy(c.ch.Stderr(), source.ch.Stderr()) + errors <- err + }() + + // Read both errors to make sure both goroutines terminate, but only do + // anything on the first non-nil error, the second error is likely either + // the same as the first one or caused by closing the channel. + handledError := false + for range 2 { + err := <-errors + if err != nil && !handledError { + handledError = true + // Failed to write channel data from source to this channel. This was + // not an EOF from source or io.Copy would have returned nil. The + // stream might be missing data so close both channels. + // + // This should also unblock the stderr stream if the regular stream + // returned an error, and vice-versa. + c.log.ErrorContext(ctx, "Fatal error proxying SSH channel data", "error", err) + c.ch.Close() + source.ch.Close() + } + } +} + +// writeRequestsFrom forwards channel requests from source to this SSH channel. +func (c *sshChan) writeRequestsFrom(ctx context.Context, source *sshChan) { + log := c.log.With("request_layer", "channel") sendRequest := func(name string, wantReply bool, payload []byte) (bool, []byte, error) { - ok, err := targetChan.SendRequest(name, wantReply, payload) + ok, err := c.ch.SendRequest(name, wantReply, payload) // Replies to channel requests never have a payload. return ok, nil, err } - proxyRequests(ctx, log, sendRequest, reqs, closeChannels) + // Must forcibly close both channels if there was a fatal error proxying + // channel requests so that we don't continue in a bad state. + onFatalError := func() { + c.ch.Close() + source.ch.Close() + } + proxyRequests(ctx, log, sendRequest, source.requests, onFatalError) } func proxyGlobalRequests( ctx context.Context, targetConn ssh.Conn, reqs <-chan *ssh.Request, - closeConnections func(), + onFatalError func(), ) { log := log.With("request_layer", "global") sendRequest := targetConn.SendRequest - proxyRequests(ctx, log, sendRequest, reqs, closeConnections) + proxyRequests(ctx, log, sendRequest, reqs, onFatalError) } func proxyRequests( @@ -232,7 +305,7 @@ func proxyRequests( log *slog.Logger, sendRequest func(name string, wantReply bool, payload []byte) (bool, []byte, error), reqs <-chan *ssh.Request, - closeRequestSources func(), + onFatalError func(), ) { for req := range reqs { log := log.With("request_type", req.Type) @@ -240,23 +313,20 @@ func proxyRequests( ok, reply, err := sendRequest(req.Type, req.WantReply, req.Payload) if err != nil { // We failed to send the request, the target must be dead. - log.DebugContext(ctx, "Failed to forward SSH request", "request_type", req.Type, "error", err) - // Close both connections or channels to clean up but we must - // continue handling requests on the chan until it is closed by - // crypto/ssh. - closeRequestSources() - _ = req.Reply(false, nil) - continue + log.DebugContext(ctx, "Failed to forward SSH request", "error", err) + onFatalError() + req.Reply(false, nil) + ssh.DiscardRequests(reqs) + return } if err := req.Reply(ok, reply); err != nil { // A reply was expected and returned by the target but we failed to // forward it back, the connection that initiated the request must // be dead. - log.DebugContext(ctx, "Failed to reply to SSH request", "request_type", req.Type, "error", err) - // Close both connections or channels to clean up but we must - // continue handling requests on the chan until it is closed by - // crypto/ssh. - closeRequestSources() + log.DebugContext(ctx, "Failed to reply to SSH request", "error", err) + onFatalError() + ssh.DiscardRequests(reqs) + return } } } diff --git a/lib/vnet/ssh_proxy_test.go b/lib/vnet/ssh_proxy_test.go index 4c9b0e026ab80..d7eb8313c8a11 100644 --- a/lib/vnet/ssh_proxy_test.go +++ b/lib/vnet/ssh_proxy_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net" + "sync" "testing" "github.com/gravitational/trace" @@ -108,11 +109,17 @@ func testSSHConnection(t *testing.T, dial dialer) { } func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - go ssh.DiscardRequests(reqs) + requestStreamEnded := make(chan struct{}) + go func() { + ssh.DiscardRequests(reqs) + close(requestStreamEnded) + }() + chanStreamEnded := make(chan struct{}) go func() { for newChan := range chans { newChan.Reject(ssh.Prohibited, "test") } + close(chanStreamEnded) }() // Try sending some global requests. @@ -136,6 +143,26 @@ func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan t.Run("echo channel 2", func(t *testing.T) { testEchoChannel(t, sshConn) }) + + t.Run("closing", func(t *testing.T) { + // Send a request that causes the target server to close the connection + // immediately and make sure channel reads are unblocked, and the global + // request and channel request streams end. + ch, reqs, err := sshConn.OpenChannel("echo", nil) + require.NoError(t, err) + go ssh.DiscardRequests(reqs) + readErr := make(chan error) + go func() { + var b [1]byte + _, err := ch.Read(b[:]) + readErr <- err + }() + _, _, err = sshConn.SendRequest("close", false, nil) + require.NoError(t, err) + require.ErrorIs(t, <-readErr, io.EOF) + <-requestStreamEnded + <-chanStreamEnded + }) } func testGlobalRequests(t *testing.T, conn ssh.Conn) { @@ -156,7 +183,11 @@ func testGlobalRequests(t *testing.T, conn ssh.Conn) { func testEchoChannel(t *testing.T, conn ssh.Conn) { ch, reqs, err := conn.OpenChannel("echo", nil) require.NoError(t, err) - go ssh.DiscardRequests(reqs) + requestStreamEnded := make(chan struct{}) + go func() { + ssh.DiscardRequests(reqs) + close(requestStreamEnded) + }() defer ch.Close() // Try sending a message over the SSH channel and asserting that it is @@ -170,16 +201,43 @@ func testEchoChannel(t *testing.T, conn ssh.Conn) { require.Equal(t, len(msg), n) require.Equal(t, msg, buf[:n]) + // Try sending a message over stderr and asserting that it is echoed back. + _, err = ch.Stderr().Write(msg) + require.NoError(t, err) + n, err = ch.Stderr().Read(buf[:]) + require.NoError(t, err) + require.Equal(t, len(msg), n) + require.Equal(t, msg, buf[:n]) + // Try sending a channel request that expects a reply. reply, err := ch.SendRequest("echo", true, nil) require.NoError(t, err) require.True(t, reply) + // Close the channel for writes of in-band data and send another channel + // request, which should succeed. + require.NoError(t, ch.CloseWrite()) + reply, err = ch.SendRequest("echo", true, nil) + require.NoError(t, err) + require.True(t, reply) + // The test server replies false to channel requests with type other than // "echo". reply, err = ch.SendRequest("unknown", true, nil) require.NoError(t, err) require.False(t, reply) + + // Send a channel request that causes the server to close the channel and + // make sure channel reads get unblocked and the incoming request stream ends. + readErr := make(chan error) + go func() { + _, err := ch.Read(buf[:]) + readErr <- err + }() + _, err = ch.SendRequest("close", false, nil) + require.NoError(t, err) + require.ErrorIs(t, <-readErr, io.EOF) + <-requestStreamEnded } type dialer interface { @@ -282,7 +340,7 @@ func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error { return trace.Wrap(err) } go func() { - handleEchoRequests(reqs) + handleSSHRequests(reqs, sshConn.Close) sshConn.Close() }() handleEchoChannels(chans) @@ -290,17 +348,6 @@ func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error { return nil } -func handleEchoRequests(reqs <-chan *ssh.Request) { - for req := range reqs { - switch req.Type { - case "echo": - req.Reply(true, req.Payload) - default: - req.Reply(false, nil) - } - } -} - func handleEchoChannels(chans <-chan ssh.NewChannel) { for newChan := range chans { switch newChan.ChannelType() { @@ -317,8 +364,33 @@ func handleEchoChannel(newChan ssh.NewChannel) { if err != nil { return } - go handleEchoRequests(reqs) - io.Copy(ch, ch) + go handleSSHRequests(reqs, ch.Close) + defer ch.CloseWrite() + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(ch, ch) + wg.Done() + }() + go func() { + io.Copy(ch.Stderr(), ch.Stderr()) + wg.Done() + }() + wg.Wait() +} + +func handleSSHRequests(reqs <-chan *ssh.Request, closeSource func() error) { + defer closeSource() + for req := range reqs { + switch req.Type { + case "echo": + req.Reply(true, req.Payload) + case "close": + closeSource() + default: + req.Reply(false, nil) + } + } } func sshServerConfig(t *testing.T) *ssh.ServerConfig { diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 31f5be490e29e..8fd300dd41ee9 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -25,6 +25,7 @@ import ( "errors" "log/slog" "net/http" + "net/url" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -40,6 +41,7 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/sso" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/desktop" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -384,21 +386,42 @@ func (h *Handler) performSessionMFACeremony( span.End() }() + // channelID is used by the front end to differentiate between separate ongoing SSO challenges. + channelID := uuid.NewString() + mfaCeremony := &mfa.Ceremony{ - PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt { + CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, + SSOMFACeremonyConstructor: func(_ context.Context) (mfa.SSOMFACeremony, error) { + u, err := url.Parse(sso.WebMFARedirect) + if err != nil { + return nil, trace.Wrap(err) + } + u.RawQuery = url.Values{"channel_id": {channelID}}.Encode() + return &sso.MFACeremony{ + ClientCallbackURL: u.String(), + ProxyAddress: h.PublicProxyAddr(), + }, nil + }, + PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt { return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - codec := tdpMFACodec{} + // Convert from proto to JSON types. + var challenge client.MFAAuthenticateChallenge + if chal.WebauthnChallenge != nil { + challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge) + } - if chal.WebauthnChallenge == nil { - return nil, trace.AccessDenied("Desktop access requires WebAuthn MFA, please register a WebAuthn device to connect") + if chal.SSOChallenge != nil { + challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge) + challenge.SSOChallenge.ChannelID = channelID } + + if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil { + return nil, trace.AccessDenied("Only WebAuthn and SSO MFA methods are supported on the web, please register a supported MFA method to connect to this desktop") + } + // Send the challenge over the socket. - msg, err := codec.Encode( - &client.MFAAuthenticateChallenge{ - WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), - }, - defaults.WebsocketMFAChallenge, - ) + var codec tdpMFACodec + msg, err := codec.Encode(&challenge, defaults.WebsocketMFAChallenge) if err != nil { return nil, trace.Wrap(err) } @@ -457,7 +480,6 @@ func (h *Handler) performSessionMFACeremony( return assertion, nil }) }, - CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, } result, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{ @@ -552,7 +574,6 @@ func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds *tls.Conn, go monitorLatency(ctx, clockwork.NewRealClock(), ws, pinger, latency.ReporterFunc(func(ctx context.Context, stats latency.Statistics) error { - log.DebugContext(ctx, "sending latency stats", "client", stats.Client, "server", stats.Server) return trace.Wrap(tdpConnProxy.SendToClient(tdp.LatencyStats{ ClientLatency: uint32(stats.Client), ServerLatency: uint32(stats.Server), diff --git a/lib/web/scripts/install/install.sh b/lib/web/scripts/install/install.sh index 29892b7c55551..71d9d3f70934d 100755 --- a/lib/web/scripts/install/install.sh +++ b/lib/web/scripts/install/install.sh @@ -328,13 +328,13 @@ install_teleport() { esac # select install method based on distribution - # if ID is debian derivate, run apt-get + # if ID is debian derivative, run apt-get case "$ID" in debian | ubuntu | kali | linuxmint | pop | raspbian | neon | zorin | parrot | elementary) install_via_apt_get ;; # if ID is amazon Linux 2/RHEL/etc, run yum - centos | rhel | amzn) + centos | rhel | rocky | almalinux | amzn) install_via_yum "$ID" ;; sles) diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 3abc495d96eb4..2489073369366 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -624,16 +624,11 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc, proxyAddr string) *mfa.Ceremony { // channelID is used by the front end to differentiate between separate ongoing SSO challenges. - var channelID string + channelID := uuid.NewString() return &mfa.Ceremony{ CreateAuthenticateChallenge: createAuthenticateChallenge, SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) { - id, err := uuid.NewRandom() - if err != nil { - return nil, trace.Wrap(err) - } - channelID = id.String() u, err := url.Parse(sso.WebMFARedirect) if err != nil { diff --git a/web/packages/shared/libs/tdp/client.ts b/web/packages/shared/libs/tdp/client.ts index a7435ef93db30..af169cfb32fe6 100644 --- a/web/packages/shared/libs/tdp/client.ts +++ b/web/packages/shared/libs/tdp/client.ts @@ -342,7 +342,7 @@ export class TdpClient extends EventEmitter { this.handleRdpConnectionActivated(buffer); break; case MessageType.RDP_FASTPATH_PDU: - this.handleRdpFastPathPDU(buffer); + this.handleRdpFastPathPdu(buffer); break; case MessageType.CLIENT_SCREEN_SPEC: this.handleClientScreenSpec(buffer); @@ -484,8 +484,8 @@ export class TdpClient extends EventEmitter { this.emit(TdpClientEvent.TDP_CLIENT_SCREEN_SPEC, spec); } - handleRdpFastPathPDU(buffer: ArrayBufferLike) { - let rdpFastPathPDU = this.codec.decodeRdpFastPathPDU(buffer); + handleRdpFastPathPdu(buffer: ArrayBufferLike) { + let rdpFastPathPdu = this.codec.decodeRdpFastPathPdu(buffer); // This should never happen but let's catch it with an error in case it does. if (!this.fastPathProcessor) { @@ -493,13 +493,13 @@ export class TdpClient extends EventEmitter { } this.fastPathProcessor.process( - rdpFastPathPDU, + rdpFastPathPdu, this, (bmpFrame: BitmapFrame) => { this.emit(TdpClientEvent.TDP_BMP_FRAME, bmpFrame); }, (responseFrame: ArrayBuffer) => { - this.sendRdpResponsePDU(responseFrame); + this.sendRdpResponsePdu(responseFrame); }, (data: ImageData | boolean, hotspot_x?: number, hotspot_y?: number) => { this.emit(TdpClientEvent.POINTER, { data, hotspot_x, hotspot_y }); @@ -821,8 +821,8 @@ export class TdpClient extends EventEmitter { this.sendClientScreenSpec(spec); }; - sendRdpResponsePDU(responseFrame: ArrayBufferLike) { - this.send(this.codec.encodeRdpResponsePDU(responseFrame)); + sendRdpResponsePdu(responseFrame: ArrayBufferLike) { + this.send(this.codec.encodeRdpResponsePdu(responseFrame)); } // Emits a warning event, but keeps the socket open. diff --git a/web/packages/shared/libs/tdp/codec.ts b/web/packages/shared/libs/tdp/codec.ts index 7c38e12541e5d..75eb8061e8b4f 100644 --- a/web/packages/shared/libs/tdp/codec.ts +++ b/web/packages/shared/libs/tdp/codec.ts @@ -745,7 +745,7 @@ export default class Codec { } // | message type (30) | data_length uint32 | data []byte | - encodeRdpResponsePDU(responseFrame: ArrayBufferLike): Message { + encodeRdpResponsePdu(responseFrame: ArrayBufferLike): Message { const bufLen = BYTE_LEN + UINT_32_LEN + responseFrame.byteLength; const buffer = new ArrayBuffer(bufLen); const view = new DataView(buffer); @@ -894,7 +894,7 @@ export default class Codec { } // | message type (29) | data_length uint32 | data []byte | - decodeRdpFastPathPDU(buffer: ArrayBufferLike): RdpFastPathPdu { + decodeRdpFastPathPdu(buffer: ArrayBufferLike): RdpFastPathPdu { const dv = new DataView(buffer); let offset = 0; offset += BYTE_LEN; // eat message type diff --git a/web/packages/teleport/index.html b/web/packages/teleport/index.html index 34d6ce56782c2..917ff53c548cb 100644 --- a/web/packages/teleport/index.html +++ b/web/packages/teleport/index.html @@ -10,8 +10,15 @@ + diff --git a/web/packages/teleport/src/Discover/SelectResource/resources/resources.tsx b/web/packages/teleport/src/Discover/SelectResource/resources/resources.tsx index 5f2e5c29cb3b5..ed23a0627745d 100644 --- a/web/packages/teleport/src/Discover/SelectResource/resources/resources.tsx +++ b/web/packages/teleport/src/Discover/SelectResource/resources/resources.tsx @@ -68,7 +68,16 @@ export const SERVERS: SelectResourceSpec[] = [ id: DiscoverGuideId.ServerLinuxRhelCentos, name: 'RHEL 8+/CentOS Stream 9+', kind: ResourceKind.Server, - keywords: [...baseServerKeywords, 'rhel', 'redhat', 'centos', 'linux'], + keywords: [ + ...baseServerKeywords, + 'rhel', + 'redhat', + 'centos', + 'linux', + 'rocky', + 'alma', + 'almalinux', + ], icon: 'linux', event: DiscoverEventResource.Server, platform: Platform.Linux, diff --git a/web/packages/teleport/src/lib/tdp/playerClient.ts b/web/packages/teleport/src/lib/tdp/playerClient.ts index f22b4deb94dae..eed18671ec0c5 100644 --- a/web/packages/teleport/src/lib/tdp/playerClient.ts +++ b/web/packages/teleport/src/lib/tdp/playerClient.ts @@ -208,7 +208,7 @@ export class PlayerClient extends TdpClient { // RDP response PDUs to the server during playback, which is unnecessary // and breaks the playback system. // eslint-disable-next-line unused-imports/no-unused-vars - sendRdpResponsePDU(responseFrame: ArrayBuffer) { + sendRdpResponsePdu(responseFrame: ArrayBuffer) { return; }