Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh when Invalid Session error occurs #58

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type authentication struct {
Scope string `json:"scope"`
IssuedAt string `json:"issued_at"`
Signature string `json:"signature"`
grantType string
creds Creds
}

type Creds struct {
Expand All @@ -30,8 +32,9 @@ type Creds struct {
}

const (
grantTypePassword = "password"
grantTypeUsernamePassword = "password"
grantTypeClientCredentials = "client_credentials"
grantTypeAccessToken = "access_token"
)

func validateAuth(sf Salesforce) error {
Expand All @@ -45,14 +48,52 @@ func validateSession(auth authentication) error {
if err := validateAuth(Salesforce{auth: &auth}); err != nil {
return err
}
_, err := doRequest(http.MethodGet, "/limits", jsonType, auth, "")
_, err := doRequest(&auth, requestPayload{
method: http.MethodGet,
uri: "/limits",
content: jsonType,
})
if err != nil {
return err
}

return nil
}

func refreshSession(auth *authentication) error {
k-capehart marked this conversation as resolved.
Show resolved Hide resolved
var refreshedAuth *authentication
var err error

switch grantType := auth.grantType; grantType {
case grantTypeClientCredentials:
refreshedAuth, err = clientCredentialsFlow(
auth.creds.Domain,
auth.creds.ConsumerKey,
auth.creds.ConsumerSecret,
)
k-capehart marked this conversation as resolved.
Show resolved Hide resolved
case grantTypeUsernamePassword:
refreshedAuth, err = usernamePasswordFlow(
auth.creds.Domain,
auth.creds.Username,
auth.creds.Password,
auth.creds.SecurityToken,
auth.creds.ConsumerKey,
auth.creds.ConsumerSecret,
)
default:
return errors.New("invalid session, unable to refresh session")
}

if refreshedAuth != nil {
auth.AccessToken = refreshedAuth.AccessToken
auth.IssuedAt = refreshedAuth.IssuedAt
auth.Signature = refreshedAuth.Signature
auth.Id = refreshedAuth.Id
}

return err
k-capehart marked this conversation as resolved.
Show resolved Hide resolved
}

func doAuth(url string, body *strings.Reader) (*authentication, error) {
resp, err := http.Post(url, "application/x-www-form-urlencoded", body)
if err != nil {
Expand All @@ -79,7 +120,7 @@ func doAuth(url string, body *strings.Reader) (*authentication, error) {

func usernamePasswordFlow(domain string, username string, password string, securityToken string, consumerKey string, consumerSecret string) (*authentication, error) {
payload := url.Values{
"grant_type": {grantTypePassword},
"grant_type": {grantTypeUsernamePassword},
"client_id": {consumerKey},
"client_secret": {consumerSecret},
"username": {username},
Expand All @@ -91,6 +132,7 @@ func usernamePasswordFlow(domain string, username string, password string, secur
if err != nil {
return nil, err
}
auth.grantType = grantTypeUsernamePassword
return auth, nil
}

Expand All @@ -106,6 +148,7 @@ func clientCredentialsFlow(domain string, consumerKey string, consumerSecret str
if err != nil {
return nil, err
}
auth.grantType = grantTypeClientCredentials
return auth, nil
}

Expand All @@ -114,5 +157,6 @@ func setAccessToken(domain string, accessToken string) (*authentication, error)
if err := validateSession(*auth); err != nil {
return nil, err
}
auth.grantType = grantTypeAccessToken
return auth, nil
}
2 changes: 2 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func Test_usernamePasswordFlow(t *testing.T) {
Id: "123abc",
IssuedAt: "01/01/1970",
Signature: "signed",
grantType: grantTypeUsernamePassword,
}
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()
Expand Down Expand Up @@ -117,6 +118,7 @@ func Test_clientCredentialsFlow(t *testing.T) {
Id: "123abc",
IssuedAt: "01/01/1970",
Signature: "signed",
grantType: grantTypeClientCredentials,
}
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()
Expand Down
63 changes: 43 additions & 20 deletions bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,29 @@ const (

var appFs = afero.NewOsFs() // afero.Fs type is a wrapper around os functions, allowing us to mock it in tests

func updateJobState(job bulkJob, state string, auth authentication) error {
func updateJobState(job bulkJob, state string, auth *authentication) error {
job.State = state
body, _ := json.Marshal(job)
_, err := doRequest(http.MethodPatch, "/jobs/ingest/"+job.Id, jsonType, auth, string(body))
_, err := doRequest(auth, requestPayload{
method: http.MethodPatch,
uri: "/jobs/ingest/" + job.Id,
content: jsonType,
body: string(body),
})
if err != nil {
return err
}

return nil
}

func createBulkJob(auth authentication, jobType string, body []byte) (bulkJob, error) {
resp, err := doRequest(http.MethodPost, "/jobs/"+jobType, jsonType, auth, string(body))
func createBulkJob(auth *authentication, jobType string, body []byte) (bulkJob, error) {
resp, err := doRequest(auth, requestPayload{
method: http.MethodPost,
uri: "/jobs/" + jobType,
content: jsonType,
body: string(body),
})
if err != nil {
return bulkJob{}, err
}
Expand All @@ -96,8 +106,13 @@ func createBulkJob(auth authentication, jobType string, body []byte) (bulkJob, e
return *newJob, nil
}

func uploadJobData(auth authentication, data string, bulkJob bulkJob) error {
_, uploadDataErr := doRequest("PUT", "/jobs/ingest/"+bulkJob.Id+"/batches", csvType, auth, data)
func uploadJobData(auth *authentication, data string, bulkJob bulkJob) error {
_, uploadDataErr := doRequest(auth, requestPayload{
method: http.MethodPut,
uri: "/jobs/ingest/" + bulkJob.Id + "/batches",
content: csvType,
body: data,
})
if uploadDataErr != nil {
if err := updateJobState(bulkJob, jobStateAborted, auth); err != nil {
return err
Expand All @@ -112,8 +127,12 @@ func uploadJobData(auth authentication, data string, bulkJob bulkJob) error {
return nil
}

func getJobResults(auth authentication, jobType string, bulkJobId string) (BulkJobResults, error) {
resp, err := doRequest(http.MethodGet, "/jobs/"+jobType+"/"+bulkJobId, jsonType, auth, "")
func getJobResults(auth *authentication, jobType string, bulkJobId string) (BulkJobResults, error) {
resp, err := doRequest(auth, requestPayload{
method: http.MethodGet,
uri: "/jobs/" + jobType + "/" + bulkJobId,
content: jsonType,
})
if err != nil {
return BulkJobResults{}, err
}
Expand All @@ -132,7 +151,7 @@ func getJobResults(auth authentication, jobType string, bulkJobId string) (BulkJ
return *bulkJobResults, nil
}

func getJobRecordResults(auth authentication, bulkJobResults BulkJobResults) (BulkJobResults, error) {
func getJobRecordResults(auth *authentication, bulkJobResults BulkJobResults) (BulkJobResults, error) {
successfulRecords, err := getBulkJobRecords(auth, bulkJobResults.Id, successfulResults)
if err != nil {
return bulkJobResults, fmt.Errorf("failed to get SuccessfulRecords: %w", err)
Expand All @@ -146,8 +165,12 @@ func getJobRecordResults(auth authentication, bulkJobResults BulkJobResults) (Bu
return bulkJobResults, err
}

func getBulkJobRecords(auth authentication, bulkJobId string, resultType string) ([]map[string]any, error) {
resp, err := doRequest(http.MethodGet, "/jobs/ingest/"+bulkJobId+"/"+resultType, jsonType, auth, "")
func getBulkJobRecords(auth *authentication, bulkJobId string, resultType string) ([]map[string]any, error) {
resp, err := doRequest(auth, requestPayload{
method: http.MethodGet,
uri: "/jobs/ingest/" + bulkJobId + "/" + resultType,
content: jsonType,
})
if err != nil {
return nil, err
}
Expand All @@ -160,7 +183,7 @@ func getBulkJobRecords(auth authentication, bulkJobId string, resultType string)
return results, nil
}

func waitForJobResultsAsync(auth authentication, bulkJobId string, jobType string, interval time.Duration, c chan error) {
func waitForJobResultsAsync(auth *authentication, bulkJobId string, jobType string, interval time.Duration, c chan error) {
err := wait.PollUntilContextTimeout(context.Background(), interval, time.Minute, false, func(context.Context) (bool, error) {
bulkJob, reqErr := getJobResults(auth, jobType, bulkJobId)
if reqErr != nil {
Expand All @@ -171,7 +194,7 @@ func waitForJobResultsAsync(auth authentication, bulkJobId string, jobType strin
c <- err
}

func waitForJobResults(auth authentication, bulkJobId string, jobType string, interval time.Duration) error {
func waitForJobResults(auth *authentication, bulkJobId string, jobType string, interval time.Duration) error {
err := wait.PollUntilContextTimeout(context.Background(), interval, time.Minute, false, func(context.Context) (bool, error) {
bulkJob, reqErr := getJobResults(auth, jobType, bulkJobId)
if reqErr != nil {
Expand All @@ -195,12 +218,12 @@ func isBulkJobDone(bulkJob BulkJobResults) (bool, error) {
return false, nil
}

func getQueryJobResults(auth authentication, bulkJobId string, locator string) (bulkJobQueryResults, error) {
func getQueryJobResults(auth *authentication, bulkJobId string, locator string) (bulkJobQueryResults, error) {
uri := "/jobs/query/" + bulkJobId + "/results"
if locator != "" {
uri = uri + "/?locator=" + locator
}
resp, err := doRequest(http.MethodGet, uri, jsonType, auth, "")
resp, err := doRequest(auth, requestPayload{method: http.MethodGet, uri: uri, content: jsonType})
if err != nil {
return bulkJobQueryResults{}, err
}
Expand All @@ -225,7 +248,7 @@ func getQueryJobResults(auth authentication, bulkJobId string, locator string) (
return queryResults, nil
}

func collectQueryResults(auth authentication, bulkJobId string) ([][]string, error) {
func collectQueryResults(auth *authentication, bulkJobId string) ([][]string, error) {
queryResults, resultsErr := getQueryJobResults(auth, bulkJobId, "")
if resultsErr != nil {
return nil, resultsErr
Expand Down Expand Up @@ -336,7 +359,7 @@ func writeCSVFile(filePath string, data [][]string) error {
return nil
}

func constructBulkJobRequest(auth authentication, sObjectName string, operation string, fieldName string) (bulkJob, error) {
func constructBulkJobRequest(auth *authentication, sObjectName string, operation string, fieldName string) (bulkJob, error) {
jobReq := bulkJobCreationRequest{
Object: sObjectName,
Operation: operation,
Expand All @@ -356,7 +379,7 @@ func constructBulkJobRequest(auth authentication, sObjectName string, operation
return job, nil
}

func doBulkJob(auth authentication, sObjectName string, fieldName string, operation string, records any, batchSize int, waitForResults bool) ([]string, error) {
func doBulkJob(auth *authentication, sObjectName string, fieldName string, operation string, records any, batchSize int, waitForResults bool) ([]string, error) {
recordMap, err := convertToSliceOfMaps(records)
if err != nil {
return []string{}, err
Expand Down Expand Up @@ -402,7 +425,7 @@ func doBulkJob(auth authentication, sObjectName string, fieldName string, operat
return jobIds, jobErrors
}

func doBulkJobWithFile(auth authentication, sObjectName string, fieldName string, operation string, filePath string, batchSize int, waitForResults bool) ([]string, error) {
func doBulkJobWithFile(auth *authentication, sObjectName string, fieldName string, operation string, filePath string, batchSize int, waitForResults bool) ([]string, error) {
var jobErrors error
var jobIds []string

Expand Down Expand Up @@ -461,7 +484,7 @@ func doBulkJobWithFile(auth authentication, sObjectName string, fieldName string
return jobIds, jobErrors
}

func doQueryBulk(auth authentication, filePath string, query string) error {
func doQueryBulk(auth *authentication, filePath string, query string) error {
queryJobReq := bulkQueryJobCreationRequest{
Operation: queryJobType,
Query: query,
Expand Down
Loading
Loading