Skip to content

Commit

Permalink
Let .bzl files record their usages of repo mapping
Browse files Browse the repository at this point in the history
In the same vein as #20742, we record all repo mapping entries used during the load of a .bzl file too, including any of its `load()` statements and calls to `Label()` that contain an apparent repo name.

See #20721 (comment) for a more detailed explanation for this change, and the test cases in this commit for more potential triggers.

Fixes #20721
  • Loading branch information
Wyverald committed Jan 10, 2024
1 parent 2d5af9c commit b59ba48
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,21 @@ private static boolean didRepoMappingsChange(
.map(RepositoryMappingValue::key)
.collect(toImmutableSet()));
if (env.valuesMissing()) {
// This shouldn't really happen, since the RepositoryMappingValues of any recorded repos
// should have already been requested by the time we load the .bzl for the extension. And this
// method is only called if the transitive .bzl digest hasn't changed.
// However, we pretend it could happen anyway because we're good citizens.
// This likely means that one of the 'source repos' in the recorded mapping entries is no
// longer there.
throw new NeedsSkyframeRestartException();
}
for (Table.Cell<RepositoryName, String, RepositoryName> cell : recordedRepoMappings.cellSet()) {
RepositoryMappingValue repoMappingValue =
(RepositoryMappingValue) result.get(RepositoryMappingValue.key(cell.getRowKey()));
if (repoMappingValue == null) {
// Again, this shouldn't happen. But anyway.
throw new NeedsSkyframeRestartException();
}
if (!cell.getValue()
// Very importantly, `repoMappingValue` here could be for a repo that's no longer existent in
// the dep graph. See
// bazel_lockfile_test.testExtensionRepoMappingChange_sourceRepoNoLongerExistent for a test
// case.
if (repoMappingValue.equals(RepositoryMappingValue.NOT_FOUND_VALUE) || !cell.getValue()
.equals(repoMappingValue.getRepositoryMapping().get(cell.getColumnKey()))) {
// Wee woo wee woo -- diff detected!
return true;
Expand Down Expand Up @@ -817,20 +818,19 @@ private RegularRunnableExtension loadRegularRunnableExtension(
if (envVars == null) {
return null;
}
return new RegularRunnableExtension(
BazelModuleContext.of(bzlLoadValue.getModule()), extension, envVars);
return new RegularRunnableExtension(bzlLoadValue, extension, envVars);
}

private final class RegularRunnableExtension implements RunnableExtension {
private final BazelModuleContext bazelModuleContext;
private final BzlLoadValue bzlLoadValue;
private final ModuleExtension extension;
private final ImmutableMap<String, String> envVars;

RegularRunnableExtension(
BazelModuleContext bazelModuleContext,
BzlLoadValue bzlLoadValue,
ModuleExtension extension,
ImmutableMap<String, String> envVars) {
this.bazelModuleContext = bazelModuleContext;
this.bzlLoadValue = bzlLoadValue;
this.extension = extension;
this.envVars = envVars;
}
Expand All @@ -849,7 +849,7 @@ public ImmutableMap<String, String> getEnvVars() {

@Override
public byte[] getBzlTransitiveDigest() {
return bazelModuleContext.bzlTransitiveDigest();
return BazelModuleContext.of(bzlLoadValue.getModule()).bzlTransitiveDigest();
}

@Nullable
Expand All @@ -864,12 +864,13 @@ public RunModuleExtensionResult run(
new ModuleExtensionEvalStarlarkThreadContext(
usagesValue.getExtensionUniqueName() + "~",
extensionId.getBzlFileLabel().getPackageIdentifier(),
bazelModuleContext.repoMapping(),
BazelModuleContext.of(bzlLoadValue.getModule()).repoMapping(),
directories,
env.getListener());
ModuleExtensionContext moduleContext;
Optional<ModuleExtensionMetadata> moduleExtensionMetadata;
var repoMappingRecorder = new Label.RepoMappingRecorder();
repoMappingRecorder.mergeEntries(bzlLoadValue.getRecordedRepoMappings());
try (Mutability mu =
Mutability.create("module extension", usagesValue.getExtensionUniqueName())) {
StarlarkThread thread = new StarlarkThread(mu, starlarkSemantics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ public static final class RepoMappingRecorder {
/** {@code <fromRepo, apparentRepoName, canonicalRepoName> } */
Table<RepositoryName, String, RepositoryName> entries = HashBasedTable.create();

public void mergeEntries(Table<RepositoryName, String, RepositoryName> entries) {
this.entries.putAll(entries);
}

public ImmutableTable<RepositoryName, String, RepositoryName> recordedEntries() {
return ImmutableTable.<RepositoryName, String, RepositoryName>builder()
.orderRowsBy(Comparator.comparing(RepositoryName::getName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.devtools.build.lib.cmdline.BazelModuleContext;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.Label.PackageContext;
import com.google.devtools.build.lib.cmdline.Label.RepoMappingRecorder;
import com.google.devtools.build.lib.cmdline.LabelConstants;
import com.google.devtools.build.lib.cmdline.LabelSyntaxException;
import com.google.devtools.build.lib.cmdline.PackageIdentifier;
Expand Down Expand Up @@ -762,6 +763,7 @@ private BzlLoadValue computeInternalWithCompiledBzl(
if (repoMapping == null) {
return null;
}
Label.RepoMappingRecorder repoMappingRecorder = new Label.RepoMappingRecorder();
ImmutableList<Pair<String, Location>> programLoads = getLoadsFromProgram(prog);
ImmutableList<Label> loadLabels =
getLoadLabels(
Expand All @@ -770,7 +772,8 @@ private BzlLoadValue computeInternalWithCompiledBzl(
pkg,
repoMapping,
key.isSclDialect(),
isSclFlagEnabled);
isSclFlagEnabled,
repoMappingRecorder);
if (loadLabels == null) {
throw new BzlLoadFailedException(
String.format(
Expand Down Expand Up @@ -821,6 +824,7 @@ private BzlLoadValue computeInternalWithCompiledBzl(
BzlLoadValue v = loadValues.get(i++);
loadMap.put(load.first, v.getModule()); // dups ok
fp.addBytes(v.getTransitiveDigest());
repoMappingRecorder.mergeEntries(v.getRecordedRepoMappings());
}

// Retrieve predeclared symbols and complete the digest computation.
Expand Down Expand Up @@ -862,7 +866,14 @@ private BzlLoadValue computeInternalWithCompiledBzl(
// caching BzlLoadValues. Note that executing the code mutates the Module and
// BzlInitThreadContext.
executeBzlFile(
prog, label, module, loadMap, context, builtins.starlarkSemantics, env.getListener());
prog,
label,
module,
loadMap,
context,
builtins.starlarkSemantics,
env.getListener(),
repoMappingRecorder);

BzlVisibility bzlVisibility = context.getBzlVisibility();
if (bzlVisibility == null) {
Expand All @@ -871,7 +882,8 @@ private BzlLoadValue computeInternalWithCompiledBzl(
// We save load visibility in the BzlLoadValue rather than the BazelModuleContext because
// visibility doesn't need to be introspected by any Starlark builtin methods, and because the
// alternative would mean mutating or overwriting the BazelModuleContext after evaluation.
return new BzlLoadValue(module, transitiveDigest, bzlVisibility);
return new BzlLoadValue(
module, transitiveDigest, bzlVisibility, repoMappingRecorder.recordedEntries());
}

@Nullable
Expand Down Expand Up @@ -1058,7 +1070,8 @@ private static ImmutableList<Label> getLoadLabels(
PackageIdentifier base,
RepositoryMapping repoMapping,
boolean withinSclDialect,
boolean isSclFlagEnabled) {
boolean isSclFlagEnabled,
@Nullable Label.RepoMappingRecorder repoMappingRecorder) {
boolean ok = true;

ImmutableList.Builder<Label> loadLabels = ImmutableList.builderWithExpectedSize(loads.size());
Expand All @@ -1071,7 +1084,8 @@ private static ImmutableList<Label> getLoadLabels(
throw new LabelSyntaxException("in .scl files, load labels must begin with \"//\"");
}
Label label =
Label.parseWithPackageContext(unparsedLabel, PackageContext.of(base, repoMapping));
Label.parseWithPackageContext(
unparsedLabel, PackageContext.of(base, repoMapping), repoMappingRecorder);
checkValidLoadLabel(
label,
/* fromBuiltinsRepo= */ StarlarkBuiltinsValue.isBuiltinsRepo(base.getRepository()),
Expand Down Expand Up @@ -1106,7 +1120,8 @@ static ImmutableList<Label> getLoadLabels(
repoMapping,
/* withinSclDialect= */ false,
/* isSclFlagEnabled= */ starlarkSemantics.getBool(
BuildLanguageOptions.EXPERIMENTAL_ENABLE_SCL_DIALECT));
BuildLanguageOptions.EXPERIMENTAL_ENABLE_SCL_DIALECT),
/* repoMappingRecorder= */ null);
}

/** Extracts load statements from compiled program (see {@link #getLoadLabels}). */
Expand Down Expand Up @@ -1337,11 +1352,15 @@ private static void executeBzlFile(
Map<String, Module> loadedModules,
BzlInitThreadContext context,
StarlarkSemantics starlarkSemantics,
ExtendedEventHandler skyframeEventHandler)
ExtendedEventHandler skyframeEventHandler,
Label.RepoMappingRecorder repoMappingRecorder)
throws BzlLoadFailedException, InterruptedException {
try (Mutability mu = Mutability.create("loading", label)) {
StarlarkThread thread = new StarlarkThread(mu, starlarkSemantics);
thread.setLoader(loadedModules::get);
// This is needed so that any calls to `Label()` will have its used repo mapping entries
// recorded. See #20721 for more details.
thread.setThreadLocal(Label.RepoMappingRecorder.class, repoMappingRecorder);

// Wrap the skyframe event handler to listen for starlark errors.
AtomicBoolean sawStarlarkError = new AtomicBoolean(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableTable;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
Expand Down Expand Up @@ -52,12 +53,15 @@ public class BzlLoadValue implements SkyValue {
// from the Module as client data?
private final byte[] transitiveDigest; // of .bzl file and load dependencies
private final BzlVisibility bzlVisibility;
private final ImmutableTable<RepositoryName, String, RepositoryName> recordedRepoMappings;

@VisibleForTesting
public BzlLoadValue(Module module, byte[] transitiveDigest, BzlVisibility bzlVisibility) {
public BzlLoadValue(Module module, byte[] transitiveDigest, BzlVisibility bzlVisibility,
ImmutableTable<RepositoryName, String, RepositoryName> recordedRepoMappings) {
this.module = checkNotNull(module);
this.transitiveDigest = checkNotNull(transitiveDigest);
this.bzlVisibility = checkNotNull(bzlVisibility);
this.recordedRepoMappings = checkNotNull(recordedRepoMappings);
}

/** Returns the .bzl module. */
Expand All @@ -75,6 +79,14 @@ public BzlVisibility getBzlVisibility() {
return bzlVisibility;
}

/**
* Returns the repo mapping entries used to laod this bzl file. Stored for correctness across
* Bazel server restarts.
*/
public ImmutableTable<RepositoryName, String, RepositoryName> getRecordedRepoMappings() {
return recordedRepoMappings;
}

private static final SkyKeyInterner<Key> keyInterner = SkyKey.newInterner();

/** SkyKey for a Starlark load. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.devtools.build.skyframe.SkyFunctionException;
import com.google.devtools.build.skyframe.SkyKey;
import com.google.devtools.build.skyframe.SkyValue;
import java.io.IOException;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -178,7 +179,7 @@ public SkyValue compute(SkyKey skyKey, Environment env)
return computeFromWorkspace(repositoryName, externalPackageValue, rootModuleRepoMapping);
}

throw new RepositoryMappingFunctionException();
return RepositoryMappingValue.NOT_FOUND_VALUE;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.devtools.build.skyframe.SkyValue;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;

/**
* A value that represents the 'mappings' of an external Bazel workspace, as defined in the main
Expand Down Expand Up @@ -57,6 +58,9 @@ public abstract class RepositoryMappingValue implements SkyValue {
public static final RepositoryMappingValue VALUE_FOR_ROOT_MODULE_WITHOUT_REPOS =
RepositoryMappingValue.createForWorkspaceRepo(RepositoryMapping.ALWAYS_FALLBACK);

public static final RepositoryMappingValue NOT_FOUND_VALUE =
RepositoryMappingValue.createForWorkspaceRepo(null);

/**
* Returns a {@link RepositoryMappingValue} for a repo defined in MODULE.bazel, which has an
* associated module.
Expand All @@ -80,6 +84,8 @@ public static RepositoryMappingValue createForWorkspaceRepo(RepositoryMapping re
repositoryMapping, Optional.empty(), Optional.empty());
}

/** The actual repo mapping. Will be null if the requested repo doesn't exist. */
@Nullable
public abstract RepositoryMapping getRepositoryMapping();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Tables;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.clock.BlazeClock;
import com.google.devtools.build.lib.cmdline.BazelModuleContext;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.packages.RuleVisibility;
import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
import com.google.devtools.build.lib.pkgcache.PackageOptions;
Expand Down Expand Up @@ -1034,7 +1036,7 @@ public void testLoadBzlFileFromWorkspaceWithRemapping() throws Exception {

scratch.file("/y/WORKSPACE");
scratch.file("/y/BUILD");
scratch.file("/y/y.bzl", "y_symbol = 5");
scratch.file("/y/y.bzl", "l = Label('@z//:z')", "y_symbol = 5");

scratch.file("/a/WORKSPACE");
scratch.file("/a/BUILD");
Expand All @@ -1050,8 +1052,12 @@ public void testLoadBzlFileFromWorkspaceWithRemapping() throws Exception {
SkyframeExecutorTestUtils.evaluate(
getSkyframeExecutor(), skyKey, /*keepGoing=*/ false, reporter);

assertThat(result.get(skyKey).getModule().getGlobals())
.containsEntry("a_symbol", StarlarkInt.of(5));
var bzlLoadValue = result.get(skyKey);
assertThat(bzlLoadValue.getModule().getGlobals()).containsEntry("a_symbol", StarlarkInt.of(5));
assertThat(bzlLoadValue.getRecordedRepoMappings().cellSet()).containsExactly(
Tables.immutableCell(RepositoryName.create("a"), "x", RepositoryName.create("y")),
Tables.immutableCell(RepositoryName.create("y"), "z", RepositoryName.create("z")))
.inOrder();
}

@Test
Expand All @@ -1071,6 +1077,7 @@ public void testLoadBzlFileFromBzlmod() throws Exception {
fooDir.getRelative("test.bzl").getPathString(),
// Also test that bzlmod .bzl files can load .scl files.
"load('@bar_alias//:test.scl', 'haha')",
"l = Label('@foo//:whatever')",
"hoho = haha");
Path barDir = moduleRoot.getRelative("bar~2.0");
scratch.file(barDir.getRelative("WORKSPACE").getPathString());
Expand All @@ -1083,8 +1090,13 @@ public void testLoadBzlFileFromBzlmod() throws Exception {
getSkyframeExecutor(), skyKey, /*keepGoing=*/ false, reporter);

assertThatEvaluationResult(result).hasNoError();
assertThat(result.get(skyKey).getModule().getGlobals())
.containsEntry("hoho", StarlarkInt.of(5));
var bzlLoadValue = result.get(skyKey);
assertThat(bzlLoadValue.getModule().getGlobals()).containsEntry("hoho", StarlarkInt.of(5));
assertThat(bzlLoadValue.getRecordedRepoMappings().cellSet()).containsExactly(
Tables.immutableCell(RepositoryName.create("foo~1.0"), "bar_alias",
RepositoryName.create("bar~2.0")),
Tables.immutableCell(RepositoryName.create("foo~1.0"), "foo",
RepositoryName.create("foo~1.0"))).inOrder();
// Note that we're not testing the case of a non-registry override using @bazel_tools here, but
// that is incredibly hard to set up in a unit test. So we should just rely on integration tests
// for that.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import static java.nio.charset.StandardCharsets.ISO_8859_1;

import com.google.common.collect.ImmutableClassToInstanceMap;
import com.google.common.collect.ImmutableTable;
import com.google.devtools.build.lib.cmdline.RepositoryName;
import com.google.devtools.build.lib.packages.BzlVisibility;
import com.google.devtools.build.lib.skyframe.BzlLoadValue;
import com.google.devtools.build.lib.skyframe.serialization.testutils.SerializationTester;
Expand All @@ -29,6 +31,10 @@
/** Tests for {@link BzlLoadValue} serialization. */
@RunWith(JUnit4.class)
public class BzlLoadValueCodecTest {
private static final ImmutableTable<RepositoryName, String, RepositoryName> SOME_TABLE =
ImmutableTable.of(
RepositoryName.createUnvalidated("foo"), "bar", RepositoryName.createUnvalidated("quux"));

@Test
public void objectCodecTests() throws Exception {
Module module = Module.create();
Expand All @@ -37,13 +43,15 @@ public void objectCodecTests() throws Exception {
module.setGlobal("c", 3);
byte[] digest = "dummy".getBytes(ISO_8859_1);

new SerializationTester(new BzlLoadValue(module, digest, BzlVisibility.PUBLIC))
new SerializationTester(
new BzlLoadValue(module, digest, BzlVisibility.PUBLIC, SOME_TABLE))
.setVerificationFunction(
(SerializationTester.VerificationFunction<BzlLoadValue>)
(x, y) -> {
if (!java.util.Arrays.equals(x.getTransitiveDigest(), y.getTransitiveDigest())) {
throw new AssertionError("unequal digests after serialization");
}
assertThat(x.getRecordedRepoMappings()).isEqualTo(y.getRecordedRepoMappings());
})
.runTestsWithoutStableSerializationCheck();
}
Expand All @@ -64,6 +72,6 @@ private static BzlLoadValue makeBLV(String name, Object value) {
module.setGlobal(name, value);

byte[] digest = "dummy".getBytes(ISO_8859_1);
return new BzlLoadValue(module, digest, BzlVisibility.PUBLIC);
return new BzlLoadValue(module, digest, BzlVisibility.PUBLIC, SOME_TABLE);
}
}
Loading

0 comments on commit b59ba48

Please sign in to comment.