Skip to content

Commit

Permalink
Fix HttpFields.Mutable.Wrapper.computeField() (#11688)
Browse files Browse the repository at this point in the history
* #11687 make HttpFields.Mutable.Wrapper.computeField() call onRemoveField() and remove the field when null is returned by computeFn
* #11687 replace IAE with NPE
* #11687 replace collect(Collectors.toList()) with toList()

---------

Signed-off-by: Ludovic Orban <[email protected]>
  • Loading branch information
lorban authored Apr 23, 2024
1 parent 4cc9384 commit bb633b8
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -896,8 +896,8 @@ interface Mutable extends HttpFields
*/
default Mutable add(String name, String value)
{
if (value == null)
throw new IllegalArgumentException("null value");
Objects.requireNonNull(name);
Objects.requireNonNull(value);
return add(new HttpField(name, value));
}

Expand All @@ -912,6 +912,7 @@ default Mutable add(String name, String value)
*/
default Mutable add(String name, long value)
{
Objects.requireNonNull(name);
return add(new HttpField.LongValueHttpField(name, value));
}

Expand All @@ -926,6 +927,8 @@ default Mutable add(String name, long value)
*/
default Mutable add(HttpHeader header, HttpHeaderValue value)
{
Objects.requireNonNull(header);
Objects.requireNonNull(value);
return add(header, value.toString());
}

Expand All @@ -940,8 +943,8 @@ default Mutable add(HttpHeader header, HttpHeaderValue value)
*/
default Mutable add(HttpHeader header, String value)
{
if (value == null)
throw new IllegalArgumentException("null value");
Objects.requireNonNull(header);
Objects.requireNonNull(value);
return add(new HttpField(header, value));
}

Expand All @@ -956,6 +959,7 @@ default Mutable add(HttpHeader header, String value)
*/
default Mutable add(HttpHeader header, long value)
{
Objects.requireNonNull(header);
return add(new HttpField.LongValueHttpField(header, value));
}

Expand All @@ -967,6 +971,7 @@ default Mutable add(HttpHeader header, long value)
*/
default Mutable add(HttpField field)
{
Objects.requireNonNull(field);
ListIterator<HttpField> i = listIterator(size());
i.add(field);
return this;
Expand Down Expand Up @@ -998,8 +1003,7 @@ default Mutable add(HttpFields fields)
default Mutable add(String name, List<String> list)
{
Objects.requireNonNull(name);
if (list == null)
throw new IllegalArgumentException("null list");
Objects.requireNonNull(list);
if (list.isEmpty())
return this;
if (list.size() == 1)
Expand All @@ -1022,6 +1026,7 @@ default Mutable add(String name, List<String> list)
*/
default Mutable addCSV(HttpHeader header, String... values)
{
Objects.requireNonNull(header);
QuotedCSV existing = null;
for (HttpField f : this)
{
Expand Down Expand Up @@ -1049,6 +1054,7 @@ default Mutable addCSV(HttpHeader header, String... values)
*/
default Mutable addCSV(String name, String... values)
{
Objects.requireNonNull(name);
QuotedCSV existing = null;
for (HttpField f : this)
{
Expand Down Expand Up @@ -1076,6 +1082,7 @@ default Mutable addCSV(String name, String... values)
*/
default Mutable addDateField(String name, long date)
{
Objects.requireNonNull(name);
add(name, DateGenerator.formatDate(date));
return this;
}
Expand Down Expand Up @@ -1105,6 +1112,7 @@ default Mutable clear()
*/
default void ensureField(HttpField field)
{
Objects.requireNonNull(field);
HttpHeader header = field.getHeader();
// Is the field value multi valued?
if (field.getValue().indexOf(',') < 0)
Expand Down Expand Up @@ -1136,6 +1144,7 @@ default void ensureField(HttpField field)
*/
default Mutable put(HttpField field)
{
Objects.requireNonNull(field);
boolean put = false;
ListIterator<HttpField> i = listIterator();
while (i.hasNext())
Expand Down Expand Up @@ -1170,6 +1179,7 @@ default Mutable put(HttpField field)
*/
default Mutable put(String name, String value)
{
Objects.requireNonNull(name);
if (value == null)
return remove(name);
return put(new HttpField(name, value));
Expand All @@ -1186,6 +1196,7 @@ default Mutable put(String name, String value)
*/
default Mutable put(HttpHeader header, HttpHeaderValue value)
{
Objects.requireNonNull(header);
if (value == null)
return remove(header);
return put(new HttpField(header, value.toString()));
Expand All @@ -1202,6 +1213,7 @@ default Mutable put(HttpHeader header, HttpHeaderValue value)
*/
default Mutable put(HttpHeader header, String value)
{
Objects.requireNonNull(header);
if (value == null)
return remove(header);
return put(new HttpField(header, value));
Expand Down Expand Up @@ -1241,6 +1253,7 @@ default Mutable put(String name, List<String> list)
*/
default Mutable putDate(HttpHeader name, long date)
{
Objects.requireNonNull(name);
return put(name, DateGenerator.formatDate(date));
}

Expand All @@ -1256,6 +1269,7 @@ default Mutable putDate(HttpHeader name, long date)
*/
default Mutable putDate(String name, long date)
{
Objects.requireNonNull(name);
return put(name, DateGenerator.formatDate(date));
}

Expand All @@ -1269,6 +1283,7 @@ default Mutable putDate(String name, long date)
*/
default Mutable put(HttpHeader header, long value)
{
Objects.requireNonNull(header);
if (value == 0 && header == HttpHeader.CONTENT_LENGTH)
return put(HttpFields.CONTENT_LENGTH_0);
return put(new HttpField.LongValueHttpField(header, value));
Expand All @@ -1284,6 +1299,7 @@ default Mutable put(HttpHeader header, long value)
*/
default Mutable put(String name, long value)
{
Objects.requireNonNull(name);
if (value == 0 && HttpHeader.CONTENT_LENGTH.is(name))
return put(HttpFields.CONTENT_LENGTH_0);
return put(new HttpField.LongValueHttpField(name, value));
Expand Down Expand Up @@ -1367,7 +1383,9 @@ default Mutable put(String name, long value)
*/
default Mutable computeField(HttpHeader header, BiFunction<HttpHeader, List<HttpField>, HttpField> computeFn)
{
return put(computeFn.apply(header, stream().filter(f -> f.getHeader() == header).collect(Collectors.toList())));
Objects.requireNonNull(header);
HttpField result = computeFn.apply(header, stream().filter(f -> f.getHeader() == header).toList());
return result != null ? put(result) : remove(header);
}

/**
Expand All @@ -1380,7 +1398,9 @@ default Mutable computeField(HttpHeader header, BiFunction<HttpHeader, List<Http
*/
default Mutable computeField(String name, BiFunction<String, List<HttpField>, HttpField> computeFn)
{
return put(computeFn.apply(name, stream().filter(f -> f.is(name)).collect(Collectors.toList())));
Objects.requireNonNull(name);
HttpField result = computeFn.apply(name, stream().filter(f -> f.is(name)).toList());
return result != null ? put(result) : remove(name);
}

/**
Expand All @@ -1391,6 +1411,7 @@ default Mutable computeField(String name, BiFunction<String, List<HttpField>, Ht
*/
default Mutable remove(HttpHeader header)
{
Objects.requireNonNull(header);
Iterator<HttpField> i = iterator();
while (i.hasNext())
{
Expand Down Expand Up @@ -1428,6 +1449,7 @@ default Mutable remove(EnumSet<HttpHeader> headers)
*/
default Mutable remove(String name)
{
Objects.requireNonNull(name);
for (ListIterator<HttpField> i = listIterator(); i.hasNext(); )
{
HttpField f = i.next();
Expand Down Expand Up @@ -1656,6 +1678,7 @@ public Mutable add(HttpField field)
@Override
public Mutable put(HttpField field)
{
Objects.requireNonNull(field);
// rewrite put to ensure that removes are called before replace
int put = -1;
ListIterator<HttpField> i = _fields.listIterator();
Expand All @@ -1675,7 +1698,7 @@ else if (onRemoveField(f))
{
field = onAddField(field);
if (field != null)
add(field);
_fields.add(field);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ public void testPutNullName()
public void testAddNullValueList()
{
HttpFields.Mutable fields = HttpFields.build();
assertThrows(IllegalArgumentException.class, () -> fields.add("name", (List<String>)null));
assertThrows(NullPointerException.class, () -> fields.add("name", (List<String>)null));
assertThat(fields.size(), is(0));
List<String> list = new ArrayList<>();
fields.add("name", list);
Expand Down Expand Up @@ -1374,4 +1374,56 @@ public void testEnsureStringMultiValue()
fields.ensureField(new HttpField("Test", "three, four"));
assertThat(fields.stream().map(HttpField::toString).collect(Collectors.toList()), contains("Test: one, two, three, four"));
}

@Test
public void testWrapperComputeFieldCallingOnField()
{
var wrapper = new HttpFields.Mutable.Wrapper(HttpFields.build())
{
final List<String> actions = new ArrayList<>();

@Override
public HttpField onAddField(HttpField field)
{
actions.add("onAddField");
return super.onAddField(field);
}

@Override
public boolean onRemoveField(HttpField field)
{
actions.add("onRemoveField");
return super.onRemoveField(field);
}

@Override
public HttpField onReplaceField(HttpField oldField, HttpField newField)
{
actions.add("onReplaceField");
return super.onReplaceField(oldField, newField);
}
};

wrapper.computeField("non-existent", (name, httpFields) -> null);
assertThat(wrapper.size(), is(0));
assertThat(wrapper.actions, is(List.of()));

wrapper.computeField("non-existent", (name, httpFields) -> new HttpField("non-existent", "a"));
wrapper.computeField("non-existent", (name, httpFields) -> new HttpField("non-existent", "b"));
wrapper.computeField("non-existent", (name, httpFields) -> null);
assertThat(wrapper.size(), is(0));
assertThat(wrapper.actions, is(List.of("onAddField", "onReplaceField", "onRemoveField")));
wrapper.actions.clear();

wrapper.computeField(HttpHeader.VARY, (name, httpFields) -> null);
assertThat(wrapper.size(), is(0));
assertThat(wrapper.actions, is(List.of()));

wrapper.computeField(HttpHeader.VARY, (name, httpFields) -> new HttpField(HttpHeader.VARY, "a"));
wrapper.computeField(HttpHeader.VARY, (name, httpFields) -> new HttpField(HttpHeader.VARY, "b"));
wrapper.computeField(HttpHeader.VARY, (name, httpFields) -> null);
assertThat(wrapper.size(), is(0));
assertThat(wrapper.actions, is(List.of("onAddField", "onReplaceField", "onRemoveField")));
wrapper.actions.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void testEmptyHeaders() throws Exception

HttpFields.Mutable fields = HttpFields.build();
fields.add("Host", "something");
assertThrows(IllegalArgumentException.class, () -> fields.add("Null", (String)null));
assertThrows(IllegalArgumentException.class, () -> fields.add("Null", (List<String>)null));
assertThrows(NullPointerException.class, () -> fields.add("Null", (String)null));
assertThrows(NullPointerException.class, () -> fields.add("Null", (List<String>)null));
fields.add("Empty", "");
RequestInfo info = new RequestInfo("GET", "/index.html", fields);
assertFalse(gen.isChunking());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import org.eclipse.jetty.server.LocalConnector;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.ResponseUtils;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.server.handler.ContextHandlerCollection;
Expand All @@ -103,6 +104,7 @@
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -692,6 +694,31 @@ public void destroyServer() throws Exception
_server.join();
}

@Test
public void testEnsureNotPersistent() throws Exception
{
ServletContextHandler root = new ServletContextHandler("/", ServletContextHandler.SESSIONS);
root.setContextPath("/");
root.addServlet(new ServletHolder(new HttpServlet()
{
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp)
{
Request request = ((ServletApiRequest)req).getRequest();
Response response = ((ServletApiResponse)resp).getResponse();

ResponseUtils.ensureNotPersistent(request, response);
}
}), "/ensureNotPersistent");
_server.setHandler(root);

_server.start();

String rawResponse = _connector.getResponse("GET /ensureNotPersistent HTTP/1.0\r\n\r\n");
HttpTester.Response response = HttpTester.parseResponse(rawResponse);
assertThat(response.getStatus(), is(200));
}

@Test
public void testInitParams() throws Exception
{
Expand Down

0 comments on commit bb633b8

Please sign in to comment.