From dd822392db96bb7bccdb673414a20c4b91e3dbc1 Mon Sep 17 00:00:00 2001 From: Fabian Meumertzheim Date: Fri, 31 Mar 2023 04:35:54 -0700 Subject: [PATCH] Canonicalize use_extension label Canonicalize the label by adding the current module's repo_name if the label doesn't specify a repository name. This is necessary as ModuleExtensionUsages are grouped by the string value of this label, but later mapped to their Label representation. If multiple strings map to the same Label, this would result in a crash. Also enforce that `module()` is called first (if at all). Closes #17920. PiperOrigin-RevId: 520890201 Change-Id: Ice8e2feb0da591e3ba953f4a85284766ba599ebf --- .../build/lib/bazel/bzlmod/Module.java | 2 + .../lib/bazel/bzlmod/ModuleFileGlobals.java | 38 ++++++++- .../bzlmod/ModuleExtensionResolutionTest.java | 82 +++++++++++++++++++ .../bazel/bzlmod/ModuleFileFunctionTest.java | 38 ++++++++- 4 files changed, 155 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java index 2952548e851bd5..cfc09237b415b0 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java @@ -242,6 +242,8 @@ public Builder addExtensionUsage(ModuleExtensionUsage value) { return this; } + abstract ModuleKey getKey(); + abstract String getName(); abstract Optional getRepoName(); diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java index 3b3f42c0ba0709..da7ba7318cb49e 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java @@ -61,6 +61,7 @@ public class ModuleFileGlobals { Pattern.compile("(>|<|-|<=|>=)(\\d+\\.){2}\\d+"); private boolean moduleCalled = false; + private boolean hadNonModuleCall = false; private final boolean ignoreDevDeps; private final Module.Builder module; private final Map deps = new LinkedHashMap<>(); @@ -208,6 +209,9 @@ public void module( if (moduleCalled) { throw Starlark.errorf("the module() directive can only be called once"); } + if (hadNonModuleCall) { + throw Starlark.errorf("if module() is called, it must be called before any other functions"); + } moduleCalled = true; if (!name.isEmpty()) { validateModuleName(name); @@ -298,6 +302,7 @@ private static ImmutableList checkAllCompatibilityVersions( public void bazelDep( String name, String version, String repoName, boolean devDependency, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; if (repoName.isEmpty()) { repoName = name; } @@ -330,6 +335,7 @@ public void bazelDep( allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the platforms to register.")) public void registerExecutionPlatforms(Sequence platformLabels) throws EvalException { + hadNonModuleCall = true; module.addExecutionPlatformsToRegister( checkAllAbsolutePatterns(platformLabels, "register_execution_platforms")); } @@ -347,6 +353,7 @@ public void registerExecutionPlatforms(Sequence platformLabels) throws EvalEx allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the toolchains to register.")) public void registerToolchains(Sequence toolchainLabels) throws EvalException { + hadNonModuleCall = true; module.addToolchainsToRegister( checkAllAbsolutePatterns(toolchainLabels, "register_toolchains")); } @@ -376,7 +383,14 @@ public void registerToolchains(Sequence toolchainLabels) throws EvalException }, useStarlarkThread = true) public ModuleExtensionProxy useExtension( - String extensionBzlFile, String extensionName, boolean devDependency, StarlarkThread thread) { + String rawExtensionBzlFile, + String extensionName, + boolean devDependency, + StarlarkThread thread) { + hadNonModuleCall = true; + + String extensionBzlFile = normalizeLabelString(rawExtensionBzlFile); + ModuleExtensionUsageBuilder newUsageBuilder = new ModuleExtensionUsageBuilder( extensionBzlFile, extensionName, thread.getCallerLocation()); @@ -399,6 +413,22 @@ public ModuleExtensionProxy useExtension( return newUsageBuilder.getProxy(devDependency); } + private String normalizeLabelString(String rawExtensionBzlFile) { + // Normalize the label by adding the current module's repo_name if the label doesn't specify a + // repository name. This is necessary as ModuleExtensionUsages are grouped by the string value + // of this label, but later mapped to their Label representation. If multiple strings map to the + // same Label, this would result in a crash. + // ownName can't change anymore as calling module() after this results in an error. + String ownName = module.getRepoName().orElse(module.getName()); + if (module.getKey().equals(ModuleKey.ROOT) && rawExtensionBzlFile.startsWith("@//")) { + return "@" + ownName + rawExtensionBzlFile.substring(1); + } else if (rawExtensionBzlFile.startsWith("//")) { + return "@" + ownName + rawExtensionBzlFile; + } else { + return rawExtensionBzlFile; + } + } + class ModuleExtensionUsageBuilder { private final String extensionBzlFile; private final String extensionName; @@ -516,6 +546,7 @@ public void useRepo( Dict kwargs, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; Location location = thread.getCallerLocation(); for (String arg : Sequence.cast(args, String.class, "args")) { extensionProxy.addImport(arg, arg, location); @@ -598,6 +629,7 @@ public void singleVersionOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; Version parsedVersion; try { parsedVersion = Version.parse(version); @@ -652,6 +684,7 @@ public void singleVersionOverride( }) public void multipleVersionOverride(String moduleName, Iterable versions, String registry) throws EvalException { + hadNonModuleCall = true; ImmutableList.Builder parsedVersionsBuilder = new ImmutableList.Builder<>(); try { for (String version : Sequence.cast(versions, String.class, "versions").getImmutableList()) { @@ -735,6 +768,7 @@ public void archiveOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; ImmutableList urlList = urls instanceof String ? ImmutableList.of((String) urls) @@ -806,6 +840,7 @@ public void gitOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; addOverride( moduleName, GitOverride.create( @@ -835,6 +870,7 @@ public void gitOverride( positional = false), }) public void localPathOverride(String moduleName, String path) throws EvalException { + hadNonModuleCall = true; addOverride(moduleName, LocalPathOverride.create(path)); } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java index a6f3127a8d9b24..2b488324e115b0 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java @@ -329,6 +329,88 @@ public void simpleExtension() throws Exception { assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba"); } + @Test + public void simpleExtension_nonCanonicalLabel() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_module//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + + @Test + public void simpleExtension_nonCanonicalLabel_repoName() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0', repo_name='my_name')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_name//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + @Test public void multipleModules() throws Exception { scratch.file( diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java index 2df3f7af45a1e2..1b8e52ac9cc8fe 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java @@ -473,7 +473,7 @@ public void testModuleExtensions_good() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext1") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 2, 23)) .setImports(ImmutableBiMap.of("repo1", "repo1")) @@ -491,7 +491,7 @@ public void testModuleExtensions_good() throws Exception { .build()) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext2") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("other_repo1", "repo1", "repo2", "repo2")) @@ -582,7 +582,7 @@ public void testModuleExtensions_duplicateProxy_asRoot() throws Exception { .setKey(ModuleKey.ROOT) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("/MODULE.bazel", 1, 23)) .setImports( @@ -672,7 +672,7 @@ public void testModuleExtensions_duplicateProxy_asDep() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("beta", "beta", "delta", "delta")) @@ -956,4 +956,34 @@ public void moduleRepoName_conflict() throws Exception { assertContainsEvent("The repo name 'bbb' is already being used as the module's own repo name"); } + + @Test + public void module_calledTwice() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "module(name='aaa',version='0.1',repo_name='bbb')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("the module() directive can only be called once"); + } + + @Test + public void module_calledLate() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "use_extension('//:extensions.bzl', 'my_ext')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("if module() is called, it must be called before any other functions"); + } }