From cfff62bf520818537dac4c582289906ce505b479 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 06:44:56 +0000 Subject: [PATCH 1/8] Use torch::lazy::Data inDeviceContext --- torch_xla/csrc/xla_graph_executor.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 250141116180..cf4668e0a62e 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -76,7 +76,7 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) { class DeviceContextArena { struct DeviceContext { std::mutex lock; - std::map> tensors_data; + std::map> tensors_data; uint64_t seed = 101; uint64_t running_seed = 101; torch::lazy::Value seed_ir_value; @@ -88,14 +88,14 @@ class DeviceContextArena { return arena; } - void RegisterTensor(std::shared_ptr data) { + void RegisterTensor(std::shared_ptr data) { DeviceContext* devctx = GetDeviceContext(data->device); std::lock_guard lock(devctx->lock); devctx->tensors_data.emplace(data->unique_id, data); TORCH_LAZY_COUNTER("CreateXlaTensor", 1); } - void UnregisterTensor(XLATensor::Data* data) { + void UnregisterTensor(torch::lazy::LazyTensor::Data* data) { DeviceContext* devctx = GetDeviceContext(data->device); std::lock_guard lock(devctx->lock); devctx->tensors_data.erase(data->unique_id); @@ -108,7 +108,7 @@ class DeviceContextArena { auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); for (auto& uid_wptr : devctx->tensors_data) { - std::shared_ptr data = uid_wptr.second.lock(); + auto data = std::dynamic_pointer_cast(uid_wptr.second.lock()); if (data != nullptr) { tensors.push_back( c10::make_intrusive(XLATensor(std::move(data)))); From 8918a3d8f10d8c59d1c35170839842c8895b068e Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 07:31:30 +0000 Subject: [PATCH 2/8] Make DeviceContextArena proctected --- torch_xla/csrc/xla_graph_executor.cpp | 44 +++++++------------------ torch_xla/csrc/xla_graph_executor.h | 46 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index cf4668e0a62e..f7fd6e4bfa22 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -68,41 +68,26 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) { } // namespace -// The DeviceContextArena holds per device live information and statistics, -// among which the XLA tensors which are currently alive in the system. This is -// used to create XLA computation "barriers" in order to flush pending -// operations and ensure the same XLA computations are created during the -// training loops. -class DeviceContextArena { - struct DeviceContext { - std::mutex lock; - std::map> tensors_data; - uint64_t seed = 101; - uint64_t running_seed = 101; - torch::lazy::Value seed_ir_value; - }; - - public: - static DeviceContextArena* Get() { + auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { static DeviceContextArena* arena = new DeviceContextArena(); return arena; } - void RegisterTensor(std::shared_ptr data) { + void XLAGraphExecutor::DeviceContextArena::RegisterTensor(std::shared_ptr data) { DeviceContext* devctx = GetDeviceContext(data->device); std::lock_guard lock(devctx->lock); devctx->tensors_data.emplace(data->unique_id, data); TORCH_LAZY_COUNTER("CreateXlaTensor", 1); } - void UnregisterTensor(torch::lazy::LazyTensor::Data* data) { + void XLAGraphExecutor::DeviceContextArena::UnregisterTensor(torch::lazy::LazyTensor::Data* data) { DeviceContext* devctx = GetDeviceContext(data->device); std::lock_guard lock(devctx->lock); devctx->tensors_data.erase(data->unique_id); TORCH_LAZY_COUNTER("DestroyXlaTensor", 1); } - std::vector GetLiveTensors( + std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( const torch::lazy::BackendDevice* device) { std::vector tensors; auto fn = [&](DeviceContext* devctx) { @@ -119,7 +104,7 @@ class DeviceContextArena { return tensors; } - torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device) { + torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(const torch::lazy::BackendDevice& device) { static const at::ScalarType kSeedType = at::ScalarType::Long; static const uint64_t kSeedMul = 214013; static const uint64_t kSeedAdd = 2531011; @@ -142,7 +127,7 @@ class DeviceContextArena { return devctx->seed_ir_value; } - torch::lazy::BackendDataPtr GetBaseSeedData( + torch::lazy::BackendDataPtr XLAGraphExecutor::DeviceContextArena::GetBaseSeedData( const torch::lazy::BackendDevice& device) { static const at::ScalarType kSeedType = at::ScalarType::Long; DeviceContext* devctx = GetDeviceContext(device); @@ -156,13 +141,13 @@ class DeviceContextArena { ->data(); } - uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device) { + uint64_t XLAGraphExecutor::DeviceContextArena::GetRunningSeed(const torch::lazy::BackendDevice& device) { DeviceContext* devctx = GetDeviceContext(device); std::lock_guard lock(devctx->lock); return devctx->running_seed; } - void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed) { + void XLAGraphExecutor::DeviceContextArena::SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed) { DeviceContext* devctx = GetDeviceContext(device); std::lock_guard lock(devctx->lock); devctx->seed = seed; @@ -170,7 +155,7 @@ class DeviceContextArena { devctx->seed_ir_value = torch::lazy::Value(); } - void MarkStep(const torch::lazy::BackendDevice& device) { + void XLAGraphExecutor::DeviceContextArena::MarkStep(const torch::lazy::BackendDevice& device) { DeviceContext* devctx = GetDeviceContext(device); std::lock_guard lock(devctx->lock); devctx->seed = 1012031 + devctx->seed * 7012063; @@ -178,8 +163,7 @@ class DeviceContextArena { devctx->seed_ir_value = torch::lazy::Value(); } - private: - std::vector GetAllDeviceContexts() { + auto XLAGraphExecutor::DeviceContextArena::GetAllDeviceContexts() -> std::vector { std::vector all_device_contexts; std::lock_guard lock(lock_); all_device_contexts.reserve(device_contexts_.size()); @@ -189,7 +173,7 @@ class DeviceContextArena { return all_device_contexts; } - void ForAllDeviceContexts(const std::function& fn, + void XLAGraphExecutor::DeviceContextArena::ForAllDeviceContexts(const std::function& fn, const torch::lazy::BackendDevice* device) { if (device == nullptr) { for (auto devctx : GetAllDeviceContexts()) { @@ -200,7 +184,7 @@ class DeviceContextArena { } } - DeviceContext* GetDeviceContext(const torch::lazy::BackendDevice& device) { + auto XLAGraphExecutor::DeviceContextArena::GetDeviceContext(const torch::lazy::BackendDevice& device) -> DeviceContext* { std::lock_guard lock(lock_); auto it = device_contexts_.find(device); if (it == device_contexts_.end()) { @@ -209,10 +193,6 @@ class DeviceContextArena { return it->second; } - std::mutex lock_; - std::map device_contexts_; -}; - XLAGraphExecutor::Async::Async( SyncTensorCollection* coll, std::vector parameters_data, diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 46ba8503f2bd..a375c70fa0bd 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -175,6 +175,52 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector tensors_data; }; + // The DeviceContextArena holds per device live information and statistics, + // among which the XLA tensors which are currently alive in the system. This is + // used to create XLA computation "barriers" in order to flush pending + // operations and ensure the same XLA computations are created during the + // training loops. + class DeviceContextArena { + struct DeviceContext { + std::mutex lock; + std::map> tensors_data; + uint64_t seed = 101; + uint64_t running_seed = 101; + torch::lazy::Value seed_ir_value; + }; + + public: + static DeviceContextArena* Get(); + + void RegisterTensor(std::shared_ptr data); + void UnregisterTensor(torch::lazy::LazyTensor::Data* data); + + std::vector GetLiveTensors( + const torch::lazy::BackendDevice* device); + + torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device); + + torch::lazy::BackendDataPtr GetBaseSeedData( + const torch::lazy::BackendDevice& device); + + uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device); + + void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed); + + void MarkStep(const torch::lazy::BackendDevice& device); + + private: + std::vector GetAllDeviceContexts(); + + void ForAllDeviceContexts(const std::function& fn, + const torch::lazy::BackendDevice* device); + + DeviceContext* GetDeviceContext(const torch::lazy::BackendDevice& device); + + std::mutex lock_; + std::map device_contexts_; + }; + XLAGraphExecutor() = default; SyncTensorCollection CollectSyncTensors( From 12f6c4be978b150b50f84e4f7a3e3cf34d865170 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 08:35:43 +0000 Subject: [PATCH 3/8] Inherits torch::lazy::LazyGraphExecutor::DeviceContextArena --- torch_xla/csrc/xla_graph_executor.cpp | 174 +++++++++----------------- torch_xla/csrc/xla_graph_executor.h | 42 ++----- 2 files changed, 67 insertions(+), 149 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index f7fd6e4bfa22..4f3bf79affb8 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -68,130 +68,72 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) { } // namespace - auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { - static DeviceContextArena* arena = new DeviceContextArena(); - return arena; - } - - void XLAGraphExecutor::DeviceContextArena::RegisterTensor(std::shared_ptr data) { - DeviceContext* devctx = GetDeviceContext(data->device); - std::lock_guard lock(devctx->lock); - devctx->tensors_data.emplace(data->unique_id, data); - TORCH_LAZY_COUNTER("CreateXlaTensor", 1); - } +auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { + static DeviceContextArena* arena = new DeviceContextArena(); + return arena; +} - void XLAGraphExecutor::DeviceContextArena::UnregisterTensor(torch::lazy::LazyTensor::Data* data) { - DeviceContext* devctx = GetDeviceContext(data->device); +std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( + const torch::lazy::BackendDevice* device) { + std::vector tensors; + auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); - devctx->tensors_data.erase(data->unique_id); - TORCH_LAZY_COUNTER("DestroyXlaTensor", 1); - } - - std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( - const torch::lazy::BackendDevice* device) { - std::vector tensors; - auto fn = [&](DeviceContext* devctx) { - std::lock_guard lock(devctx->lock); - for (auto& uid_wptr : devctx->tensors_data) { - auto data = std::dynamic_pointer_cast(uid_wptr.second.lock()); - if (data != nullptr) { - tensors.push_back( - c10::make_intrusive(XLATensor(std::move(data)))); - } + for (auto& uid_wptr : devctx->tensors_data) { + auto data = std::dynamic_pointer_cast(uid_wptr.second.lock()); + if (data != nullptr) { + tensors.push_back( + c10::make_intrusive(XLATensor(std::move(data)))); } - }; - ForAllDeviceContexts(fn, device); - return tensors; - } - - torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(const torch::lazy::BackendDevice& device) { - static const at::ScalarType kSeedType = at::ScalarType::Long; - static const uint64_t kSeedMul = 214013; - static const uint64_t kSeedAdd = 2531011; - DeviceContext* devctx = GetDeviceContext(device); - std::lock_guard lock(devctx->lock); - if (!devctx->seed_ir_value) { - devctx->seed_ir_value = - IrValueFromScalar(MakeIntScalar(devctx->seed), kSeedType, device); } - // Keep the running seed as scalar as well, so we can return it directly - // without executing graphs. - devctx->running_seed = kSeedAdd + kSeedMul * devctx->running_seed; - // Compose new seeds from the root seed, to avoid creating too many XLA - // computation parameters which might overflow the TPU capacity. - torch::lazy::Value k = ScalarOp(MakeIntScalar(kSeedMul), - MakeXlaPrimitiveType(kSeedType, &device)); - torch::lazy::Value b = ScalarOp(MakeIntScalar(kSeedAdd), - MakeXlaPrimitiveType(kSeedType, &device)); - devctx->seed_ir_value = b + k * devctx->seed_ir_value; - return devctx->seed_ir_value; - } - - torch::lazy::BackendDataPtr XLAGraphExecutor::DeviceContextArena::GetBaseSeedData( - const torch::lazy::BackendDevice& device) { - static const at::ScalarType kSeedType = at::ScalarType::Long; - DeviceContext* devctx = GetDeviceContext(device); - std::lock_guard lock(devctx->lock); - at::Tensor tensor = at::scalar_tensor(MakeIntScalar(devctx->seed), - at::TensorOptions(kSeedType)); - torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device); - devctx->seed_ir_value = torch::lazy::MakeNode(device_data); - devctx->running_seed = devctx->seed; - return torch_xla::DeviceData::Cast(devctx->seed_ir_value.node.get()) - ->data(); - } - - uint64_t XLAGraphExecutor::DeviceContextArena::GetRunningSeed(const torch::lazy::BackendDevice& device) { - DeviceContext* devctx = GetDeviceContext(device); - std::lock_guard lock(devctx->lock); - return devctx->running_seed; - } - - void XLAGraphExecutor::DeviceContextArena::SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed) { - DeviceContext* devctx = GetDeviceContext(device); - std::lock_guard lock(devctx->lock); - devctx->seed = seed; - devctx->running_seed = devctx->seed; - devctx->seed_ir_value = torch::lazy::Value(); - } - - void XLAGraphExecutor::DeviceContextArena::MarkStep(const torch::lazy::BackendDevice& device) { - DeviceContext* devctx = GetDeviceContext(device); - std::lock_guard lock(devctx->lock); - devctx->seed = 1012031 + devctx->seed * 7012063; - devctx->running_seed = devctx->seed; - devctx->seed_ir_value = torch::lazy::Value(); - } + }; + ForAllDeviceContexts(fn, device); + return tensors; +} - auto XLAGraphExecutor::DeviceContextArena::GetAllDeviceContexts() -> std::vector { - std::vector all_device_contexts; - std::lock_guard lock(lock_); - all_device_contexts.reserve(device_contexts_.size()); - for (auto& device_contexts : device_contexts_) { - all_device_contexts.push_back(device_contexts.second); - } - return all_device_contexts; +torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(const torch::lazy::BackendDevice& device) { + static const at::ScalarType kSeedType = at::ScalarType::Long; + static const uint64_t kSeedMul = 214013; + static const uint64_t kSeedAdd = 2531011; + DeviceContext* devctx = GetDeviceContext(device); + std::lock_guard lock(devctx->lock); + if (!devctx->seed_ir_value) { + devctx->seed_ir_value = + IrValueFromScalar(MakeIntScalar(devctx->seed), kSeedType, device); } + // Keep the running seed as scalar as well, so we can return it directly + // without executing graphs. + devctx->running_seed = kSeedAdd + kSeedMul * devctx->running_seed; + // Compose new seeds from the root seed, to avoid creating too many XLA + // computation parameters which might overflow the TPU capacity. + torch::lazy::Value k = ScalarOp(MakeIntScalar(kSeedMul), + MakeXlaPrimitiveType(kSeedType, &device)); + torch::lazy::Value b = ScalarOp(MakeIntScalar(kSeedAdd), + MakeXlaPrimitiveType(kSeedType, &device)); + devctx->seed_ir_value = b + k * devctx->seed_ir_value; + return devctx->seed_ir_value; +} - void XLAGraphExecutor::DeviceContextArena::ForAllDeviceContexts(const std::function& fn, - const torch::lazy::BackendDevice* device) { - if (device == nullptr) { - for (auto devctx : GetAllDeviceContexts()) { - fn(devctx); - } - } else { - fn(GetDeviceContext(*device)); - } - } +torch::lazy::BackendDataPtr XLAGraphExecutor::DeviceContextArena::GetBaseSeedData( + const torch::lazy::BackendDevice& device) { + static const at::ScalarType kSeedType = at::ScalarType::Long; + DeviceContext* devctx = GetDeviceContext(device); + std::lock_guard lock(devctx->lock); + at::Tensor tensor = at::scalar_tensor(MakeIntScalar(devctx->seed), + at::TensorOptions(kSeedType)); + torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device); + devctx->seed_ir_value = torch::lazy::MakeNode(device_data); + devctx->running_seed = devctx->seed; + return torch_xla::DeviceData::Cast(devctx->seed_ir_value.node.get()) + ->data(); +} - auto XLAGraphExecutor::DeviceContextArena::GetDeviceContext(const torch::lazy::BackendDevice& device) -> DeviceContext* { - std::lock_guard lock(lock_); - auto it = device_contexts_.find(device); - if (it == device_contexts_.end()) { - it = device_contexts_.emplace(device, new DeviceContext()).first; - } - return it->second; - } +torch::lazy::Value XLAGraphExecutor::DeviceContextArena::IrValueFromScalar(const at::Scalar& value, + at::ScalarType scalar_type, + const torch::lazy::BackendDevice& device) { + at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type)); + torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device); + return torch::lazy::MakeNode(std::move(device_data)); +} XLAGraphExecutor::Async::Async( SyncTensorCollection* coll, diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index a375c70fa0bd..d5d92d2f070c 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -175,50 +175,26 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector tensors_data; }; - // The DeviceContextArena holds per device live information and statistics, - // among which the XLA tensors which are currently alive in the system. This is - // used to create XLA computation "barriers" in order to flush pending - // operations and ensure the same XLA computations are created during the - // training loops. - class DeviceContextArena { - struct DeviceContext { - std::mutex lock; - std::map> tensors_data; - uint64_t seed = 101; - uint64_t running_seed = 101; - torch::lazy::Value seed_ir_value; - }; - + class DeviceContextArena : public torch::lazy::LazyGraphExecutor::DeviceContextArena { public: static DeviceContextArena* Get(); - void RegisterTensor(std::shared_ptr data); - void UnregisterTensor(torch::lazy::LazyTensor::Data* data); - + // This method returns XLATensorPtrs instead of LazyTensorPtrs. std::vector GetLiveTensors( const torch::lazy::BackendDevice* device); - torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device); + // We override this to use our own + and * for torch::lazy::Value. + torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device) final; torch::lazy::BackendDataPtr GetBaseSeedData( const torch::lazy::BackendDevice& device); - uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device); - - void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed); - - void MarkStep(const torch::lazy::BackendDevice& device); - private: - std::vector GetAllDeviceContexts(); - - void ForAllDeviceContexts(const std::function& fn, - const torch::lazy::BackendDevice* device); - - DeviceContext* GetDeviceContext(const torch::lazy::BackendDevice& device); - - std::mutex lock_; - std::map device_contexts_; + // We override this to use TensorToXlaData(). + torch::lazy::Value IrValueFromScalar( + const at::Scalar& value, + at::ScalarType scalar_type, + const torch::lazy::BackendDevice& device) final; }; XLAGraphExecutor() = default; From 3e48ad6b2701acf5a548deb0c1f2e505358d56f2 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 09:05:05 +0000 Subject: [PATCH 4/8] Mark some overrided methods final --- torch_xla/csrc/xla_graph_executor.cpp | 4 ++-- torch_xla/csrc/xla_graph_executor.h | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 4f3bf79affb8..ec1931470909 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -175,11 +175,11 @@ XLAGraphExecutor* XLAGraphExecutor::Get() { return &arena; } -void XLAGraphExecutor::RegisterTensor(std::shared_ptr data) { +void XLAGraphExecutor::RegisterTensor(std::shared_ptr data) { DeviceContextArena::Get()->RegisterTensor(data); } -void XLAGraphExecutor::UnregisterTensor(XLATensor::Data* data) { +void XLAGraphExecutor::UnregisterTensor(torch::lazy::LazyTensor::Data* data) { DeviceContextArena::Get()->UnregisterTensor(data); } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index d5d92d2f070c..fea52dbff53b 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -32,8 +32,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { public: static XLAGraphExecutor* Get(); - virtual void RegisterTensor(std::shared_ptr data); - virtual void UnregisterTensor(XLATensor::Data* data); + void RegisterTensor(std::shared_ptr data) final; + void UnregisterTensor(torch::lazy::LazyTensor::Data* data) final; // This method just syncs the tensors passed as argument. This method is // called at two places: @@ -79,9 +79,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { c10::optional logical_element_type, const torch::lazy::BackendDevice& device); - torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device); - void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed); - uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device); + torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device) final; + void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed) final; + uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device) final; torch::lazy::BackendDataPtr GetBaseSeedData( const torch::lazy::BackendDevice& device); @@ -93,6 +93,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // for the given device. If device is nullptr, the live tensors for all // devices will be returned. Returned tensors are sorted by device as primary // key, and by unique ID as secondary key. + // Unlike the base class, here we return XLATensorPtrs. std::vector GetLiveTensors( const torch::lazy::BackendDevice* device); @@ -113,7 +114,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Marks an execution step, which allows the tensor framework to understand // the computation boundaries. - void MarkStep(const torch::lazy::BackendDevice& device); + void MarkStep(const torch::lazy::BackendDevice& device) final; // Waits for all the outstanding operations on all the supplied devices. // If devices is empty, the wait will happen for all local devices. From f4a550b6bfc3e05a64e46a10911158fb767f4fae Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 09:06:13 +0000 Subject: [PATCH 5/8] Fix linters --- torch_xla/csrc/xla_graph_executor.cpp | 21 ++++++++++++--------- torch_xla/csrc/xla_graph_executor.h | 21 ++++++++++++--------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index ec1931470909..7206a5e9ca34 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -79,7 +79,8 @@ std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); for (auto& uid_wptr : devctx->tensors_data) { - auto data = std::dynamic_pointer_cast(uid_wptr.second.lock()); + auto data = + std::dynamic_pointer_cast(uid_wptr.second.lock()); if (data != nullptr) { tensors.push_back( c10::make_intrusive(XLATensor(std::move(data)))); @@ -90,7 +91,8 @@ std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( return tensors; } -torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(const torch::lazy::BackendDevice& device) { +torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed( + const torch::lazy::BackendDevice& device) { static const at::ScalarType kSeedType = at::ScalarType::Long; static const uint64_t kSeedMul = 214013; static const uint64_t kSeedAdd = 2531011; @@ -113,7 +115,8 @@ torch::lazy::Value XLAGraphExecutor::DeviceContextArena::GetRngSeed(const torch: return devctx->seed_ir_value; } -torch::lazy::BackendDataPtr XLAGraphExecutor::DeviceContextArena::GetBaseSeedData( +torch::lazy::BackendDataPtr +XLAGraphExecutor::DeviceContextArena::GetBaseSeedData( const torch::lazy::BackendDevice& device) { static const at::ScalarType kSeedType = at::ScalarType::Long; DeviceContext* devctx = GetDeviceContext(device); @@ -123,13 +126,12 @@ torch::lazy::BackendDataPtr XLAGraphExecutor::DeviceContextArena::GetBaseSeedDat torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device); devctx->seed_ir_value = torch::lazy::MakeNode(device_data); devctx->running_seed = devctx->seed; - return torch_xla::DeviceData::Cast(devctx->seed_ir_value.node.get()) - ->data(); + return torch_xla::DeviceData::Cast(devctx->seed_ir_value.node.get())->data(); } -torch::lazy::Value XLAGraphExecutor::DeviceContextArena::IrValueFromScalar(const at::Scalar& value, - at::ScalarType scalar_type, - const torch::lazy::BackendDevice& device) { +torch::lazy::Value XLAGraphExecutor::DeviceContextArena::IrValueFromScalar( + const at::Scalar& value, at::ScalarType scalar_type, + const torch::lazy::BackendDevice& device) { at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type)); torch::lazy::BackendDataPtr device_data = TensorToXlaData(tensor, device); return torch::lazy::MakeNode(std::move(device_data)); @@ -175,7 +177,8 @@ XLAGraphExecutor* XLAGraphExecutor::Get() { return &arena; } -void XLAGraphExecutor::RegisterTensor(std::shared_ptr data) { +void XLAGraphExecutor::RegisterTensor( + std::shared_ptr data) { DeviceContextArena::Get()->RegisterTensor(data); } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index fea52dbff53b..662912eedb7e 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -32,7 +32,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { public: static XLAGraphExecutor* Get(); - void RegisterTensor(std::shared_ptr data) final; + void RegisterTensor( + std::shared_ptr data) final; void UnregisterTensor(torch::lazy::LazyTensor::Data* data) final; // This method just syncs the tensors passed as argument. This method is @@ -80,7 +81,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const torch::lazy::BackendDevice& device); torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device) final; - void SetRngSeed(const torch::lazy::BackendDevice& device, uint64_t seed) final; + void SetRngSeed(const torch::lazy::BackendDevice& device, + uint64_t seed) final; uint64_t GetRunningSeed(const torch::lazy::BackendDevice& device) final; torch::lazy::BackendDataPtr GetBaseSeedData( const torch::lazy::BackendDevice& device); @@ -176,8 +178,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector tensors_data; }; - class DeviceContextArena : public torch::lazy::LazyGraphExecutor::DeviceContextArena { - public: + class DeviceContextArena + : public torch::lazy::LazyGraphExecutor::DeviceContextArena { + public: static DeviceContextArena* Get(); // This method returns XLATensorPtrs instead of LazyTensorPtrs. @@ -185,16 +188,16 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const torch::lazy::BackendDevice* device); // We override this to use our own + and * for torch::lazy::Value. - torch::lazy::Value GetRngSeed(const torch::lazy::BackendDevice& device) final; + torch::lazy::Value GetRngSeed( + const torch::lazy::BackendDevice& device) final; torch::lazy::BackendDataPtr GetBaseSeedData( const torch::lazy::BackendDevice& device); - private: - // We override this to use TensorToXlaData(). + private: + // We override this to use TensorToXlaData(). torch::lazy::Value IrValueFromScalar( - const at::Scalar& value, - at::ScalarType scalar_type, + const at::Scalar& value, at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) final; }; From 3db6a0cfdf92448161de20e9c7d31146b27339b7 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 09:06:59 +0000 Subject: [PATCH 6/8] Add .torch_pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..dc42e488d36c --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#90531 From ad8696a0da4cbf9eb56c28519e8cb6b14c8d4624 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 9 Dec 2022 19:57:49 +0000 Subject: [PATCH 7/8] Add back ssome XLA counters --- torch_xla/csrc/xla_graph_executor.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 7206a5e9ca34..dded15d05df4 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -180,10 +180,12 @@ XLAGraphExecutor* XLAGraphExecutor::Get() { void XLAGraphExecutor::RegisterTensor( std::shared_ptr data) { DeviceContextArena::Get()->RegisterTensor(data); + TORCH_LAZY_COUNTER("CreateXlaTensor", 1); } void XLAGraphExecutor::UnregisterTensor(torch::lazy::LazyTensor::Data* data) { DeviceContextArena::Get()->UnregisterTensor(data); + TORCH_LAZY_COUNTER("DestroyXlaTensor", 1); } void XLAGraphExecutor::ApplyEagerSync(std::vector& tensors) { From cfefbf18e1d8cfd9a02446325e6b4375a2eb37ad Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 13 Dec 2022 04:23:53 +0000 Subject: [PATCH 8/8] Remove .torch_pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index dc42e488d36c..000000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#90531