@@ -18,13 +18,15 @@ package utils
18
18
19
19
import (
20
20
"encoding/json"
21
+ "fmt"
21
22
"io/ioutil"
22
23
"net/http"
23
24
"net/http/httptest"
24
25
"testing"
25
26
26
27
mock_audit "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/mocks"
27
28
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request"
29
+ "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response"
28
30
"github.com/golang/mock/gomock"
29
31
"github.com/stretchr/testify/assert"
30
32
"github.com/stretchr/testify/require"
@@ -83,6 +85,32 @@ func TestWriteJSONToResponse(t *testing.T) {
83
85
assert .Equal (t , `"Unable to get task arn from request"` , bodyString )
84
86
}
85
87
88
+ // Tests that WriteJSONResponse marshals the provided response to JSON and writes it to
89
+ // the response writer.
90
+ func TestWriteJSONResponse (t * testing.T ) {
91
+ recorder := httptest .NewRecorder ()
92
+ res := response.PortResponse {ContainerPort : 8080 , Protocol : "TCP" , HostPort : 80 , HostIp : "IP" }
93
+ WriteJSONResponse (recorder , http .StatusOK , res , RequestTypeTaskMetadata )
94
+
95
+ var actualResponse response.PortResponse
96
+ err := json .Unmarshal (recorder .Body .Bytes (), & actualResponse )
97
+ require .NoError (t , err )
98
+
99
+ assert .Equal (t , http .StatusOK , recorder .Code )
100
+ assert .Equal (t , res , actualResponse )
101
+ }
102
+
103
+ // Tests that an empty JSON response is written by WriteJSONResponse if the provided response
104
+ // is not convertible to JSON.
105
+ func TestWriteJSONResponseError (t * testing.T ) {
106
+ recorder := httptest .NewRecorder ()
107
+ res := func (k string ) string { return k }
108
+ WriteJSONResponse (recorder , http .StatusOK , res , RequestTypeTaskMetadata )
109
+
110
+ assert .Equal (t , http .StatusInternalServerError , recorder .Code )
111
+ assert .Equal (t , "{}" , recorder .Body .String ())
112
+ }
113
+
86
114
func TestValueFromRequest (t * testing.T ) {
87
115
r , _ := http .NewRequest ("GET" , "/v1/credentials?id=credid" , nil )
88
116
val , ok := ValueFromRequest (r , "id" )
@@ -108,3 +136,19 @@ func TestLimitReachHandler(t *testing.T) {
108
136
recorder := httptest .NewRecorder ()
109
137
handler .ServeHTTP (recorder , req )
110
138
}
139
+
140
+ func TestIs5XXStatus (t * testing.T ) {
141
+ yes := []int {500 , 501 , 550 , http .StatusInternalServerError , http .StatusServiceUnavailable , 580 , 599 }
142
+ for _ , y := range yes {
143
+ t .Run (fmt .Sprintf ("yes %d" , y ), func (t * testing.T ) {
144
+ assert .True (t , Is5XXStatus (y ))
145
+ })
146
+ }
147
+
148
+ no := []int {http .StatusTooEarly , http .StatusBadRequest , http .StatusTooManyRequests , 400 , 450 , 600 , 200 , 301 }
149
+ for _ , n := range no {
150
+ t .Run (fmt .Sprintf ("no %d" , n ), func (t * testing.T ) {
151
+ assert .False (t , Is5XXStatus (n ))
152
+ })
153
+ }
154
+ }
0 commit comments