diff --git a/README.md b/README.md index fb353af..e6ca247 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Slings store HTTP Request properties to simplify sending requests and decoding r * Encode structs into URL query parameters * Encode a form or JSON into the Request Body * Receive JSON success or failure responses +* Control a request's lifetime via context ## Install @@ -271,6 +272,29 @@ func (s *IssueService) ListByRepo(owner, repo string, params *IssueListParams) ( } ``` +### Controlling lifetime via context +All the above functionality of a sling can be made context aware. + +Getting a context aware request: +```go +ctx, cancel := context.WithTimeout(context.Background(),10*time.Second) +req, err := sling.New().Get("https://example.com").RequestWithContext(ctx) +``` +Receiving in a context aware manner +```go +success := &struct{}{} +failure := &struct{}{} +ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +resp, err := sling.New().Path("https://example.com").Get("/foo").ReceiveWithContext(ctx,success,failure) +``` +After making the request you can first check whether request completed in time before proceeding with the response: +```go +if errors.Is(err, context.DeadlineExceeded) { + // Take action accordingly +} +``` +For more details about effectively using context please see: https://go.dev/blog/context + ## Example APIs using Sling * Digits [dghubble/go-digits](https://github.com/dghubble/go-digits) diff --git a/sling.go b/sling.go index 5492ef2..cb82f2b 100644 --- a/sling.go +++ b/sling.go @@ -1,6 +1,7 @@ package sling import ( + "context" "encoding/base64" "io" "net/http" @@ -277,6 +278,15 @@ func (s *Sling) BodyForm(bodyForm interface{}) *Sling { // Returns any errors parsing the rawURL, encoding query structs, encoding // the body, or creating the http.Request. func (s *Sling) Request() (*http.Request, error) { + return s.request(context.Background()) +} + +// RequestWithContext is similar to Request but allows you to pass a context. +func (s *Sling) RequestWithContext(ctx context.Context) (*http.Request, error) { + return s.request(ctx) +} + +func (s *Sling) request(ctx context.Context) (*http.Request, error) { reqURL, err := url.Parse(s.rawURL) if err != nil { return nil, err @@ -294,7 +304,7 @@ func (s *Sling) Request() (*http.Request, error) { return nil, err } } - req, err := http.NewRequest(s.method, reqURL.String(), body) + req, err := http.NewRequestWithContext(ctx, s.method, reqURL.String(), body) if err != nil { return nil, err } @@ -364,7 +374,16 @@ func (s *Sling) ReceiveSuccess(successV interface{}) (*http.Response, error) { // the response is returned. // Receive is shorthand for calling Request and Do. func (s *Sling) Receive(successV, failureV interface{}) (*http.Response, error) { - req, err := s.Request() + return s.receive(context.Background(), successV, failureV) +} + +// ReceiveWithContext is similar to Receive but allows you to pass a context. +func (s *Sling) ReceiveWithContext(ctx context.Context, successV, failureV interface{}) (*http.Response, error) { + return s.receive(ctx, successV, failureV) +} + +func (s *Sling) receive(ctx context.Context, successV, failureV interface{}) (*http.Response, error) { + req, err := s.RequestWithContext(ctx) if err != nil { return nil, err } diff --git a/sling_test.go b/sling_test.go index 370c920..8eab528 100644 --- a/sling_test.go +++ b/sling_test.go @@ -580,6 +580,20 @@ func TestRequest_headers(t *testing.T) { } } +func TestRequest_context(t *testing.T) { + ctx, fn := context.WithCancel(context.Background()) + defer fn() // cancel context eventually to release resources + + req, err := New().RequestWithContext(ctx) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if req.Context() != ctx { + t.Errorf("request.Context() is not same as context passed during request creation") + } +} + func TestAddQueryStructs(t *testing.T) { cases := []struct { rawurl string @@ -924,6 +938,26 @@ func TestReceive_errorCreatingRequest(t *testing.T) { } } +func TestReceive_context(t *testing.T) { + client, mux, server := testServer() + defer server.Close() + mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { + }) + + ctx, fn := context.WithCancel(context.Background()) + defer fn() + + endpoint := New().Client(client).Get("http://example.com/foo") + resp, err := endpoint.New().ReceiveWithContext(ctx, nil, nil) + if err != nil { + t.Errorf("expected nil, got %v", err) + } + + if resp.Request.Context() != ctx { + t.Error("request.Context() is not same as context passed during receive operation") + } +} + func TestReuseTcpConnections(t *testing.T) { var connCount int32