Skip to content

Commit

Permalink
[remote/downloader] Migrate Downloader to take Credentials
Browse files Browse the repository at this point in the history
Progress on bazelbuild#15856
  • Loading branch information
Yannic committed Oct 28, 2022
1 parent e5a7389 commit 4c861fc
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ filegroup(
name = "srcs",
srcs = glob(["**"]) + [
"//src/main/java/com/google/devtools/build/lib/authandtls/credentialhelper:srcs",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials:srcs",
],
visibility = ["//src:__subpackages__"],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@rules_java//java:defs.bzl", "java_library")

package(default_visibility = ["//src:__subpackages__"])

filegroup(
name = "srcs",
srcs = glob(["**"]),
visibility = ["//src:__subpackages__"],
)

java_library(
name = "staticcredentials",
srcs = glob(["*.java"]),
deps = [
"//third_party:auth",
"//third_party:guava",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.google.devtools.build.lib.authandtls.staticcredentials;

import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;

/** Implementation of {@link Credentials} which provides a static set of credentials. */
public final class StaticCredentials extends Credentials {
private final ImmutableMap<URI, Map<String, List<String>>> credentials;

public StaticCredentials(Map<URI, Map<String, List<String>>> credentials) {
Preconditions.checkNotNull(credentials);

this.credentials = ImmutableMap.copyOf(credentials);
}

public Map<URI, Map<String, List<String>>> getMapForMigration() {
return credentials;
}

@Override
public String getAuthenticationType() {
return "static";
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
Preconditions.checkNotNull(uri);

return credentials.getOrDefault(uri, ImmutableMap.of());
}

@Override
public boolean hasRequestMetadata() {
return true;
}

@Override
public boolean hasRequestMetadataOnly() {
return true;
}

@Override
public void refresh() {
// Can't refresh static credentials.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ java_library(
deps = [
"//src/main/java/com/google/devtools/build/lib/analysis:blaze_version_info",
"//src/main/java/com/google/devtools/build/lib/authandtls",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache:events",
"//src/main/java/com/google/devtools/build/lib/buildeventstream",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
Expand Down Expand Up @@ -47,7 +48,7 @@ public void setDelegate(@Nullable Downloader delegate) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -60,6 +61,6 @@ public void download(
downloader = delegate;
}
downloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCacheHitEvent;
Expand Down Expand Up @@ -256,7 +257,7 @@ public Path download(
try {
downloader.download(
rewrittenUrls,
rewrittenAuthHeaders,
new StaticCredentials(rewrittenAuthHeaders),
checksum,
canonicalId,
destination,
Expand Down Expand Up @@ -337,7 +338,7 @@ public byte[] downloadAndReadOneUrl(
for (int attempt = 0; attempt <= retries; ++attempt) {
try {
return httpDownloader.downloadAndReadOneUrl(
rewrittenUrls.get(0), authHeaders, eventHandler, clientEnv);
rewrittenUrls.get(0), new StaticCredentials(authHeaders), eventHandler, clientEnv);
} catch (ContentLengthMismatchException e) {
if (attempt == retries) {
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
Expand All @@ -33,7 +34,7 @@ public interface Downloader {
* caller is responsible for cleaning up outputs of failed downloads.
*
* @param urls list of mirror URLs with identical content
* @param authHeaders map of authentication headers per URL
* @param credentials credentials to use when connecting to URLs
* @param checksum valid checksum which is checked, or absent to disable
* @param output path to the destination file to write
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
Expand All @@ -42,7 +43,7 @@ public interface Downloader {
*/
void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -74,7 +76,7 @@ final class HttpConnectorMultiplexer {
}

public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOException {
return connect(url, checksum, ImmutableMap.of(), Optional.absent());
return connect(url, checksum, new StaticCredentials(ImmutableMap.of()), Optional.absent());
}

/**
Expand All @@ -87,7 +89,7 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
*
* @param url the URL to conenct to. can be: file, http, or https
* @param checksum checksum lazily checked on entire payload, or empty to disable
* @param authHeaders the authentication headers
* @param credentials the credentials
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
* @return an {@link InputStream} of response payload
* @throws IOException if all mirrors are down and contains suppressed exception of each attempt
Expand All @@ -97,15 +99,15 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
public HttpStream connect(
URL url,
Optional<Checksum> checksum,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<String> type)
throws IOException {
Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url));
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
Function<URL, ImmutableMap<String, List<String>>> headerFunction =
getHeaderFunction(REQUEST_HEADERS, authHeaders);
getHeaderFunction(REQUEST_HEADERS, credentials);
URLConnection connection = connector.connect(url, headerFunction);
return httpStreamFactory.create(
connection,
Expand All @@ -128,20 +130,20 @@ public HttpStream connect(
@VisibleForTesting
static Function<URL, ImmutableMap<String, List<String>>> getHeaderFunction(
Map<String, List<String>> baseHeaders,
Map<URI, Map<String, List<String>>> additionalHeaders) {
Credentials credentials) {
Preconditions.checkNotNull(baseHeaders);
Preconditions.checkNotNull(credentials);

return url -> {
ImmutableMap<String, List<String>> headers = ImmutableMap.copyOf(baseHeaders);
Map<String, List<String>> headers = new HashMap<>(baseHeaders);
try {
if (additionalHeaders.containsKey(url.toURI())) {
Map<String, List<String>> newHeaders = new HashMap<>(headers);
newHeaders.putAll(additionalHeaders.get(url.toURI()));
headers = ImmutableMap.copyOf(newHeaders);
}
} catch (URISyntaxException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), still try
// to do the connection, not adding authentication information as we cannot look it up.
headers.putAll(credentials.getRequestMetadata(url.toURI()));
} catch (URISyntaxException | IOException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), or fetching
// credentials fails for any other reason, still try to do the connection, not adding
// authentication information as we cannot look it up.
}
return headers;
return ImmutableMap.copyOf(headers);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -63,7 +64,7 @@ public void setTimeoutScaling(float timeoutScaling) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -82,8 +83,8 @@ public void download(
for (URL url : urls) {
SEMAPHORE.acquire();

try (HttpStream payload = multiplexer.connect(url, checksum, authHeaders, type);
OutputStream out = destination.getOutputStream()) {
try (HttpStream payload = multiplexer.connect(url, checksum, credentials, type);
OutputStream out = destination.getOutputStream()) {
try {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
Expand Down Expand Up @@ -132,7 +133,7 @@ public void download(
/** Downloads the contents of one URL and reads it into a byte array. */
public byte[] downloadAndReadOneUrl(
URL url,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv)
throws IOException, InterruptedException {
Expand All @@ -141,7 +142,7 @@ public byte[] downloadAndReadOneUrl(
ByteArrayOutputStream out = new ByteArrayOutputStream();
SEMAPHORE.acquire();
try (HttpStream payload =
multiplexer.connect(url, Optional.absent(), authHeaders, Optional.absent())) {
multiplexer.connect(url, Optional.absent(), credentials, Optional.absent())) {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
// SocketTimeoutExceptions are InterruptedIOExceptions; however they do not signify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ java_library(
name = "downloader",
srcs = glob(["*.java"]),
deps = [
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader",
"//src/main/java/com/google/devtools/build/lib/events",
"//src/main/java/com/google/devtools/build/lib/remote:ReferenceCountedChannel",
Expand All @@ -22,6 +23,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/remote/options",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//src/main/java/com/google/devtools/build/lib/vfs",
"//third_party:auth",
"//third_party:guava",
"//third_party:jsr305",
"//third_party/grpc-java:grpc-jar",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import build.bazel.remote.asset.v1.Qualifier;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.RequestMetadata;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
Expand All @@ -41,12 +42,10 @@
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -112,7 +111,7 @@ public void close() {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -156,7 +155,7 @@ public void download(
eventHandler.handle(
Event.warn("Remote Cache: " + Utils.grpcAwareErrorMessage(e, verboseFailures)));
fallbackDownloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ java_library(
srcs = glob(["*.java"]),
deps = [
"//src/main/java/com/google/devtools/build/lib/authandtls",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader",
"//src/main/java/com/google/devtools/build/lib/events",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -163,7 +164,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap.of("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA==")));

Function<URL, ImmutableMap<String, List<String>>> headerFunction =
HttpConnectorMultiplexer.getHeaderFunction(baseHeaders, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
baseHeaders, new StaticCredentials(additionalHeaders));

// Unrelated URL
assertThat(headerFunction.apply(new URL("http://example.org/some/path/file.txt")))
Expand Down Expand Up @@ -215,7 +217,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap<String, List<String>> annonAuth =
ImmutableMap.of("Authentication", ImmutableList.of("YW5vbnltb3VzOmZvb0BleGFtcGxlLm9yZw=="));
Function<URL, ImmutableMap<String, List<String>>> combinedHeaders =
HttpConnectorMultiplexer.getHeaderFunction(annonAuth, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
annonAuth, new StaticCredentials(additionalHeaders));
assertThat(combinedHeaders.apply(new URL("http://hosting.example.com/user/foo/file.txt")))
.containsExactly("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA=="));
assertThat(combinedHeaders.apply(new URL("http://unreleated.example.org/user/foo/file.txt")))
Expand Down
Loading

0 comments on commit 4c861fc

Please sign in to comment.