diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 44b95274e2e..28d81548188 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -573,6 +573,16 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { if (expr_evaluator_.isKnown(tv)) { return; } + + // Check the cache if enabled + if (params_.use_allocation_cache) { + auto it = allocation_cache_.find(allocate); + if (it != allocation_cache_.end()) { + expr_evaluator_.bind(tv, it->second); + return; + } + } + GlobalBufferInfo info = getBufferInfos(expr_evaluator_, PrimDataType::Int, {tv}).at(0); c10::Device device = @@ -584,6 +594,12 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { c10::nullopt, device, c10::nullopt); + + // Cache the allocation if enabled + if (params_.use_allocation_cache) { + allocation_cache_[allocate] = tensor; + } + if (allocate->zeroInit()) { tensor.zero_(); } diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 5747630a9e6..bba880a7ccb 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -37,6 +37,8 @@ struct HostIrEvaluatorParams { // number of additional cuda streams to use at runtime for comm+compute // pipelining int64_t number_of_streams = 4; + // Whether to use allocation cache for tensor allocations + bool use_allocation_cache = false; }; // A HostIrEvaluator evaluates a host programs represented through a @@ -135,6 +137,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { std::unordered_map> works_; const int64_t my_local_device_index_; IpcHandleCache ipc_handle_cache_; + // Allocation cache + std::unordered_map allocation_cache_; }; } // namespace hir diff --git a/python/python_direct/multidevice.cpp b/python/python_direct/multidevice.cpp index 33313e537df..0855f4c281f 100644 --- a/python/python_direct/multidevice.cpp +++ b/python/python_direct/multidevice.cpp @@ -143,17 +143,37 @@ If the distributed tensor is replicated on that parallel type, returns -1. } void bindMultiDeviceExecutor(py::module& nvfuser) { + // Bind params type under the multidevice submodule. We'll alias it to the + // top-level module in bindMultiDevice to allow direct imports. + py::class_(nvfuser, "MultiDeviceExecutorParams") + .def(py::init<>()) + .def_property( + "use_allocation_cache", + [](const MultiDeviceExecutorParams& self) { + return self.executor.use_allocation_cache; + }, + [](MultiDeviceExecutorParams& self, bool value) { + self.executor.use_allocation_cache = value; + }) + .def_property( + "backend_type", + [](const MultiDeviceExecutorParams& self) { + return self.lower.communicator_backend; + }, + [](MultiDeviceExecutorParams& self, CommunicatorBackend value) { + self.lower.communicator_backend = value; + }); + py::class_ multi_device_executor( nvfuser, "MultiDeviceExecutor"); multi_device_executor.def( - py::init([](const Fusion& fusion, CommunicatorBackend backend) { - MultiDeviceExecutorParams params; - params.lower.communicator_backend = backend; - return std::make_unique( - std::make_unique(fusion), - Communicator::getInstance(), - std::move(params)); - }), + py::init( + [](const Fusion& fusion, const MultiDeviceExecutorParams& params) { + return std::make_unique( + std::make_unique(fusion), + Communicator::getInstance(), + params); + }), R"( Create a new MultiDeviceExecutor. @@ -161,16 +181,18 @@ Parameters ---------- fusion : Fusion The fusion to be executed. -backend : CommunicatorBackend - The backend to be used for the communicator. +params : MultiDeviceExecutorParams + Parameters configuring the executor and communicator backend. Examples -------- ->>> multi_device_executor = MultiDeviceExecutor(fusion, CommunicatorBackend.nccl) +>>> params = MultiDeviceExecutorParams() +>>> params.backend_type = CommunicatorBackend.nccl +>>> multi_device_executor = MultiDeviceExecutor(fusion, params) >>> outputs = multi_device_executor.run(inputs) )", py::arg("fusion"), - py::arg("backend")); + py::arg("params")); multi_device_executor.def( "__str__", [](MultiDeviceExecutor& self) { diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index f6b805430ce..44e4bc3d079 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -64,9 +64,9 @@ def multidevice_schedule(fd, tensors, num_devices) -> None: tensors = fusion_definition(fd, m, k, n, s, d) multidevice_schedule(fd, tensors, d) - multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor( - fd.fusion, backend_type - ) + params = nvfuser.multidevice.MultiDeviceExecutorParams() + params.backend_type = backend_type + multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(fd.fusion, params) # warmup for _ in range(N_WARMUPS): @@ -133,9 +133,10 @@ def multidevice_schedule(fd, tensors, num_devices) -> None: tensors = fusion_definition(fd, m, k, n, d) multidevice_schedule(fd, tensors, d) - multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor( - fd.fusion, backend_type - ) + params = nvfuser.multidevice.MultiDeviceExecutorParams() + params.backend_type = backend_type + params.use_allocation_cache = True + multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(fd.fusion, params) # warmup for _ in range(N_WARMUPS):