Skip to content

Commit

Permalink
merge branches, fix auth issue
Browse files Browse the repository at this point in the history
  • Loading branch information
k-capehart committed Nov 4, 2024
1 parent 48217a2 commit 083a8d1
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 8 deletions.
18 changes: 17 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package salesforce
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
)

type authentication struct {
Expand All @@ -29,13 +33,17 @@ type Creds struct {
SecurityToken string
ConsumerKey string
ConsumerSecret string
ConsumerRSAPem string
AccessToken string
}

const JwtExpirationTime = 5 * time.Minute

const (
grantTypeUsernamePassword = "password"
grantTypeClientCredentials = "client_credentials"
grantTypeAccessToken = "access_token"
grantTypeJWT = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)

func validateAuth(sf Salesforce) error {
Expand Down Expand Up @@ -81,6 +89,14 @@ func refreshSession(auth *authentication) error {
auth.creds.ConsumerKey,
auth.creds.ConsumerSecret,
)
case grantTypeJWT:
refreshedAuth, err = jwtFlow(
auth.InstanceUrl,
auth.creds.Username,
auth.creds.ConsumerKey,
auth.creds.ConsumerRSAPem,
JwtExpirationTime,
)
default:
return errors.New("invalid session, unable to refresh session")
}
Expand Down Expand Up @@ -170,7 +186,7 @@ func setAccessToken(domain string, accessToken string) (*authentication, error)

func jwtFlow(domain string, username string, consumerKey string, consumerRSAPem string, expirationTime time.Duration) (*authentication, error) {
audience := domain
if(strings.Contains(audience, "sandbox")) {
if strings.Contains(audience, "sandbox") {
audience = "https://test.salesforce.com"
} else {
audience = "https://login.salesforce.com"
Expand Down
3 changes: 2 additions & 1 deletion auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"reflect"
"testing"
"time"
)

func Test_validateAuth(t *testing.T) {
Expand Down Expand Up @@ -269,7 +270,7 @@ func Test_refreshSession(t *testing.T) {
serverJwt, sfAuthJwt := setupTestServer(refreshedAuth, http.StatusOK)
sampleKey, _ := os.ReadFile("test/sample_key.pem")
sfAuthJwt.creds = Creds{
Domain: serverJwt.URL,
Domain: serverJwt.URL,
Username: "u",
ConsumerKey: "key",
ConsumerRSAPem: string(sampleKey),
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ require github.com/mitchellh/mapstructure v1.5.0
require github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee

require (
github.com/jszwec/csvutil v1.10.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/jszwec/csvutil v1.10.0
github.com/spf13/afero v1.11.0
k8s.io/apimachinery v0.31.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee/go.mod h1:bON1
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
Expand Down
20 changes: 16 additions & 4 deletions salesforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,11 @@ func processSalesforceError(resp http.Response, auth *authentication, payload re
func Init(creds Creds) (*Salesforce, error) {
var auth *authentication
var err error
if creds != (Creds{}) {
if creds == (Creds{}) {
return nil, errors.New("creds is empty")
}
if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" &&
creds.Username != "" && creds.Password != "" && creds.SecurityToken != "" {

auth, err = usernamePasswordFlow(
creds.Domain,
creds.Username,
Expand All @@ -216,17 +215,26 @@ func Init(creds Creds) (*Salesforce, error) {
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" {
} else if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" {
auth, err = clientCredentialsFlow(
creds.Domain,
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.AccessToken != "" {
} else if creds.AccessToken != "" {
auth, err = setAccessToken(
creds.Domain,
creds.AccessToken,
)
} else if creds.Domain != "" && creds.Username != "" &&
creds.ConsumerKey != "" && creds.ConsumerRSAPem != "" {
auth, err = jwtFlow(
creds.Domain,
creds.Username,
creds.ConsumerKey,
creds.ConsumerRSAPem,
JwtExpirationTime,
)
}

if err != nil {
Expand Down Expand Up @@ -429,6 +437,10 @@ func (sf *Salesforce) QueryStructBulkExport(soqlStruct any, filePath string) err
}

func (sf *Salesforce) QueryBulkIterator(query string) (IteratorJob, error) {
authErr := validateAuth(*sf)
if authErr != nil {
return nil, authErr
}

Check warning on line 443 in salesforce.go

View check run for this annotation

Codecov / codecov/patch

salesforce.go#L442-L443

Added lines #L442 - L443 were not covered by tests
queryJobReq := bulkQueryJobCreationRequest{
Operation: queryJobType,
Query: query,
Expand Down
2 changes: 1 addition & 1 deletion salesforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func TestInit(t *testing.T) {
defer serverJwt.Close()
sampleKey, _ := os.ReadFile("test/sample_key.pem")
credsJwt := Creds{
Domain: serverAccessToken.URL,
Domain: serverAccessToken.URL,
Username: "u",
ConsumerKey: "key",
ConsumerRSAPem: string(sampleKey),
Expand Down

0 comments on commit 083a8d1

Please sign in to comment.