Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 129 additions & 10 deletions model/orbax/experimental/model/core/python/compile_options_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from collections.abc import Mapping, Sequence
import logging
from typing import Any

from absl import flags
from google.protobuf import descriptor
from google.protobuf import text_format
import jax
from orbax.experimental.model.core.protos import manifest_pb2
Expand All @@ -28,6 +31,12 @@
from tensorflow.compiler.xla import xla_pb2
from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2

# A mapping between XLAflag names and protobuf field names.
_XLA_FLAG_TO_FIELD_MAP = {
field.name: field
for field in tpu_comp_env_pb2.TpuCompilationEnvironment.DESCRIPTOR.fields
}


def generate_tpu_compilation_env(
xla_flags: Sequence[str] | None = None,
Expand All @@ -39,16 +48,21 @@ def generate_tpu_compilation_env(
tpu_compilation_env_str
)
# Override with supplied XLA flags if any is provided.
if xla_flags is not None:
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
xla_flags_str = '\n'.join(xla_flags)
try:
text_format.Parse(xla_flags_str, env_override)
except text_format.ParseError as e:
raise ValueError(
f'Error parsing supplied XLA flag overrides {xla_flags_str}.'
) from e
env.MergeFrom(env_override)
if xla_flags:
is_proto_formatted = False if xla_flags[0].startswith('--') else True
if is_proto_formatted:
merge_proto_formatted_flags_compile_option(xla_flags, env)
else:
parsed_flags = {}
for flag in xla_flags:
if not flag.startswith('--'):
raise ValueError(
f"Flag {flag} does not start with '--'. All flags must be in the"
' format of --flag_name=flag_value.'
)
flag_name, flag_value = flag[2:].split('=', 1)
parsed_flags[flag_name] = flag_value
merge_flags_into_compile_options(parsed_flags, env)

# Pack the TPU compilation environment into a compilation env proto.
any_proto = any_pb2.Any()
Expand Down Expand Up @@ -109,8 +123,13 @@ def generate_xla_compile_options(
"""Sets the XLA compilation options.

Args:
native_serialization_platforms: A sequence of platform names that the
compile options will be set for. If None, the compile options will be set
for TPU only.
xla_flags_per_platform: A mapping from platform name to a list of xla flags
which will be used to override the default XLA compilation flags.
jax_mesh: The JAX mesh used for sharding. If None, the compile options will
be set for a default single-replica.

Returns:
A `CompileOptionsProtoMap` containing the XLA compilation options per
Expand Down Expand Up @@ -156,3 +175,103 @@ def generate_xla_compile_options(
generate_compilation_options(compile_environment, jax_mesh)
)
return compile_options_map


def get_field_for_flag(flag_name: str) -> descriptor.FieldDescriptor:
"""Gets the protobuf field descriptor for a given flag name."""
if flag_name not in _XLA_FLAG_TO_FIELD_MAP:
raise ValueError(
f'No TpuCompilationEnvironment field matching flag {flag_name}'
)
return _XLA_FLAG_TO_FIELD_MAP[flag_name]


def parse_flag_from_string(flag_name: str, value: str) -> Any:
"""Parses a string value for a given flag and normalizes it for a proto field.

This is a Python implementation of the C++ function
TpuCompEnvReflection::ParseFlagFromString.

Args:
flag_name: The name of the flag.
value: The string value of the flag.

Returns:
The parsed and normalized value suitable for setting the corresponding field
in `TpuCompilationEnvironment`. This can be a primitive type (int, bool,
str), float, an enum's integer value, or a proto message instance.

Raises:
ValueError: If the flag is not found, or if a proto message value cannot
be parsed.
"""
try:
flag_holder = flags.FLAGS[flag_name]
except KeyError:
raise ValueError(f'Flag not found: {flag_name}')

parsed_value = flag_holder.parser.parse(value)
field = get_field_for_flag(flag_name)

if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
message_instance = field.message_type._concrete_class()
try:
text_format.Parse(value, message_instance)
return message_instance
except text_format.ParseError as e:
raise ValueError(
f'Error parsing proto value for flag {flag_name}: {value}'
) from e
if field.type == descriptor.FieldDescriptor.TYPE_ENUM:
if isinstance(parsed_value, str):
return field.enum_type.values_by_name[parsed_value].number
# If it's already an int, assume it's the correct value.
return parsed_value
if field.type in (
descriptor.FieldDescriptor.TYPE_FLOAT,
descriptor.FieldDescriptor.TYPE_DOUBLE,
):
return float(parsed_value)
return parsed_value


def merge_flags_into_compile_options(
xla_flags: Mapping[str, str],
env: tpu_comp_env_pb2.TpuCompilationEnvironment,
):
"""Merges flags into a TpuCompilationEnvironment proto.

Args:
xla_flags: A mapping of XLA flag names to their string values. These flags
will be parsed and merged into the `env` proto.
env: The TpuCompilationEnvironment proto to merge the flags into. This
proto will be modified in place.
"""
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
for flag_name, value in xla_flags.items():
field_descriptor = get_field_for_flag(flag_name)
parsed_value = parse_flag_from_string(flag_name, value)
if field_descriptor.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
# For message types, we need to copy the parsed message.
getattr(env_override, field_descriptor.name).CopyFrom(parsed_value)
else:
# For scalar types, we can set the attribute directly.
setattr(env_override, field_descriptor.name, parsed_value)
env.MergeFrom(env_override)


# TODO(b/438187387): remove this path and only allow the "--flag=value" format.
def merge_proto_formatted_flags_compile_option(
xla_flags: Sequence[str],
env: tpu_comp_env_pb2.TpuCompilationEnvironment,
):
"""Merges flags into a proto."""
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
xla_flags_str = '\n'.join(xla_flags)
try:
text_format.Parse(xla_flags_str, env_override)
except text_format.ParseError as e:
raise ValueError(
f'Error parsing supplied XLA flag overrides {xla_flags_str}.'
) from e
env.MergeFrom(env_override)
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
from orbax.experimental.model.core.python import compile_options_util
from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2


class CompileOptionsUtilTest(parameterized.TestCase):

def test_parse_flag_from_string_bool(self):
result = compile_options_util.parse_flag_from_string(
'xla_sc_poison_buffers', 'false'
)
self.assertEqual(result, False)

def test_parse_flag_from_string_int(self):
result = compile_options_util.parse_flag_from_string(
'xla_jf_rematerialization_percent_shared_memory_limit', '99'
)
self.assertEqual(result, 99)

def test_parse_flag_from_string_float(self):
result = compile_options_util.parse_flag_from_string(
'xla_tpu_async_copy_bandwidth_scaling_factor', '0.19125064716453793'
)
self.assertEqual(result, 0.19125064716453793)

def test_parse_flag_from_string_string(self):
result = compile_options_util.parse_flag_from_string(
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers',
'NO_SCALE',
)
self.assertEqual(result, 'NO_SCALE')

def test_parse_flag_from_string_proto(self):
compile_options_util.parse_flag_from_string(
'xla_tpu_memory_bound_loop_optimizer_options', 'enabled:false'
)

def test_parse_flag_from_string_enum(self):
result = compile_options_util.parse_flag_from_string(
'xla_memory_scheduler', 'DFS'
)
expected = tpu_comp_env_pb2.MemorySchedulerProto.DFS
self.assertEqual(result, expected)

def test_parse_flag_from_string_nonexistent_flag(self):
with self.assertRaisesRegex(ValueError, 'Flag not found: nonexistent_flag'):
compile_options_util.parse_flag_from_string('nonexistent_flag', 'value')

@parameterized.named_parameters(
(
'dict_xla_flags',
{
'xla_jf_rematerialization_percent_shared_memory_limit': '99',
'xla_tpu_allocate_scoped_vmem_at_same_offset': 'false',
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers': (
'NO_SCALE'
),
'xla_tpu_memory_bound_loop_optimizer_options': 'enabled:false',
'xla_tpu_async_copy_bandwidth_scaling_factor': (
'0.19125064716453793'
),
},
compile_options_util.merge_flags_into_compile_options,
),
(
'proto_formatted_xla_flags',
[
'xla_jf_rematerialization_percent_shared_memory_limit: 99',
'xla_tpu_allocate_scoped_vmem_at_same_offset: false',
(
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers:'
" 'NO_SCALE'"
),
'xla_tpu_memory_bound_loop_optimizer_options: {enabled:false}',
(
'xla_tpu_async_copy_bandwidth_scaling_factor:'
' 0.19125064716453793'
),
],
compile_options_util.merge_proto_formatted_flags_compile_option,
),
)
def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
# Initialize the environment with some values.
env = tpu_comp_env_pb2.TpuCompilationEnvironment()
# Values that should be overridden.
env.xla_jf_rematerialization_percent_shared_memory_limit = 10
env.xla_tpu_memory_bound_loop_optimizer_options.enabled = True
# Value that should not be overridden.
env.xla_tpu_wait_n_cycles_before_program_termination = 1234

# Merge the flags into the environment.
merge_fn(xla_flags, env)
self.assertEqual(
env.xla_jf_rematerialization_percent_shared_memory_limit, 99
)
self.assertEqual(env.xla_tpu_allocate_scoped_vmem_at_same_offset, False)
self.assertEqual(
env.xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers,
'NO_SCALE',
)
self.assertEqual(
env.xla_tpu_memory_bound_loop_optimizer_options.enabled, False
)
self.assertAlmostEqual(
env.xla_tpu_async_copy_bandwidth_scaling_factor,
0.19125064716453793,
)

# Value that should not be overridden.
self.assertEqual(env.xla_tpu_wait_n_cycles_before_program_termination, 1234)


if __name__ == '__main__':
absltest.main()
Loading