From 083a8d1bd271c10507cef94b47621da87e4bc4ed Mon Sep 17 00:00:00 2001 From: Kyle Date: Mon, 4 Nov 2024 16:25:04 -0500 Subject: [PATCH] merge branches, fix auth issue --- auth.go | 18 +++++++++++++++++- auth_test.go | 3 ++- go.mod | 2 +- go.sum | 2 ++ salesforce.go | 20 ++++++++++++++++---- salesforce_test.go | 2 +- 6 files changed, 39 insertions(+), 8 deletions(-) diff --git a/auth.go b/auth.go index 4295e59..2f86971 100644 --- a/auth.go +++ b/auth.go @@ -3,11 +3,15 @@ package salesforce import ( "encoding/json" "errors" + "fmt" "io" "net/http" "net/url" "strconv" "strings" + "time" + + "github.com/golang-jwt/jwt/v5" ) type authentication struct { @@ -29,13 +33,17 @@ type Creds struct { SecurityToken string ConsumerKey string ConsumerSecret string + ConsumerRSAPem string AccessToken string } +const JwtExpirationTime = 5 * time.Minute + const ( grantTypeUsernamePassword = "password" grantTypeClientCredentials = "client_credentials" grantTypeAccessToken = "access_token" + grantTypeJWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" ) func validateAuth(sf Salesforce) error { @@ -81,6 +89,14 @@ func refreshSession(auth *authentication) error { auth.creds.ConsumerKey, auth.creds.ConsumerSecret, ) + case grantTypeJWT: + refreshedAuth, err = jwtFlow( + auth.InstanceUrl, + auth.creds.Username, + auth.creds.ConsumerKey, + auth.creds.ConsumerRSAPem, + JwtExpirationTime, + ) default: return errors.New("invalid session, unable to refresh session") } @@ -170,7 +186,7 @@ func setAccessToken(domain string, accessToken string) (*authentication, error) func jwtFlow(domain string, username string, consumerKey string, consumerRSAPem string, expirationTime time.Duration) (*authentication, error) { audience := domain - if(strings.Contains(audience, "sandbox")) { + if strings.Contains(audience, "sandbox") { audience = "https://test.salesforce.com" } else { audience = "https://login.salesforce.com" diff --git a/auth_test.go b/auth_test.go index f0b099e..9e8b6d3 100644 --- a/auth_test.go +++ b/auth_test.go @@ -5,6 +5,7 @@ import ( "os" "reflect" "testing" + "time" ) func Test_validateAuth(t *testing.T) { @@ -269,7 +270,7 @@ func Test_refreshSession(t *testing.T) { serverJwt, sfAuthJwt := setupTestServer(refreshedAuth, http.StatusOK) sampleKey, _ := os.ReadFile("test/sample_key.pem") sfAuthJwt.creds = Creds{ - Domain: serverJwt.URL, + Domain: serverJwt.URL, Username: "u", ConsumerKey: "key", ConsumerRSAPem: string(sampleKey), diff --git a/go.mod b/go.mod index cec9d47..a48c040 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require github.com/mitchellh/mapstructure v1.5.0 require github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee require ( - github.com/jszwec/csvutil v1.10.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/jszwec/csvutil v1.10.0 github.com/spf13/afero v1.11.0 k8s.io/apimachinery v0.31.1 ) diff --git a/go.sum b/go.sum index 63f2616..c39b7b0 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee/go.mod h1:bON1 github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/salesforce.go b/salesforce.go index e5c40dd..f37de87 100644 --- a/salesforce.go +++ b/salesforce.go @@ -202,12 +202,11 @@ func processSalesforceError(resp http.Response, auth *authentication, payload re func Init(creds Creds) (*Salesforce, error) { var auth *authentication var err error - if creds != (Creds{}) { + if creds == (Creds{}) { return nil, errors.New("creds is empty") } if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" && creds.Username != "" && creds.Password != "" && creds.SecurityToken != "" { - auth, err = usernamePasswordFlow( creds.Domain, creds.Username, @@ -216,17 +215,26 @@ func Init(creds Creds) (*Salesforce, error) { creds.ConsumerKey, creds.ConsumerSecret, ) - } else if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" { + } else if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" { auth, err = clientCredentialsFlow( creds.Domain, creds.ConsumerKey, creds.ConsumerSecret, ) - } else if creds != (Creds{}) && creds.AccessToken != "" { + } else if creds.AccessToken != "" { auth, err = setAccessToken( creds.Domain, creds.AccessToken, ) + } else if creds.Domain != "" && creds.Username != "" && + creds.ConsumerKey != "" && creds.ConsumerRSAPem != "" { + auth, err = jwtFlow( + creds.Domain, + creds.Username, + creds.ConsumerKey, + creds.ConsumerRSAPem, + JwtExpirationTime, + ) } if err != nil { @@ -429,6 +437,10 @@ func (sf *Salesforce) QueryStructBulkExport(soqlStruct any, filePath string) err } func (sf *Salesforce) QueryBulkIterator(query string) (IteratorJob, error) { + authErr := validateAuth(*sf) + if authErr != nil { + return nil, authErr + } queryJobReq := bulkQueryJobCreationRequest{ Operation: queryJobType, Query: query, diff --git a/salesforce_test.go b/salesforce_test.go index 4702d66..fad13c3 100644 --- a/salesforce_test.go +++ b/salesforce_test.go @@ -477,7 +477,7 @@ func TestInit(t *testing.T) { defer serverJwt.Close() sampleKey, _ := os.ReadFile("test/sample_key.pem") credsJwt := Creds{ - Domain: serverAccessToken.URL, + Domain: serverAccessToken.URL, Username: "u", ConsumerKey: "key", ConsumerRSAPem: string(sampleKey),