diff --git a/model/orbax/experimental/model/core/python/compile_options_util.py b/model/orbax/experimental/model/core/python/compile_options_util.py index 2be9a5259..e32a10458 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util.py +++ b/model/orbax/experimental/model/core/python/compile_options_util.py @@ -16,7 +16,10 @@ from collections.abc import Mapping, Sequence import logging +from typing import Any +from absl import flags +from google.protobuf import descriptor from google.protobuf import text_format import jax from orbax.experimental.model.core.protos import manifest_pb2 @@ -28,6 +31,12 @@ from tensorflow.compiler.xla import xla_pb2 from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2 +# A mapping between XLAflag names and protobuf field names. +_XLA_FLAG_TO_FIELD_MAP = { + field.name: field + for field in tpu_comp_env_pb2.TpuCompilationEnvironment.DESCRIPTOR.fields +} + def generate_tpu_compilation_env( xla_flags: Sequence[str] | None = None, @@ -39,16 +48,21 @@ def generate_tpu_compilation_env( tpu_compilation_env_str ) # Override with supplied XLA flags if any is provided. - if xla_flags is not None: - env_override = tpu_comp_env_pb2.TpuCompilationEnvironment() - xla_flags_str = '\n'.join(xla_flags) - try: - text_format.Parse(xla_flags_str, env_override) - except text_format.ParseError as e: - raise ValueError( - f'Error parsing supplied XLA flag overrides {xla_flags_str}.' - ) from e - env.MergeFrom(env_override) + if xla_flags: + is_proto_formatted = False if xla_flags[0].startswith('--') else True + if is_proto_formatted: + merge_proto_formatted_flags_compile_option(xla_flags, env) + else: + parsed_flags = {} + for flag in xla_flags: + if not flag.startswith('--'): + raise ValueError( + f"Flag {flag} does not start with '--'. All flags must be in the" + ' format of --flag_name=flag_value.' + ) + flag_name, flag_value = flag[2:].split('=', 1) + parsed_flags[flag_name] = flag_value + merge_flags_into_compile_options(parsed_flags, env) # Pack the TPU compilation environment into a compilation env proto. any_proto = any_pb2.Any() @@ -109,8 +123,13 @@ def generate_xla_compile_options( """Sets the XLA compilation options. Args: + native_serialization_platforms: A sequence of platform names that the + compile options will be set for. If None, the compile options will be set + for TPU only. xla_flags_per_platform: A mapping from platform name to a list of xla flags which will be used to override the default XLA compilation flags. + jax_mesh: The JAX mesh used for sharding. If None, the compile options will + be set for a default single-replica. Returns: A `CompileOptionsProtoMap` containing the XLA compilation options per @@ -156,3 +175,103 @@ def generate_xla_compile_options( generate_compilation_options(compile_environment, jax_mesh) ) return compile_options_map + + +def get_field_for_flag(flag_name: str) -> descriptor.FieldDescriptor: + """Gets the protobuf field descriptor for a given flag name.""" + if flag_name not in _XLA_FLAG_TO_FIELD_MAP: + raise ValueError( + f'No TpuCompilationEnvironment field matching flag {flag_name}' + ) + return _XLA_FLAG_TO_FIELD_MAP[flag_name] + + +def parse_flag_from_string(flag_name: str, value: str) -> Any: + """Parses a string value for a given flag and normalizes it for a proto field. + + This is a Python implementation of the C++ function + TpuCompEnvReflection::ParseFlagFromString. + + Args: + flag_name: The name of the flag. + value: The string value of the flag. + + Returns: + The parsed and normalized value suitable for setting the corresponding field + in `TpuCompilationEnvironment`. This can be a primitive type (int, bool, + str), float, an enum's integer value, or a proto message instance. + + Raises: + ValueError: If the flag is not found, or if a proto message value cannot + be parsed. + """ + try: + flag_holder = flags.FLAGS[flag_name] + except KeyError: + raise ValueError(f'Flag not found: {flag_name}') + + parsed_value = flag_holder.parser.parse(value) + field = get_field_for_flag(flag_name) + + if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + message_instance = field.message_type._concrete_class() + try: + text_format.Parse(value, message_instance) + return message_instance + except text_format.ParseError as e: + raise ValueError( + f'Error parsing proto value for flag {flag_name}: {value}' + ) from e + if field.type == descriptor.FieldDescriptor.TYPE_ENUM: + if isinstance(parsed_value, str): + return field.enum_type.values_by_name[parsed_value].number + # If it's already an int, assume it's the correct value. + return parsed_value + if field.type in ( + descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE, + ): + return float(parsed_value) + return parsed_value + + +def merge_flags_into_compile_options( + xla_flags: Mapping[str, str], + env: tpu_comp_env_pb2.TpuCompilationEnvironment, +): + """Merges flags into a TpuCompilationEnvironment proto. + + Args: + xla_flags: A mapping of XLA flag names to their string values. These flags + will be parsed and merged into the `env` proto. + env: The TpuCompilationEnvironment proto to merge the flags into. This + proto will be modified in place. + """ + env_override = tpu_comp_env_pb2.TpuCompilationEnvironment() + for flag_name, value in xla_flags.items(): + field_descriptor = get_field_for_flag(flag_name) + parsed_value = parse_flag_from_string(flag_name, value) + if field_descriptor.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + # For message types, we need to copy the parsed message. + getattr(env_override, field_descriptor.name).CopyFrom(parsed_value) + else: + # For scalar types, we can set the attribute directly. + setattr(env_override, field_descriptor.name, parsed_value) + env.MergeFrom(env_override) + + +# TODO(b/438187387): remove this path and only allow the "--flag=value" format. +def merge_proto_formatted_flags_compile_option( + xla_flags: Sequence[str], + env: tpu_comp_env_pb2.TpuCompilationEnvironment, +): + """Merges flags into a proto.""" + env_override = tpu_comp_env_pb2.TpuCompilationEnvironment() + xla_flags_str = '\n'.join(xla_flags) + try: + text_format.Parse(xla_flags_str, env_override) + except text_format.ParseError as e: + raise ValueError( + f'Error parsing supplied XLA flag overrides {xla_flags_str}.' + ) from e + env.MergeFrom(env_override) diff --git a/model/orbax/experimental/model/core/python/compile_options_util_test.py b/model/orbax/experimental/model/core/python/compile_options_util_test.py new file mode 100644 index 000000000..3ca5fa8bf --- /dev/null +++ b/model/orbax/experimental/model/core/python/compile_options_util_test.py @@ -0,0 +1,130 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from orbax.experimental.model.core.python import compile_options_util +from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2 + + +class CompileOptionsUtilTest(parameterized.TestCase): + + def test_parse_flag_from_string_bool(self): + result = compile_options_util.parse_flag_from_string( + 'xla_sc_poison_buffers', 'false' + ) + self.assertEqual(result, False) + + def test_parse_flag_from_string_int(self): + result = compile_options_util.parse_flag_from_string( + 'xla_jf_rematerialization_percent_shared_memory_limit', '99' + ) + self.assertEqual(result, 99) + + def test_parse_flag_from_string_float(self): + result = compile_options_util.parse_flag_from_string( + 'xla_tpu_async_copy_bandwidth_scaling_factor', '0.19125064716453793' + ) + self.assertEqual(result, 0.19125064716453793) + + def test_parse_flag_from_string_string(self): + result = compile_options_util.parse_flag_from_string( + 'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers', + 'NO_SCALE', + ) + self.assertEqual(result, 'NO_SCALE') + + def test_parse_flag_from_string_proto(self): + compile_options_util.parse_flag_from_string( + 'xla_tpu_memory_bound_loop_optimizer_options', 'enabled:false' + ) + + def test_parse_flag_from_string_enum(self): + result = compile_options_util.parse_flag_from_string( + 'xla_memory_scheduler', 'DFS' + ) + expected = tpu_comp_env_pb2.MemorySchedulerProto.DFS + self.assertEqual(result, expected) + + def test_parse_flag_from_string_nonexistent_flag(self): + with self.assertRaisesRegex(ValueError, 'Flag not found: nonexistent_flag'): + compile_options_util.parse_flag_from_string('nonexistent_flag', 'value') + + @parameterized.named_parameters( + ( + 'dict_xla_flags', + { + 'xla_jf_rematerialization_percent_shared_memory_limit': '99', + 'xla_tpu_allocate_scoped_vmem_at_same_offset': 'false', + 'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers': ( + 'NO_SCALE' + ), + 'xla_tpu_memory_bound_loop_optimizer_options': 'enabled:false', + 'xla_tpu_async_copy_bandwidth_scaling_factor': ( + '0.19125064716453793' + ), + }, + compile_options_util.merge_flags_into_compile_options, + ), + ( + 'proto_formatted_xla_flags', + [ + 'xla_jf_rematerialization_percent_shared_memory_limit: 99', + 'xla_tpu_allocate_scoped_vmem_at_same_offset: false', + ( + 'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers:' + " 'NO_SCALE'" + ), + 'xla_tpu_memory_bound_loop_optimizer_options: {enabled:false}', + ( + 'xla_tpu_async_copy_bandwidth_scaling_factor:' + ' 0.19125064716453793' + ), + ], + compile_options_util.merge_proto_formatted_flags_compile_option, + ), + ) + def test_merge_flags_into_compile_options(self, xla_flags, merge_fn): + # Initialize the environment with some values. + env = tpu_comp_env_pb2.TpuCompilationEnvironment() + # Values that should be overridden. + env.xla_jf_rematerialization_percent_shared_memory_limit = 10 + env.xla_tpu_memory_bound_loop_optimizer_options.enabled = True + # Value that should not be overridden. + env.xla_tpu_wait_n_cycles_before_program_termination = 1234 + + # Merge the flags into the environment. + merge_fn(xla_flags, env) + self.assertEqual( + env.xla_jf_rematerialization_percent_shared_memory_limit, 99 + ) + self.assertEqual(env.xla_tpu_allocate_scoped_vmem_at_same_offset, False) + self.assertEqual( + env.xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers, + 'NO_SCALE', + ) + self.assertEqual( + env.xla_tpu_memory_bound_loop_optimizer_options.enabled, False + ) + self.assertAlmostEqual( + env.xla_tpu_async_copy_bandwidth_scaling_factor, + 0.19125064716453793, + ) + + # Value that should not be overridden. + self.assertEqual(env.xla_tpu_wait_n_cycles_before_program_termination, 1234) + + +if __name__ == '__main__': + absltest.main()