Skip to content

Commit

Permalink
[Neuron] Fix the XLADevice Neuron mappings for SPMD downcasts (#8335)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Oct 31, 2024
1 parent 1bac062 commit 7c7ad4e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3667,7 +3667,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward(
// our XLA lowering.
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(grad_output_tensor->GetDevice().type());
if (!CheckTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON) {
if (!CheckTpuDevice(hw_type) && !CheckNeuronDevice(hw_type)) {
return at::native::call_fallback_fn<
&xla_fallback, ATEN_OP(upsample_nearest2d_backward)>::call(grad_output,
output_size,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bool IsSparseGather(const xla::Shape& input_shape,
// to avoid gather on a single float on TPU.
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) {
if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) {
// XLA_DENSE_GATHER_FACTOR can be used to finely control the
// sparsity check.
static int dense_gather_factor =
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,16 @@ bool CheckTpuDevice(XlaDeviceType hw_type) {
return false;
}

bool CheckNeuronDevice(XlaDeviceType hw_type) {
if (hw_type == XlaDeviceType::NEURON) {
return true;
}

std::string pjrt_device = runtime::sys_util::GetEnvString("PJRT_DEVICE", "");
if (hw_type == XlaDeviceType::SPMD) {
return pjrt_device == "NEURON";
}
return false;
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ bool GetLockSpmdConfig();
// TODO(yeounoh) - see if we need to check for AOT compilation device type.
bool CheckTpuDevice(XlaDeviceType hw_type);

// Return true if the physical device type is NEURON.
bool CheckNeuronDevice(XlaDeviceType hw_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_DEVICE_H_
10 changes: 5 additions & 5 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || hw_type == XlaDeviceType::NEURON) {
if (DowncastBF16() || CheckNeuronDevice(hw_type)) {
return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
case xla::PrimitiveType::F32:
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
: xla::PrimitiveType::U16;
case xla::PrimitiveType::S16:
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S16;
case xla::PrimitiveType::S64:
return xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/resize_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input,

XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) {
if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) {
// TPU uses custom call implementation
resized =
xla::CustomCall(input.builder(), target, {tinput}, resized_shape,
Expand Down

0 comments on commit 7c7ad4e

Please sign in to comment.