Skip to content

Commit

Permalink
Add functions to emit custom call to place a buffer to host and device.
Browse files Browse the repository at this point in the history
This is used for host-offloading.

example code of what jax emits:
```python
def policy(prim, *avals, **params) -> Offloadable:
  return Offloadable(src='device', dst='pinned_host')

@functools.partial(jax.remat, policy=policy)
def f(x):
  x = jnp.sin(x)
  x = jnp.sin(x)
  return jnp.sum(x)
```

becomes:
```mlir
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<16xf32>
    %1 = stablehlo.cosine %arg0 : tensor<16xf32>
    %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %3 = stablehlo.cosine %0 : tensor<16xf32>
    %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32>
    %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %9 = stablehlo.multiply %8, %7 : tensor<16xf32>
    %10 = stablehlo.multiply %9, %6 : tensor<16xf32>
    return %10 : tensor<16xf32>
  }
}
```
  • Loading branch information
qihqi committed Nov 1, 2024
1 parent 7c7ad4e commit fc408f2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 14 deletions.
19 changes: 18 additions & 1 deletion test/stablehlo/test_stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch_xla.core.xla_model as xm
import torch_xla.experimental.stablehlo_custom_call
from torch.library import Library, impl, impl_abstract
from torch_xla.experimental.stablehlo_custom_call import stablehlo_custom_call
from torch_xla.experimental.stablehlo_custom_call import (
stablehlo_custom_call, place_to_host, place_to_device)
from torch_xla.stablehlo import (StableHLOExportOptions,
exported_program_to_stablehlo)

Expand Down Expand Up @@ -115,6 +116,22 @@ def forward(self, x):
# self.assertTrue("api_version = 1" in shlo_text)


def test_place_to_host_device(self):
dev = xm.xla_device()
a = torch.ones(10, device=dev)
b = place_to_host(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"pinned_host\"}}" in shlo_text)

a = torch.ones(10, device=dev)
b = place_to_device(a)
shlo_text = xm.get_stablehlo([b])
self.assertTrue("has_side_effect = true" in shlo_text)
self.assertTrue("mhlo.frontend_attributes = {_xla_buffer_placement = \"device\"}}" in shlo_text)



if __name__ == "__main__":

test = unittest.main()
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2555,8 +2555,9 @@ void InitXlaModuleBindings(py::module m) {
[](const std::vector<at::Tensor>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes, bool has_side_effect,
const std::string& backend_config,
const int api_version) -> std::vector<at::Tensor> {
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>&
frontend_attributes) -> std::vector<at::Tensor> {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
Expand All @@ -2566,7 +2567,8 @@ void InitXlaModuleBindings(py::module m) {

auto xtensors = tensor_methods::custom_call(
bridge::GetXlaTensors(inputs), target, output_shapes, dtypes,
has_side_effect, backend_config, api_version);
has_side_effect, backend_config, api_version,
frontend_attributes);
return bridge::AtenFromXlaTensors(std::move(xtensors));
});
m.def("_xla_tpu_custom_call",
Expand Down
26 changes: 21 additions & 5 deletions torch_xla/csrc/ops/custom_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,27 @@

namespace torch_xla {

CustomCall::CustomCall(torch::lazy::OpList inputs,
const std::string& call_target, xla::Shape output_shape,
bool has_side_effect, const std::string& backend_config,
const int api_version)
CustomCall::CustomCall(
torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes)
: XlaNode(xla_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/output_shape.tuple_shapes_size(),
torch::lazy::MHash(call_target)),
call_target_(call_target),
has_side_effect_(has_side_effect),
backend_config_(backend_config),
api_version_(api_version) {}
api_version_(api_version),
frontend_attributes_(frontend_attributes) {}

CustomCall::CustomCall(torch::lazy::OpList inputs,
const std::string& call_target, xla::Shape output_shape,
bool has_side_effect, const std::string& backend_config,
const int api_version)
: CustomCall(inputs, call_target, output_shape, has_side_effect,
backend_config, api_version,
std::unordered_map<std::string, std::string>()) {}

torch::lazy::NodePtr CustomCall::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CustomCall>(operands, call_target_,
Expand All @@ -38,6 +48,12 @@ XlaOpVector CustomCall::Lower(LoweringContext* loctx) const {
output_shape = output_shape.tuple_shapes(0);
}
XLA_CHECK(api_version_ >= 0 && api_version_ < 5);

xla::FrontendAttributes feattr;
feattr.mutable_map()->insert(frontend_attributes_.begin(),
frontend_attributes_.end());
xla::XlaScopedFrontendAttributesAssignment feattr_assign(inputs[0].builder(),
feattr);
xla::XlaOp output = xla::CustomCall(
inputs[0].builder(), call_target_, inputs, output_shape,
/*opaque=*/backend_config_,
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/ops/custom_call.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_
#define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_CALL_H_

#include <unordered_map>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
Expand All @@ -10,6 +12,11 @@ class CustomCall : public XlaNode {
CustomCall(torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version);
CustomCall(
torch::lazy::OpList inputs, const std::string& call_target,
xla::Shape output_shape, bool has_side_effect,
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes);

std::string ToString() const override;

Expand All @@ -22,6 +29,7 @@ class CustomCall : public XlaNode {
bool has_side_effect_;
std::string backend_config_;
int api_version_;
std::unordered_map<std::string, std::string> frontend_attributes_;
};

} // namespace torch_xla
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version) {
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";

std::vector<torch::lazy::Value> values;
Expand All @@ -584,7 +585,7 @@ std::vector<XLATensorPtr> custom_call(

auto node = torch_xla::MakeNode<CustomCall>(
values, target, xla::ShapeUtil::MakeTupleShape(output_xla_shapes),
has_side_effect, backend_config, api_version);
has_side_effect, backend_config, api_version, frontend_attributes);

std::vector<XLATensorPtr> outputs;
outputs.reserve(output_shapes.size());
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<at::ScalarType>& output_dtypes, bool has_side_effect,
const std::string& backend_config, const int api_version);
const std::string& backend_config, const int api_version,
const std::unordered_map<std::string, std::string>& frontend_attributes);

void custom_sharding_(
const XLATensorPtr& input,
Expand Down
18 changes: 16 additions & 2 deletions torch_xla/experimental/stablehlo_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ def stablehlo_custom_call(args,
output_dtypes,
has_side_effect=False,
backend_config="",
api_version=0):
api_version=0,
frontend_attributes=None):
frontend_attributes = frontend_attributes or {}
res = torch_xla._XLAC._xla_custom_call(args, call_target, output_shapes,
output_dtypes, has_side_effect,
backend_config, api_version)
backend_config, api_version,
frontend_attributes)
if len(output_shapes) == 1:
return res[0]
return res
Expand All @@ -29,3 +32,14 @@ def extract_custom_call_outputs_shape_dtype(n: torch.fx.Node):
assert None not in output_shape_dtype
output_shape, output_dtype = zip(*output_shape_dtype)
return output_shape, output_dtype


def place_to_host(a: torch.Tensor):
return stablehlo_custom_call([a], "annotate_device_placement",
[a.shape],[a.dtype], has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "pinned_host"})

def place_to_device(a: torch.Tensor):
return stablehlo_custom_call([a], "annotate_device_placement",
[a.shape],[a.dtype], has_side_effect=True,
frontend_attributes={"_xla_buffer_placement": "device"})

0 comments on commit fc408f2

Please sign in to comment.