Skip to content

Commit

Permalink
Update XLA pin to 10/16 (#8267)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Oct 21, 2024
1 parent b8ee2c2 commit 6536639
Show file tree
Hide file tree
Showing 22 changed files with 28 additions and 27 deletions.
5 changes: 3 additions & 2 deletions .github/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ function run_torch_xla_cpp_tests() {
"test_aten_xla_tensor_2"
"test_aten_xla_tensor_3"
"test_aten_xla_tensor_4"
"pjrt_computation_client_test"
"ifrt_computation_client_test")
"pjrt_computation_client_test")
# Disable IFRT test as it currently crashes
#"ifrt_computation_client_test")
test_names2=("test_aten_xla_tensor_5"
"test_aten_xla_tensor_6"
"test_ir"
Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4'
xla_hash = 'eef7ee50d0980848436f0b4f402cec8c5bf86f21'

http_archive(
name = "xla",
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240913'
_date = '20241015'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl'
_jax_version = f'0.4.33.dev{_date}'
_jax_version = f'0.4.35.dev{_date}'


def _get_build_mode():
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ptxla_cc_test(
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:xla_data_proto_cc",
"@tsl//tsl/profiler/utils:session_manager",
"@xla//xla/tsl/profiler/utils:session_manager",
],
)

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ ptxla_cc_library(
"@xla//xla/client/lib:slicing",
"@xla//xla/client/lib:sorting",
"@xla//xla/client/lib:svd",
"@xla//xla/hlo/pass:hlo_pass_pipeline",
"@xla//xla/stream_executor:dnn",
"@tsl//tsl/platform:errors",
"@tsl//tsl/profiler/lib:traceme",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
#include "torch_xla/csrc/xla_sharding_util.h"
#include "tsl/platform/env.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/python/profiler/internal/traceme_wrapper.h"
#include "xla/service/hlo_parser.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "torch_xla/csrc/shape_helper.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/qr.h"
#include "xla/hlo/builder/lib/qr.h"
#include "xla/shape_util.h"
#include "xla/util.h"

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "torch_xla/csrc/xla_lower_util.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/hlo/builder/lib/logdet.h"

namespace torch_xla {
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/qr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/qr.h"
#include "xla/hlo/builder/lib/qr.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/randperm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/loops.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/pooling.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/prng.h"
#include "xla/hlo/builder/lib/prng.h"

namespace torch_xla {
namespace {
Expand Down
6 changes: 1 addition & 5 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,7 @@ IfrtComputationClient::ExecuteReplicated(
counter.Wait();
}

xla::ExecuteOptions execute_options;
execute_options.untuple_result = options.explode_tuple;
execute_options.strict_shape_checking = true;
// TODO(yeounoh) currently only support single-slice execution
execute_options.multi_slice_config = nullptr;
xla::ifrt::ExecuteOptions execute_options;

TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
<< spmd_device_str;
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/pjrt_dtype.h"
#include "xla/python/pjrt_ifrt/xla_compiler.h"
#include "xla/shape.h"

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name,
args.function_name = fn_name.c_str();
args.function_name_size = fn_name.size();
args.api_version = 0;
args.custom_call_function = function_ptr;
args.handler_execute = function_ptr;
PJRT_Error* error =
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args);
if (error) {
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ ComputationClient* GetComputationClient() {

std::unique_ptr<ComputationClient> client;

static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
// Disable IFRT right now as it currently crashes.
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
static bool use_ifrt = false;
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
if (use_ifrt) {
client = std::make_unique<IfrtComputationClient>();
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/comparators.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/dnn.h"
#include "xla/util.h"
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_op_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/tensor_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/pooling.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/primitive_util.h"
#include "xla/shape_util.h"

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/protobuf_util.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/hlo_verifier.h"
#include "xla/service/sharding_propagation.h"
#include "xla/service/spmd/spmd_partitioner.h"
Expand Down

0 comments on commit 6536639

Please sign in to comment.