Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions service/kas/access/rewrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func err403(s string) error {
return connect.NewError(connect.CodePermissionDenied, errors.Join(ErrUser, status.Error(codes.PermissionDenied, s)))
}

func err500(s string) error {
return connect.NewError(connect.CodeInternal, errors.Join(ErrInternal, status.Error(codes.Internal, s)))
}

func generateHMACDigest(ctx context.Context, msg, key []byte, logger logger.Logger) ([]byte, error) {
mac := hmac.New(sha256.New, key)
_, err := mac.Write(msg)
Expand Down Expand Up @@ -212,11 +216,13 @@ func extractSRTBody(ctx context.Context, headers http.Header, in *kaspb.RewrapRe
err = protojson.Unmarshal([]byte(rbString), &requestBody)
// if there are no requests then it could be a v1 request
if err != nil || len(requestBody.GetRequests()) == 0 {
logger.WarnContext(ctx, "invalid request body! checking v1 SRT")
requestBody, err = extractAndConvertV1SRTBody([]byte(rbString))
if err != nil {
var errv1 error
requestBody, errv1 = extractAndConvertV1SRTBody([]byte(rbString))
if errv1 != nil {
logger.WarnContext(ctx, "invalid SRT", "err.v1", errv1, "err.v2", err)
return nil, false, err400("invalid request body")
}
logger.DebugContext(ctx, "legacy v1 SRT", "err.v2", err)
isV1 = true
}
logger.DebugContext(ctx, "extracted request body", slog.String("rewrap.body", requestBody.String()), slog.Any("rewrap.srt", rbString))
Expand Down Expand Up @@ -596,15 +602,16 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
return "", results
}

privateKeyHandle, publicKeyHandle, err := p.CryptoProvider.GenerateEphemeralKasKeys()
privateKeyHandle, ephemeralKeyPEM, err := p.CryptoProvider.GenerateEphemeralKasKeys()
if err != nil {
failAllKaos(requests, results, fmt.Errorf("failed to generate keypair: %w", err))
failAllKaos(requests, results, err500("entropy failure"))
p.Logger.WarnContext(ctx, "failure in GenerateEphemeralKasKeys", "err", err)
return "", results
}
sessionKey, err := p.CryptoProvider.GenerateNanoTDFSessionKey(privateKeyHandle, []byte(clientPublicKey))
if err != nil {
p.Logger.DebugContext(ctx, "GenerateNanoTDFSessionKey", "err", err)
failAllKaos(requests, results, fmt.Errorf("failed to generate session key: %w", err))
p.Logger.WarnContext(ctx, "failure in GenerateNanoTDFSessionKey", "err", err)
failAllKaos(requests, results, err400("keypair mismatch"))
return "", results
}

Expand Down Expand Up @@ -653,7 +660,7 @@ func (p *Provider) nanoTDFRewrap(ctx context.Context, requests []*kaspb.Unsigned
p.Logger.Audit.RewrapSuccess(ctx, auditEventParams)
}
}
return string(publicKeyHandle), results
return string(ephemeralKeyPEM), results
}

func (p *Provider) verifyNanoRewrapRequests(ctx context.Context, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) (*Policy, map[string]kaoResult) {
Expand Down
Loading