From 71ac32a32ffbf8afdec6cea7e21e0f4bed52e125 Mon Sep 17 00:00:00 2001 From: Kevin Fox Date: Sun, 19 Jan 2025 15:17:25 -0800 Subject: [PATCH] HA startup Use spire-ha federated entry to allow single agent startup. Signed-off-by: Kevin Fox --- cmd/spire-ha-agent/main.go | 152 +++++++++++++++++++++++++------------ 1 file changed, 105 insertions(+), 47 deletions(-) diff --git a/cmd/spire-ha-agent/main.go b/cmd/spire-ha-agent/main.go index 590c485..bf0c8b9 100644 --- a/cmd/spire-ha-agent/main.go +++ b/cmd/spire-ha-agent/main.go @@ -12,6 +12,7 @@ import ( "fmt" "crypto/x509" "reflect" + "slices" "sync" "strconv" "os" @@ -64,8 +65,10 @@ type clientSet struct { clientOK bool debugClient agentdebug.DebugClient delegatedClient agentdelegated.DelegatedIdentityClient - bundle *x509bundle.Set - jwtBundles map[string]jose.JSONWebKeySet + ourX509Bundle *x509bundle.Bundle + haX509Bundle *x509bundle.Bundle + ourJWTBundle *jose.JSONWebKeySet + haJWTBundle *jose.JSONWebKeySet } func ConcatRawCertsFromCerts(certs []*x509.Certificate) []byte { @@ -471,8 +474,6 @@ func setupClient(ls *server, clientName string, id int, adminSocketName string, log.Fatalf("Failed to dial context: %v", err) } - ls.x509BundleUpdate = make(chan x509BundleUpdated) - ls.jwtBundleUpdate = make(chan jwtBundleUpdated) cs.delegatedClient = agentdelegated.NewDelegatedIdentityClient(dconn) cs.debugClient = agentdebug.NewDebugClient(dconn) go func() { @@ -528,6 +529,7 @@ func setupClient(ls *server, clientName string, id int, adminSocketName string, } log.Printf("Pushing x509 bundle") ls.x509BundleUpdate <- x509BundleUpdated{id, bundles} + } } }() @@ -552,8 +554,8 @@ func setupClient(ls *server, clientName string, id int, adminSocketName string, bundles := resp.GetBundles() jwksBundles := make(map[string]jose.JSONWebKeySet) for td, bundle := range bundles { - log.Printf("jwt Bundle: %s %s", td, string(bundle)) - //log.Printf("jwt Bundle: %s %d", td, len(bundle)) + //log.Printf("jwt Bundle: %s %s", td, string(bundle)) + log.Printf("jwt Bundle: %s %d", td, len(bundle)) jwks := new(jose.JSONWebKeySet) if err := json.NewDecoder(bytes.NewReader(bundle)).Decode(jwks); err != nil { log.Printf("failed to decode key set: %v", err) @@ -611,7 +613,7 @@ func main() { ) apath := "unix:///var/run/spire/agent/sockets/a/private/admin.sock" - bpath := "unix:///var/run/spire/agent/sockets/a/private/admin.sock" + bpath := "unix:///var/run/spire/agent/sockets/b/private/admin.sock" aname := "SPIRE_HA_AGENT_SOCKET" if ls.multi { aname = "SPIRE_HA_AGENT_SOCKET_A" @@ -619,12 +621,14 @@ func main() { if os.Getenv(aname) != "" { apath = os.Getenv(aname) } - setupClient(ls, "clientA", 0, apath, &ls.clients[0]) - if !ls.multi { + ls.x509BundleUpdate = make(chan x509BundleUpdated) + ls.jwtBundleUpdate = make(chan jwtBundleUpdated) + go setupClient(ls, "clientA", 0, apath, &ls.clients[0]) + if ls.multi { if os.Getenv("SPIRE_HA_AGENT_SOCKET_B") != "" { bpath = os.Getenv("SPIRE_HA_AGENT_SOCKET_B") } - setupClient(ls, "clientB", 1, bpath, &ls.clients[1]) + go setupClient(ls, "clientB", 1, bpath, &ls.clients[1]) } go func() { @@ -636,25 +640,53 @@ func main() { }() go func() { + var ourTD *spiffeid.TrustDomain + haTD, _ := spiffeid.TrustDomainFromString("spiffe://spire-ha") log.Printf("Listening for x509 bundle updates\n") for u := range ls.x509BundleUpdate { log.Printf("Got update for %d\n", u.id) - ls.clients[u.id].bundle = u.bundle - if ls.clients[0].bundle != nil && ls.clients[1].bundle != nil { - log.Printf("We got two bundles\n") + bl := u.bundle.Len() + log.Printf("Bundle count on update: %d\n", bl) + if bl < 1 { + log.Printf("Bad bundle pushed by the spire-agent.\n") + os.Exit(1) + } + if bl > 2 { + log.Printf("Too many federated bundles in the trust bundle. Please reconfigure the spire-ha-agent entry.\n") + os.Exit(1) + } + if bl == 2 && !u.bundle.Has(haTD) { + log.Printf("spire-ha trust bundle not found. Please reconfigure the spire-ha-agent entry.\n") + os.Exit(1) + } + for _, bundle := range u.bundle.Bundles() { + td := bundle.TrustDomain() + if td.Name() == "spire-ha" { + ls.clients[u.id].haX509Bundle = bundle + continue + } + if ourTD == nil { + ourTD = &td + log.Printf("Our trust domain detected as: %s\n", ourTD.Name()) + } + ls.clients[u.id].ourX509Bundle = bundle + } + bundles := slices.DeleteFunc([]*x509bundle.Bundle{ls.clients[0].ourX509Bundle, ls.clients[0].haX509Bundle, ls.clients[1].ourX509Bundle, ls.clients[1].haX509Bundle}, func(b *x509bundle.Bundle) bool { + return b == nil + }) + totalBundles := len(bundles) + if totalBundles > 1 || !ls.multi { + log.Printf("We got %d x509 bundles\n", totalBundles) var rawBundles map[string][]byte = make(map[string][]byte) - for _, bundle := range ls.clients[0].bundle.Bundles() { - td := bundle.TrustDomain() - if tdb, ok := ls.clients[1].bundle.Get(td); ok { - for _, cert := range tdb.X509Authorities() { - if !bundle.HasX509Authority(cert) { - bundle.AddX509Authority(cert) - } - } + bundle := x509bundle.New(*ourTD) + for _, tb := range bundles { + for _, cert := range tb.X509Authorities() { + bundle.AddX509Authority(cert) } - rawBundles[td.String()] = ConcatRawCertsFromCerts(bundle.X509Authorities()) } + rawBundles[ourTD.String()] = ConcatRawCertsFromCerts(bundle.X509Authorities()) if initBundle { + log.Printf("x509 inited") wg.Done() initBundle = false } @@ -675,42 +707,68 @@ func main() { }() go func() { + var ourTD *spiffeid.TrustDomain + //haTD, _ := spiffeid.TrustDomainFromString("spiffe://spire-ha") log.Printf("Listening for jwt bundle updates\n") for u := range ls.jwtBundleUpdate { log.Printf("Got update for %d\n", u.id) - ls.clients[u.id].jwtBundles = u.bundle - if !ls.multi { - ls.clients[1].jwtBundles = u.bundle + bl := len(u.bundle) + log.Printf("JWT bundle count on update: %d\n", bl) + if bl < 1 { + log.Printf("Bad JWT bundle pushed by the spire-agent.\n") + os.Exit(1) + } + if bl > 2 { + log.Printf("Too many federated bundles in the JWT trust bundle. Please reconfigure the spire-ha-agent entry.\n") + os.Exit(1) } - if ls.clients[0].jwtBundles != nil && ls.clients[1].jwtBundles != nil { - log.Printf("We got two jwt bundles\n") - tmpBundles := make(map[string]jose.JSONWebKeySet) + if _, ok := u.bundle["spiffe://spire-ha"]; bl == 2 && !ok { + log.Printf("spire-ha trust bundle not found in JWT trust bundle. Please reconfigure the spire-ha-agent entry. %s\n", u.bundle) + os.Exit(1) + } + for tdSTR, bundle := range u.bundle { + td, err := spiffeid.TrustDomainFromString(tdSTR) + if err != nil { + log.Printf("Failed to parse JWT trust bundle string. This should not happen.\n") + os.Exit(1) + } + //td := bundle.TrustDomain() + if td.Name() == "spire-ha" { + ls.clients[u.id].haJWTBundle = &bundle + continue + } + if ourTD == nil { + ourTD = &td + log.Printf("Our trust domain detected as: %s\n", ourTD.Name()) + } + ls.clients[u.id].ourJWTBundle = &bundle + } + bundles := slices.DeleteFunc([]*jose.JSONWebKeySet{ls.clients[0].ourJWTBundle, ls.clients[0].haJWTBundle, ls.clients[1].ourJWTBundle, ls.clients[1].haJWTBundle}, func(b *jose.JSONWebKeySet) bool { + return b == nil + }) + totalBundles := len(bundles) + if totalBundles > 1 || !ls.multi { + log.Printf("We got %d jwt bundles\n", totalBundles) var rawBundles map[string][]byte = make(map[string][]byte) - for td, bundle := range ls.clients[0].jwtBundles { - kids := make(map[string]bool) - var set jose.JSONWebKeySet + kids := make(map[string]bool) + var set jose.JSONWebKeySet + for _, bundle := range bundles { for _, b := range bundle.Keys { - kids[b.KeyID] = true - set.Keys = append(set.Keys, b) - } - if tdb, ok := ls.clients[1].jwtBundles[td]; ok { - for _, b := range tdb.Keys { - if _, ok := kids[b.KeyID]; !ok { - set.Keys = append(set.Keys, b) - } + if _, ok := kids[b.KeyID]; !ok { + kids[b.KeyID] = true + set.Keys = append(set.Keys, b) } } - tmpBundles[td] = set -//FIXME td's in 1 but not 0. Maybe same with x509? - res, err := json.Marshal(tmpBundles[td]) - if err != nil { + } + res, err := json.Marshal(set) + if err != nil { //FIXME what is the best way to handle this - log.Printf("Failed to marchal. %v", err) - continue - } - rawBundles[td] = res + log.Printf("Failed to marshal. %v", err) + continue } + rawBundles[ourTD.Name()] = res if jwtInitBundle { + log.Printf("jwt inited") jwtWg.Done() jwtInitBundle = false }