Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.execution.QueryState;
import io.trino.server.DisconnectionAwareAsyncResponse;
import io.trino.server.ExternalUriInfo;
import io.trino.server.GoneException;
import io.trino.server.HttpRequestSessionContextFactory;
import io.trino.server.ServerConfig;
import io.trino.server.SessionContext;
Expand All @@ -48,21 +49,22 @@
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.BeanParam;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.Suspended;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;

import java.net.URI;
import java.util.Optional;
Expand Down Expand Up @@ -93,10 +95,6 @@
import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST;
import static jakarta.ws.rs.core.Response.Status.FORBIDDEN;
import static jakarta.ws.rs.core.Response.Status.NOT_FOUND;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
Expand Down Expand Up @@ -164,7 +162,7 @@ public Response postStatement(
@BeanParam ExternalUriInfo externalUriInfo)
{
if (isNullOrEmpty(statement)) {
throw badRequest(BAD_REQUEST, "SQL statement is empty");
throw new BadRequestException("SQL statement is empty");
}

Query query = registerQuery(statement, servletRequest, httpHeaders);
Expand All @@ -177,7 +175,7 @@ private Query registerQuery(String statement, HttpServletRequest servletRequest,
Optional<String> remoteAddress = Optional.ofNullable(servletRequest.getRemoteAddr());
Optional<Identity> identity = authenticatedIdentity(servletRequest);
if (identity.flatMap(Identity::getPrincipal).map(InternalPrincipal.class::isInstance).orElse(false)) {
throw badRequest(FORBIDDEN, "Internal communication can not be used to start a query");
throw new ForbiddenException("Internal communication can not be used to start a query");
}

MultivaluedMap<String, String> headers = httpHeaders.getRequestHeaders();
Expand Down Expand Up @@ -241,7 +239,7 @@ private Query getQuery(QueryId queryId, String slug, long token)
{
Query query = queryManager.getQuery(queryId);
if (query == null || !query.getSlug().isValid(QUEUED_QUERY, slug, token)) {
throw badRequest(NOT_FOUND, "Query not found");
throw new NotFoundException("Query not found");
}
return query;
}
Expand Down Expand Up @@ -296,15 +294,6 @@ private static QueryResults createQueryResults(
null);
}

private static WebApplicationException badRequest(Status status, String message)
{
throw new WebApplicationException(
Response.status(status)
.type(TEXT_PLAIN_TYPE)
.entity(message)
.build());
}

private static final class Query
{
private final String query;
Expand Down Expand Up @@ -387,7 +376,7 @@ public QueryResults getQueryResults(long token, ExternalUriInfo externalUriInfo)
long lastToken = this.lastToken.get();
// token should be the last token or the next token
if (token != lastToken && token != lastToken + 1) {
throw new WebApplicationException(Response.Status.GONE);
throw new GoneException("Invalid token");
}
// advance (or stay at) the token
this.lastToken.compareAndSet(lastToken, token);
Expand All @@ -402,9 +391,7 @@ public QueryResults getQueryResults(long token, ExternalUriInfo externalUriInfo)

DispatchInfo dispatchInfo = dispatchManager.getDispatchInfo(queryId)
// query should always be found, but it may have just been determined to be abandoned
.orElseThrow(() -> new WebApplicationException(Response
.status(NOT_FOUND)
.build()));
.orElseThrow(NotFoundException::new);

return createQueryResults(token + 1, externalUriInfo, dispatchInfo);
}
Expand Down
31 changes: 31 additions & 0 deletions core/trino-main/src/main/java/io/trino/server/GoneException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server;

import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Response;

public class GoneException
extends WebApplicationException
{
public GoneException(String message)
{
super(message, Response.Status.GONE);
}

public GoneException()
{
super(Response.Status.GONE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@
import io.trino.sql.parser.SqlParser;
import io.trino.transaction.TransactionId;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;

import java.net.URLDecoder;
import java.util.Collection;
Expand Down Expand Up @@ -94,14 +91,13 @@ public SessionContext createSessionContext(
MultivaluedMap<String, String> headers,
Optional<String> remoteAddress,
Optional<Identity> authenticatedIdentity)
throws WebApplicationException
{
ProtocolHeaders protocolHeaders;
try {
protocolHeaders = detectProtocol(alternateHeaderName, headers.keySet());
}
catch (ProtocolDetectionException e) {
throw badRequest(e.getMessage());
throw new BadRequestException(e.getMessage());
}
Optional<String> catalog = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog())));
Optional<String> schema = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema())));
Expand Down Expand Up @@ -145,7 +141,7 @@ case ParsedSessionPropertyName(Optional<String> catalogName, String propertyName
// catalog session properties cannot be validated until the transaction has started
catalogSessionProperties.computeIfAbsent(catalogName.orElseThrow(), id -> new HashMap<>()).put(propertyName, propertyValue);
}
default -> throw badRequest(format("Invalid %s header", protocolHeaders.requestSession()));
default -> throw new BadRequestException(format("Invalid %s header", protocolHeaders.requestSession()));
}
}
requireNonNull(catalogSessionProperties, "catalogSessionProperties is null");
Expand Down Expand Up @@ -196,7 +192,7 @@ public Identity extractAuthorizedIdentity(Optional<Identity> optionalAuthenticat
protocolHeaders = detectProtocol(alternateHeaderName, headers.keySet());
}
catch (ProtocolDetectionException e) {
throw badRequest(e.getMessage());
throw new BadRequestException(e.getMessage());
}

Identity identity = buildSessionIdentity(optionalAuthenticatedIdentity, protocolHeaders, headers);
Expand Down Expand Up @@ -320,7 +316,7 @@ private static SelectedRole toSelectedRole(ProtocolHeaders protocolHeaders, Stri
role = SelectedRole.valueOf(value);
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header", protocolHeaders.requestRole()));
throw new BadRequestException(format("Invalid %s header", protocolHeaders.requestRole()));
}
return role;
}
Expand All @@ -340,7 +336,7 @@ private static Map<String, String> parseProperty(MultivaluedMap<String, String>
properties.put(nameValue.get(0), urlDecode(nameValue.get(1)));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", headerName, e));
throw new BadRequestException(format("Invalid %s header: %s", headerName, e));
}
}
return properties;
Expand Down Expand Up @@ -374,10 +370,10 @@ private static ResourceEstimates parseResourceEstimate(ProtocolHeaders protocolH
builder.setPeakMemory(DataSize.valueOf(value));
return;
}
throw badRequest(format("Unsupported resource name %s", name));
throw new BadRequestException(format("Unsupported resource name %s", name));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e));
throw new BadRequestException(format("Unsupported format for resource estimate '%s': %s", value, e));
}
});

Expand All @@ -397,7 +393,7 @@ private static ParsedSessionPropertyName parseSessionPropertyName(String value)
private static void assertRequest(boolean expression, String format, Object... args)
{
if (!expression) {
throw badRequest(format(format, args));
throw new BadRequestException(format(format, args));
}
}

Expand All @@ -410,7 +406,7 @@ private Map<String, String> parsePreparedStatementsHeaders(ProtocolHeaders proto
statementName = urlDecode(key);
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
throw new BadRequestException(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
}
String sqlString = preparedStatementEncoder.decodePreparedStatementFromHeader(value);

Expand All @@ -420,7 +416,7 @@ private Map<String, String> parsePreparedStatementsHeaders(ProtocolHeaders proto
sqlParser.createStatement(sqlString);
}
catch (ParsingException e) {
throw badRequest(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
throw new BadRequestException(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
}

preparedStatements.put(statementName, sqlString);
Expand All @@ -439,19 +435,10 @@ private static Optional<TransactionId> parseTransactionId(String transactionId)
return Optional.of(TransactionId.valueOf(transactionId));
}
catch (Exception e) {
throw badRequest(e.getMessage());
throw new BadRequestException(e.getMessage());
}
}

private static WebApplicationException badRequest(String message)
{
throw new WebApplicationException(message, Response
.status(Status.BAD_REQUEST)
.type(MediaType.TEXT_PLAIN)
.entity(message)
.build());
}

private static String trimEmptyToNull(String value)
{
return emptyToNull(nullToEmpty(value).trim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public Response getQueryInfo(@PathParam("queryId") QueryId queryId, @QueryParam(
Optional<QueryInfo> queryInfo = dispatchManager.getFullQueryInfo(queryId)
.map(info -> pruned ? pruneQueryInfo(info, info.getVersion()) : info);
if (queryInfo.isEmpty()) {
return Response.status(Status.GONE).build();
throw new GoneException();
}
try {
checkCanViewQueryOwnedBy(sessionContextFactory.extractAuthorizedIdentity(servletRequest, httpHeaders), queryInfo.get().getSession().toIdentity(), accessControl);
Expand Down Expand Up @@ -165,7 +165,7 @@ private Response failQuery(QueryId queryId, TrinoException queryException, HttpS
throw new ForbiddenException();
}
catch (NoSuchElementException e) {
return Response.status(Status.GONE).build();
throw new GoneException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
Expand All @@ -46,7 +46,6 @@
import static io.trino.server.QueryStateInfo.createQueryStateInfo;
import static io.trino.server.QueryStateInfo.createQueuedQueryStateInfo;
import static io.trino.server.security.ResourceSecurity.AccessType.AUTHENTICATED_USER;
import static jakarta.ws.rs.core.Response.Status.NOT_FOUND;
import static java.util.Objects.requireNonNull;

@Path("/v1/queryState")
Expand Down Expand Up @@ -108,7 +107,6 @@ private QueryStateInfo getQueryStateInfo(BasicQueryInfo queryInfo)
@Path("{queryId}")
@Produces(MediaType.APPLICATION_JSON)
public QueryStateInfo getQueryStateInfo(@PathParam("queryId") String queryId, @Context HttpServletRequest servletRequest, @Context HttpHeaders httpHeaders)
throws WebApplicationException
{
try {
BasicQueryInfo queryInfo = dispatchManager.getQueryInfo(new QueryId(queryId));
Expand All @@ -119,7 +117,7 @@ public QueryStateInfo getQueryStateInfo(@PathParam("queryId") String queryId, @C
throw new ForbiddenException();
}
catch (NoSuchElementException e) {
throw new WebApplicationException(NOT_FOUND);
throw new NotFoundException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import io.trino.spi.resourcegroups.ResourceGroupId;
import jakarta.ws.rs.Encoded;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;

import java.net.URLDecoder;
Expand All @@ -31,7 +31,6 @@
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.server.security.ResourceSecurity.AccessType.MANAGEMENT_READ;
import static jakarta.ws.rs.core.Response.Status.NOT_FOUND;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -59,9 +58,9 @@ public ResourceGroupInfo getQueryStateInfos(@PathParam("resourceGroupId") String
Arrays.stream(resourceGroupIdString.split("/"))
.map(ResourceGroupStateInfoResource::urlDecode)
.collect(toImmutableList())))
.orElseThrow(() -> new WebApplicationException(NOT_FOUND));
.orElseThrow(NotFoundException::new);
}
throw new WebApplicationException(NOT_FOUND);
throw new NotFoundException();
}

private static String urlDecode(String value)
Expand Down
Loading