|
49 | 49 | import com.google.devtools.build.remote.worker.http.HttpCacheServerInitializer;
|
50 | 50 | import com.google.devtools.common.options.OptionsParser;
|
51 | 51 | import com.google.devtools.common.options.OptionsParsingException;
|
| 52 | +import io.grpc.Context; |
| 53 | +import io.grpc.Contexts; |
| 54 | +import io.grpc.Metadata; |
52 | 55 | import io.grpc.Server;
|
| 56 | +import io.grpc.ServerCall; |
| 57 | +import io.grpc.ServerCall.Listener; |
| 58 | +import io.grpc.ServerCallHandler; |
53 | 59 | import io.grpc.ServerInterceptor;
|
54 | 60 | import io.grpc.ServerInterceptors;
|
| 61 | +import io.grpc.Status; |
55 | 62 | import io.grpc.netty.GrpcSslContexts;
|
56 | 63 | import io.grpc.netty.NettyServerBuilder;
|
57 | 64 | import io.netty.bootstrap.ServerBootstrap;
|
|
71 | 78 | import java.io.OutputStreamWriter;
|
72 | 79 | import java.io.Writer;
|
73 | 80 | import java.nio.charset.StandardCharsets;
|
| 81 | +import java.util.ArrayList; |
| 82 | +import java.util.List; |
| 83 | +import java.util.Optional; |
74 | 84 | import java.util.concurrent.ConcurrentHashMap;
|
75 | 85 | import java.util.concurrent.Executors;
|
76 | 86 | import java.util.logging.Level;
|
@@ -107,6 +117,39 @@ static FileSystem getFileSystem() {
|
107 | 117 | return new JavaIoFileSystem(hashFunction);
|
108 | 118 | }
|
109 | 119 |
|
| 120 | + /** A {@link ServerInterceptor} that rejects requests unless an authorization token is present. */ |
| 121 | + private static class AuthorizationTokenInterceptor implements ServerInterceptor { |
| 122 | + private static final Metadata.Key<String> AUTHORIZATION_HEADER_KEY = |
| 123 | + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER); |
| 124 | + |
| 125 | + private static final String BEARER_PREFIX = "Bearer "; |
| 126 | + |
| 127 | + private final String expectedToken; |
| 128 | + |
| 129 | + AuthorizationTokenInterceptor(String expectedToken) { |
| 130 | + this.expectedToken = expectedToken; |
| 131 | + } |
| 132 | + |
| 133 | + private Optional<String> getTokenFromMetadata(Metadata headers) { |
| 134 | + String val = headers.get(AUTHORIZATION_HEADER_KEY); |
| 135 | + if (val != null && val.startsWith(BEARER_PREFIX)) { |
| 136 | + return Optional.of(val.substring(BEARER_PREFIX.length())); |
| 137 | + } |
| 138 | + return Optional.empty(); |
| 139 | + } |
| 140 | + |
| 141 | + @Override |
| 142 | + public <ReqT, RespT> Listener<ReqT> interceptCall( |
| 143 | + ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { |
| 144 | + Optional<String> actualToken = getTokenFromMetadata(headers); |
| 145 | + if (!expectedToken.equals(actualToken.get())) { |
| 146 | + call.close(Status.PERMISSION_DENIED, new Metadata()); |
| 147 | + return new ServerCall.Listener<ReqT>() {}; |
| 148 | + } |
| 149 | + return Contexts.interceptCall(Context.current(), call, headers, next); |
| 150 | + } |
| 151 | + } |
| 152 | + |
110 | 153 | public RemoteWorker(
|
111 | 154 | FileSystem fs,
|
112 | 155 | RemoteWorkerOptions workerOptions,
|
@@ -149,20 +192,25 @@ public RemoteWorker(
|
149 | 192 | }
|
150 | 193 |
|
151 | 194 | public Server startServer() throws IOException {
|
152 |
| - ServerInterceptor headersInterceptor = new TracingMetadataUtils.ServerHeadersInterceptor(); |
| 195 | + List<ServerInterceptor> interceptors = new ArrayList<>(); |
| 196 | + interceptors.add(new TracingMetadataUtils.ServerHeadersInterceptor()); |
| 197 | + if (workerOptions.expectedAuthorizationToken != null) { |
| 198 | + interceptors.add(new AuthorizationTokenInterceptor(workerOptions.expectedAuthorizationToken)); |
| 199 | + } |
| 200 | + |
153 | 201 | NettyServerBuilder b =
|
154 | 202 | NettyServerBuilder.forPort(workerOptions.listenPort)
|
155 |
| - .addService(ServerInterceptors.intercept(actionCacheServer, headersInterceptor)) |
156 |
| - .addService(ServerInterceptors.intercept(bsServer, headersInterceptor)) |
157 |
| - .addService(ServerInterceptors.intercept(casServer, headersInterceptor)) |
158 |
| - .addService(ServerInterceptors.intercept(capabilitiesServer, headersInterceptor)); |
| 203 | + .addService(ServerInterceptors.intercept(actionCacheServer, interceptors)) |
| 204 | + .addService(ServerInterceptors.intercept(bsServer, interceptors)) |
| 205 | + .addService(ServerInterceptors.intercept(casServer, interceptors)) |
| 206 | + .addService(ServerInterceptors.intercept(capabilitiesServer, interceptors)); |
159 | 207 |
|
160 | 208 | if (workerOptions.tlsCertificate != null) {
|
161 | 209 | b.sslContext(getSslContextBuilder(workerOptions).build());
|
162 | 210 | }
|
163 | 211 |
|
164 | 212 | if (execServer != null) {
|
165 |
| - b.addService(ServerInterceptors.intercept(execServer, headersInterceptor)); |
| 213 | + b.addService(ServerInterceptors.intercept(execServer, interceptors)); |
166 | 214 | } else {
|
167 | 215 | logger.atInfo().log("Execution disabled, only serving cache requests");
|
168 | 216 | }
|
|
0 commit comments