Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,16 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
if (expr_evaluator_.isKnown(tv)) {
return;
}

// Check the cache if enabled
if (params_.use_allocation_cache) {
auto it = allocation_cache_.find(allocate);
if (it != allocation_cache_.end()) {
expr_evaluator_.bind(tv, it->second);
return;
}
}

GlobalBufferInfo info =
getBufferInfos(expr_evaluator_, PrimDataType::Int, {tv}).at(0);
c10::Device device =
Expand All @@ -584,6 +594,12 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
c10::nullopt,
device,
c10::nullopt);

// Cache the allocation if enabled
if (params_.use_allocation_cache) {
allocation_cache_[allocate] = tensor;
}

if (allocate->zeroInit()) {
tensor.zero_();
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct HostIrEvaluatorParams {
// number of additional cuda streams to use at runtime for comm+compute
// pipelining
int64_t number_of_streams = 4;
// Whether to use allocation cache for tensor allocations
bool use_allocation_cache = false;
};

// A HostIrEvaluator evaluates a host programs represented through a
Expand Down Expand Up @@ -135,6 +137,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_local_device_index_;
IpcHandleCache ipc_handle_cache_;
// Allocation cache
std::unordered_map<kir::Allocate*, at::Tensor> allocation_cache_;
};

} // namespace hir
Expand Down
46 changes: 34 additions & 12 deletions python/python_direct/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,34 +143,56 @@ If the distributed tensor is replicated on that parallel type, returns -1.
}

void bindMultiDeviceExecutor(py::module& nvfuser) {
// Bind params type under the multidevice submodule. We'll alias it to the
// top-level module in bindMultiDevice to allow direct imports.
py::class_<MultiDeviceExecutorParams>(nvfuser, "MultiDeviceExecutorParams")
.def(py::init<>())
.def_property(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try .def_readwrite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to use that with nested structures.

Otherwise I can also change the behavior and let the use "readwrite" the whole params and not only those two knobs:

.def_readwrite("executor", &MultiDeviceExecutorParams::executor)
.def_readwrite("lower", &MultiDeviceExecutorParams::lower)

"use_allocation_cache",
[](const MultiDeviceExecutorParams& self) {
return self.executor.use_allocation_cache;
},
[](MultiDeviceExecutorParams& self, bool value) {
self.executor.use_allocation_cache = value;
})
.def_property(
"backend_type",
[](const MultiDeviceExecutorParams& self) {
return self.lower.communicator_backend;
},
[](MultiDeviceExecutorParams& self, CommunicatorBackend value) {
self.lower.communicator_backend = value;
});

py::class_<MultiDeviceExecutor> multi_device_executor(
nvfuser, "MultiDeviceExecutor");
multi_device_executor.def(
py::init([](const Fusion& fusion, CommunicatorBackend backend) {
MultiDeviceExecutorParams params;
params.lower.communicator_backend = backend;
return std::make_unique<MultiDeviceExecutor>(
std::make_unique<Fusion>(fusion),
Communicator::getInstance(),
std::move(params));
}),
py::init(
[](const Fusion& fusion, const MultiDeviceExecutorParams& params) {
return std::make_unique<MultiDeviceExecutor>(
std::make_unique<Fusion>(fusion),
Communicator::getInstance(),
params);
}),
R"(
Create a new MultiDeviceExecutor.

Parameters
----------
fusion : Fusion
The fusion to be executed.
backend : CommunicatorBackend
The backend to be used for the communicator.
params : MultiDeviceExecutorParams
Parameters configuring the executor and communicator backend.

Examples
--------
>>> multi_device_executor = MultiDeviceExecutor(fusion, CommunicatorBackend.nccl)
>>> params = MultiDeviceExecutorParams()
>>> params.backend_type = CommunicatorBackend.nccl
>>> multi_device_executor = MultiDeviceExecutor(fusion, params)
>>> outputs = multi_device_executor.run(inputs)
)",
py::arg("fusion"),
py::arg("backend"));
py::arg("params"));
multi_device_executor.def(
"__str__",
[](MultiDeviceExecutor& self) {
Expand Down
13 changes: 7 additions & 6 deletions tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def multidevice_schedule(fd, tensors, num_devices) -> None:
tensors = fusion_definition(fd, m, k, n, s, d)
multidevice_schedule(fd, tensors, d)

multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(
fd.fusion, backend_type
)
params = nvfuser.multidevice.MultiDeviceExecutorParams()
params.backend_type = backend_type
multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(fd.fusion, params)

# warmup
for _ in range(N_WARMUPS):
Expand Down Expand Up @@ -133,9 +133,10 @@ def multidevice_schedule(fd, tensors, num_devices) -> None:
tensors = fusion_definition(fd, m, k, n, d)
multidevice_schedule(fd, tensors, d)

multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(
fd.fusion, backend_type
)
params = nvfuser.multidevice.MultiDeviceExecutorParams()
params.backend_type = backend_type
params.use_allocation_cache = True
multidevice_executor = nvfuser.multidevice.MultiDeviceExecutor(fd.fusion, params)

# warmup
for _ in range(N_WARMUPS):
Expand Down