Skip to content

Commit

Permalink
feat: #54 support secrets encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
bohdan-shulha committed Sep 22, 2024
1 parent ff01ef8 commit 367eab9
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 87 deletions.
143 changes: 143 additions & 0 deletions internal/app/ptah-agent/encryption.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package ptah_agent

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"

"github.com/pkg/errors"

"github.com/docker/docker/api/types/swarm"
t "github.com/ptah-sh/ptah-agent/internal/pkg/ptah-client"
)

type EncryptionKeyPair struct {
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
}

// getEncryptionKey checks for existing key or generates a new one
func (e *taskExecutor) getEncryptionKey(ctx context.Context) (*EncryptionKeyPair, error) {
existingConfig, err := e.getConfigByName(ctx, "ptah_encryption_key")
if err != nil && !errors.Is(err, ErrConfigNotFound) {
return nil, fmt.Errorf("failed to check for existing encryption key: %v", err)
}

if existingConfig != nil {
var keyPair EncryptionKeyPair

err = json.Unmarshal(existingConfig.Spec.Data, &keyPair)

if err != nil {
return nil, fmt.Errorf("failed to unmarshal existing encryption key: %v", err)
}

return &keyPair, nil
}

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA key pair: %v", err)
}

privateKeyPEM := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}
privateKeyStr := string(pem.EncodeToMemory(privateKeyPEM))

publicKey, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %v", err)
}

publicKeyPEM := &pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKey,
}
publicKeyStr := string(pem.EncodeToMemory(publicKeyPEM))

keyPair := &EncryptionKeyPair{
PrivateKey: privateKeyStr,
PublicKey: publicKeyStr,
}

keyPairJSON, err := json.Marshal(keyPair)
if err != nil {
return nil, fmt.Errorf("failed to marshal encryption key pair: %v", err)
}

_, err = e.createDockerConfig(ctx, &t.CreateConfigReq{
SwarmConfigSpec: swarm.ConfigSpec{
Annotations: swarm.Annotations{
Name: "ptah_encryption_key",
Labels: map[string]string{},
},
Data: keyPairJSON,
},
})
if err != nil {
return nil, fmt.Errorf("failed to save encryption key to Docker config: %v", err)
}

return keyPair, nil
}

func (e *taskExecutor) decryptValue(ctx context.Context, encryptedValue string) (string, error) {
keyPair, err := e.getEncryptionKey(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to get encryption key")
}

block, _ := pem.Decode([]byte(keyPair.PrivateKey))
if block == nil {
return "", errors.New("failed to parse PEM block containing the private key")
}

privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", errors.Wrap(err, "failed to parse private key")
}

encryptedBytes, err := base64.StdEncoding.DecodeString(encryptedValue)
if err != nil {
return "", errors.Wrap(err, "failed to decode encrypted value")
}

decryptedBytes, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, privateKey, encryptedBytes, []byte(""))
if err != nil {
return "", errors.Wrap(err, "failed to decrypt value")
}

return string(decryptedBytes), nil
}

func (e *taskExecutor) encryptValue(ctx context.Context, value string) (string, error) {
keyPair, err := e.getEncryptionKey(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to get encryption key")
}

block, _ := pem.Decode([]byte(keyPair.PublicKey))
if block == nil {
return "", errors.New("failed to parse PEM block containing the public key")
}

publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return "", errors.Wrap(err, "failed to parse public key")
}

encryptedBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey.(*rsa.PublicKey), []byte(value), []byte(""))
if err != nil {
return "", errors.Wrap(err, "failed to encrypt value")
}

return base64.StdEncoding.EncodeToString(encryptedBytes), nil
}
79 changes: 56 additions & 23 deletions internal/app/ptah-agent/ptah_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
)

type Agent struct {
Version string
ptah *ptahClient.Client
rootDir string
docker *dockerClient.Client
caddy *caddyClient.Client
Version string
ptah *ptahClient.Client
rootDir string
docker *dockerClient.Client
caddy *caddyClient.Client
executor *taskExecutor
}

func New(version string, baseUrl string, ptahToken string, rootDir string) (*Agent, error) {
Expand All @@ -34,13 +35,27 @@ func New(version string, baseUrl string, ptahToken string, rootDir string) (*Age
ctx := context.Background()
docker.NegotiateAPIVersion(ctx)

return &Agent{
caddy := caddyClient.New("http://127.0.0.1:2019", http.DefaultClient)

// TODO: refactor to avoid duplication and circular dependency?
agent := &Agent{
Version: version,
ptah: ptahClient.New(baseUrl, ptahToken),
rootDir: rootDir,
caddy: caddyClient.New("http://127.0.0.1:2019", http.DefaultClient),
caddy: caddy,
docker: docker,
}, nil
executor: &taskExecutor{
docker: docker,
caddy: caddy,
rootDir: rootDir,
// TODO: use channel instead?
stopAgentFlag: false,
},
}

agent.executor.agent = agent

return agent, nil
}

func (a *Agent) sendStartedEvent(ctx context.Context) (*ptahClient.StartedRes, error) {
Expand Down Expand Up @@ -91,13 +106,30 @@ func (a *Agent) sendStartedEvent(ctx context.Context) (*ptahClient.StartedRes, e
})
}

workerJoinToken, err := a.executor.encryptValue(ctx, swarm.JoinTokens.Worker)
if err != nil {
return nil, err
}

managerJoinToken, err := a.executor.encryptValue(ctx, swarm.JoinTokens.Manager)
if err != nil {
return nil, err
}

startedReq.SwarmData = &ptahClient.SwarmData{
JoinTokens: ptahClient.JoinTokens{
Worker: swarm.JoinTokens.Worker,
Manager: swarm.JoinTokens.Manager,
Worker: workerJoinToken,
Manager: managerJoinToken,
},
ManagerNodes: managerNodes,
}

encryptionKey, err := a.executor.getEncryptionKey(ctx)
if err != nil {
return nil, err
}

startedReq.SwarmData.EncryptionKey = encryptionKey.PublicKey
}

log.Println("sending started event, base url", a.ptah.BaseUrl)
Expand All @@ -115,29 +147,30 @@ func (a *Agent) Start(ctx context.Context) error {
return err
}

executor := &taskExecutor{
docker: a.docker,
caddy: a.caddy,
rootDir: a.rootDir,
// TODO: use channel instead?
stopAgentFlag: false,
agent: a,
}

log.Println("connected to server, poll interval", settings.Settings.PollInterval)

consecutiveFailures := 0
maxConsecutiveFailures := 5

for {
taskID, task, err := a.getNextTask(ctx)
if err != nil {
log.Println("can't get the next task", err)
consecutiveFailures++

if taskID != 0 {
if taskID == 0 {
if consecutiveFailures >= maxConsecutiveFailures {
return fmt.Errorf("shutting down due to %d consecutive failures to get next task", maxConsecutiveFailures)
}
} else {
if err = a.ptah.FailTask(ctx, taskID, &ptahClient.TaskError{
Message: err.Error(),
}); err != nil {
log.Println("can't fail task", err)
}
}
} else {
consecutiveFailures = 0
}

if task == nil {
Expand All @@ -146,7 +179,7 @@ func (a *Agent) Start(ctx context.Context) error {
continue
}

result, err := executor.executeTask(ctx, task)
result, err := a.executor.executeTask(ctx, task)
// TODO: store the result to re-send it once connection to the ptah server is restored
if err == nil {
if err = a.ptah.CompleteTask(ctx, taskID, result); err != nil {
Expand All @@ -160,7 +193,7 @@ func (a *Agent) Start(ctx context.Context) error {
}
}

if executor.stopAgentFlag {
if a.executor.stopAgentFlag {
log.Println("received stop signal, shutting down gracefully")

break
Expand Down Expand Up @@ -191,7 +224,7 @@ func (a *Agent) getNextTask(ctx context.Context) (taskId int, task interface{},
func (a *Agent) ExecTasks(ctx context.Context, jsonFilePath string) error {
// Docker client should already be initialized and version negotiated in New()
if a.docker == nil {
return fmt.Errorf("Docker client not initialized")
return fmt.Errorf("docker client not initialized")
}

// Read the JSON file
Expand Down
25 changes: 6 additions & 19 deletions internal/app/ptah-agent/registry_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,18 @@ package ptah_agent
import (
"context"
"encoding/json"

"github.com/docker/docker/api/types/registry"
"github.com/pkg/errors"
t "github.com/ptah-sh/ptah-agent/internal/pkg/ptah-client"
)

func (e *taskExecutor) createRegistryAuth(ctx context.Context, req *t.CreateRegistryAuthReq) (*t.CreateRegistryAuthRes, error) {
if req.PrevConfigName != "" {
prev, _, err := e.docker.ConfigInspectWithRaw(ctx, req.PrevConfigName)
if err != nil {
return nil, err
}

var authConfig registry.AuthConfig
err = json.Unmarshal(prev.Spec.Data, &authConfig)
if err != nil {
return nil, err
}

if req.AuthConfigSpec.Username == "" {
req.AuthConfigSpec.Username = authConfig.Username
}

if req.AuthConfigSpec.Password == "" {
req.AuthConfigSpec.Password = authConfig.Password
}
decryptedPassword, err := e.decryptValue(ctx, req.AuthConfigSpec.Password)
if err != nil {
return nil, errors.Wrap(err, "failed to decrypt password")
}
req.AuthConfigSpec.Password = decryptedPassword

data, err := json.Marshal(req.AuthConfigSpec)
if err != nil {
Expand Down
30 changes: 9 additions & 21 deletions internal/app/ptah-agent/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,24 @@ import (
func (e *taskExecutor) createS3Storage(ctx context.Context, req *t.CreateS3StorageReq) (*t.CreateS3StorageRes, error) {
var res t.CreateS3StorageRes

if req.S3StorageSpec.AccessKey == "" || req.S3StorageSpec.SecretKey == "" {
if req.PrevConfigName == "" {
return nil, fmt.Errorf("create s3 storage: prev config name is empty - empty credentials")
}

prev, err := e.getConfigByName(ctx, req.PrevConfigName)
if err != nil {
return nil, err
}

var prevSpec t.S3StorageSpec
err = json.Unmarshal(prev.Spec.Data, &prevSpec)
if err != nil {
return nil, fmt.Errorf("create s3 storage: unmarshal prev config: %w", err)
}

req.S3StorageSpec.AccessKey = prevSpec.AccessKey
req.S3StorageSpec.SecretKey = prevSpec.SecretKey
decryptedSecretKey, err := e.decryptValue(ctx, req.S3StorageSpec.SecretKey)
if err != nil {
return nil, fmt.Errorf("create s3 storage: decrypt secret key: %w", err)
}

data, err := json.Marshal(req.S3StorageSpec)
decryptedSpec := req.S3StorageSpec
decryptedSpec.SecretKey = decryptedSecretKey

data, err := json.Marshal(decryptedSpec)
if err != nil {
return nil, err
return nil, fmt.Errorf("create s3 storage: marshal spec: %w", err)
}

req.SwarmConfigSpec.Data = data

config, err := e.docker.ConfigCreate(ctx, req.SwarmConfigSpec)
if err != nil {
return nil, err
return nil, fmt.Errorf("create s3 storage: create config: %w", err)
}

res.Docker.ID = config.ID
Expand Down
Loading

0 comments on commit 367eab9

Please sign in to comment.