Skip to content

Commit

Permalink
Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549801858
  • Loading branch information
sharadmv authored and jax authors committed Jul 21, 2023
1 parent dcad04d commit 3d556b7
Show file tree
Hide file tree
Showing 30 changed files with 6,457 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic/__init__.py --ignore=jax/experimental/mosaic/dialects.py
documentation_render:
Expand Down
36 changes: 36 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ load(
"jax_internal_packages",
"jax_test_util_visibility",
"jax_visibility",
"mosaic_internal_users",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
Expand Down Expand Up @@ -59,6 +60,13 @@ package_group(
] + jax_internal_packages,
)

package_group(
name = "mosaic_users",
packages = [
"//...",
] + mosaic_internal_users,
)

# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
Expand Down Expand Up @@ -556,6 +564,21 @@ pytype_strict_library(
deps = [":basearray"] + py_deps("numpy"),
)

pytype_strict_library(
name = "tpu_custom_call",
srcs = ["_src/tpu_custom_call.py"],
visibility = [":internal"],
deps = [
":core",
":jax",
"//jax/_src/lib",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:stablehlo_dialect",
] + py_deps("numpy") + py_deps("absl/flags"),
)

pytype_strict_library(
name = "util",
srcs = ["_src/util.py"],
Expand Down Expand Up @@ -733,6 +756,19 @@ pytype_library(
],
)

pytype_library(
name = "mosaic",
srcs = [
"experimental/mosaic/__init__.py",
"experimental/mosaic/dialects.py",
],
visibility = [":mosaic_users"],
deps = [
":tpu_custom_call",
"//jax/_src/lib",
],
)

pytype_library(
name = "rnn",
srcs = ["experimental/rnn.py"],
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def _xla_gc_callback(*args):

import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
try:
import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error
except ImportError:
# TODO(sharadmv): Remove this when minimum jaxlib version >= 0.4.14
# Jaxlib doesn't contain Mosaic bindings
tpu_mosaic = None # type: ignore

# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
Expand Down
300 changes: 300 additions & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX bindings for Mosaic."""

# mypy: ignore-errors

import base64
import collections.abc
import dataclasses
import functools
import io
from typing import Any, Callable, Sequence

from absl import flags
import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from mlir import ir
from mlir.dialects import mhlo
from mlir.dialects import stablehlo
from mlir.passmanager import PassManager
from jax._src.lib import tpu_mosaic
import numpy as np

# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
if tpu_mosaic is None:
raise ValueError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout

_ALLOW_HLO = flags.DEFINE_bool(
"jax_mosaic_allow_hlo", False, "Allow hlo dialect in mosaic"
)

tpu_custom_call_p = core.Primitive("tpu_custom_call")
tpu_custom_call_p.def_impl(
functools.partial(xla.apply_primitive, tpu_custom_call_p))
tpu_custom_call_p.multiple_results = True


@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
"""Represents an unserialized backend config for custom calls."""
lowered_module_asm: bytes
has_communication: bool
collective_id: int | None

# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
def __repr__(self):
return "CustomCallBackendConfig(<omitted>)"

def to_json(self):
"""Serializes the backend config into JSON."""
# We format the JSON ourselves, because json.dumps seems to be overly slow.
config = io.BytesIO()
config.write(b'{"custom_call_config": {"body": "')
config.write(base64.b64encode(self.lowered_module_asm))
config.write(b'"')
if self.has_communication:
config.write(b', "has_communication": ')
config.write(str(self.has_communication).lower().encode("ascii"))
if self.collective_id is not None:
config.write(b', "collective_id": ')
config.write(str(self.collective_id).encode("ascii"))
config.write(b"}}")
return config.getvalue()


@tpu_custom_call_p.def_abstract_eval
def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
return out_avals


def _aval_to_layout(aval):
arange = np.arange(aval.ndim, dtype=np.dtype(np.int64))[::-1].copy()
return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get())


def _avals_to_layouts(avals):
return ir.ArrayAttr.get([_aval_to_layout(a) for a in avals])


def _tpu_custom_call_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes, # pylint: disable=missing-function-docstring
config: CustomCallBackendConfig,
out_avals: Any,
) -> ...:
i32_type = ir.IntegerType.get_signless(32)
multiple_results = len(out_avals) > 1
if multiple_results:
result_type = ir.TupleType.get_tuple(
[mlir.aval_to_ir_type(aval) for aval in out_avals]
)
else:
result_type = mlir.aval_to_ir_type(out_avals[0])
axis_context = ctx.module_context.axis_context
sharding_impls = jax._src.sharding_impls # pylint: disable=protected-access
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif isinstance(axis_context, sharding_impls.ShardingContext):
if len(axis_context.device_assignment) != 1:
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif config.has_communication:
raise NotImplementedError(
"Replica lowering for Mosaic kernels not implemented."
)
call = stablehlo.CustomCallOp(
[result_type],
in_nodes,
call_target_name=ir.StringAttr.get(b"tpu_custom_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(config.to_json()),
api_version=ir.IntegerAttr.get(i32_type, 1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=_avals_to_layouts(ctx.avals_in),
result_layouts=_avals_to_layouts(ctx.avals_out),
output_operand_aliases=None,
)
if multiple_results:
results = [stablehlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
for i in range(len(out_avals))]
else:
results = call.results
return results


mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
platform="tpu")


def _lower_tpu_kernel(module: ir.Module, hardware_generation: int) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
Returns:
A pair containing an MLIR module implementing the kernel specified by the
argument and a tuple of additional constant arguments that should be
appended to the kernel invocation.
"""
try:
module.operation.verify()
except ir.MLIRError as e:
raise ValueError("The compiled module fails MLIR verification") from e

with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()
# We'll mutate the module, so clone it.
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)

if _ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
"func.func(hlo-legalize-to-linalg)",
"func.func(linalg-vectorization)",
]
PassManager.parse(f"builtin.module({','.join(pipeline)})").run(
module.operation
)

infer_memref_layout.infer_module(module, hardware_generation)

pipeline = [
"canonicalize",
"cse",
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
module.operation.verify()

apply_vector_layout.apply(module, hardware_generation)
module.operation.verify()

PassManager.parse("builtin.module(canonicalize)").run(module.operation)

vector_constants = []
for f in module.body:
if "vector_constants" not in f.attributes:
continue
if f.name.value != "main":
raise NotImplementedError(
"Only the main function can have non-splat vector constants"
)
constant_attrs = ir.ArrayAttr(f.attributes["vector_constants"])
del f.attributes["vector_constants"]
for c in constant_attrs:
c = ir.DenseElementsAttr(c)
constant_type = ir.VectorType(c.type)
if constant_type.element_type == ir.IntegerType.get_signless(32):
dtype = np.int32
elif ir.F32Type.isinstance(constant_type.element_type):
dtype = np.float32
else:
raise NotImplementedError(constant_type.element_type)
if np.issubdtype(dtype, np.integer):
c = ir.DenseIntElementsAttr(c)
elif np.issubdtype(dtype, np.floating):
c = ir.DenseFPElementsAttr(c)
else:
raise NotImplementedError(dtype)
vector_constants.append(
np.asarray(c, dtype=dtype).reshape(constant_type.shape))
bytecode_buffer = io.BytesIO()
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
return bytecode_buffer.getvalue(), tuple(vector_constants)


def as_tpu_kernel(
module: ir.Module,
out_type: Any,
*,
backend: str = "tpu",
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if device_kind.endswith(" pod"):
device_kind = device_kind[:-len(" pod")]
if device_kind.endswith(" lite"):
device_kind = device_kind[:-len(" lite")]
assert device_kind[:-1] == "TPU v", device_kind
hardware_generation = int(device_kind[-1])
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
)
lowered_module_asm, constants = _lower_tpu_kernel(module, hardware_generation)
return _lowered_as_tpu_kernel(
lowered_module_asm,
out_type,
constants,
has_communication=has_communication,
has_custom_barrier=has_custom_barrier,
)


def _lowered_as_tpu_kernel(
lowered_module_asm: bytes,
out_type: Any,
constants: Sequence[Any] = (),
*,
has_communication: bool = False,
has_custom_barrier: bool = False,
):
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
unpack = False
if not isinstance(out_type, collections.abc.Iterable):
out_type = (out_type,)
unpack = True
out_avals = tuple(core.ShapedArray(ty.shape, ty.dtype) for ty in out_type)
def apply_kernel(*args, collective_id: int | None = None):
if has_custom_barrier:
if collective_id is None:
raise ValueError(
"collective_id has to be specified when using a custom barrier"
)
elif collective_id is not None:
raise ValueError(
"collective_id has to be unspecified or None when not using a custom"
" barrier"
)
config = CustomCallBackendConfig(
lowered_module_asm, has_communication, collective_id
)
result = tpu_custom_call_p.bind(
*args, *constants, config=config, out_avals=out_avals)
return result[0] if unpack else result
return jax.jit(apply_kernel, static_argnames=["collective_id"])
Loading

0 comments on commit 3d556b7

Please sign in to comment.