Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ class TVM_DLL DeviceAPI {
* \param stream The stream to be set.
*/
virtual void SetStream(Device dev, TVMStreamHandle stream) {}
/*!
* \brief Get the current stream
* \param dev The device to get stream.
* \return The current stream of the device.
*/
virtual TVMStreamHandle GetCurrentStream(Device dev);
/*!
* \brief Synchronize 2 streams of execution.
*
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }

void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}

TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; }

void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
}

Expand Down
12 changes: 6 additions & 6 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class CUDADeviceAPI final : public DeviceAPI {
TVMStreamHandle CreateStream(Device dev) {
CUDA_CALL(cudaSetDevice(dev.device_id));
cudaStream_t retval;
CUDA_CALL(cudaStreamCreate(&retval));
CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
return static_cast<TVMStreamHandle>(retval);
}

Expand Down Expand Up @@ -225,6 +225,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return static_cast<TVMStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
}

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand All @@ -243,11 +247,7 @@ class CUDADeviceAPI final : public DeviceAPI {
private:
static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != nullptr) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
}
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
}
};

Expand Down
1 change: 1 addition & 0 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class MetalWorkspace final : public DeviceAPI {
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
TVMStreamHandle GetCurrentStream(Device dev) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
void ReinitializeDefaultStreams();
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ int GetWarpSize(id<MTLDevice> dev) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
}

TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
return MetalThreadEntry::ThreadLocal()->stream[dev.device_id];
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ enum class RPCCode : int {
kDevCreateStream,
kDevFreeStream,
kDevSetStream,
kDevGetCurrentStream,
};

/*!
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return static_cast<TVMStreamHandle>(ROCMThreadEntry::ThreadLocal()->stream);
}

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
7 changes: 6 additions & 1 deletion src/runtime/rpc/rpc_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ class RPCDeviceAPI final : public DeviceAPI {
GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream);
}

void SetStream(Device dev, TVMStreamHandle stream) {
void SetStream(Device dev, TVMStreamHandle stream) final {
auto remote_dev = RemoveRPCSessionMask(dev);
GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
auto remote_dev = RemoveRPCSessionMask(dev);
return GetSess(dev)->GetDeviceAPI(remote_dev)->GetCurrentStream(remote_dev);
}

protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint,
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ void RPCDevSetStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
handler->GetDeviceAPI(dev)->SetStream(dev, stream);
}

void RPCDevGetCurrentStream(RPCSession* handler, TVMArgs args, TVMRetValue* rv) {
Device dev = args[0];
*rv = handler->GetDeviceAPI(dev)->GetCurrentStream(dev);
}

void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
// Event handler sit at clean state at this point.
switch (code) {
Expand Down Expand Up @@ -1043,6 +1048,9 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
case RPCCode::kDevSetStream:
SysCallHandler(RPCDevSetStream);
break;
case RPCCode::kDevGetCurrentStream:
SysCallHandler(RPCDevGetCurrentStream);
break;
case RPCCode::kCopyAmongRemote:
SysCallHandler(RPCCopyAmongRemote);
break;
Expand Down Expand Up @@ -1188,6 +1196,10 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream);
}

TVMStreamHandle GetCurrentStream(Device dev) final {
return endpoint_->SysCallRemote(RPCCode::kDevGetCurrentStream, dev);
}

DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; }

bool IsLocalSession() const final { return false; }
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
ICHECK_EQ(stream, static_cast<void*>(nullptr));
}

TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return nullptr; }

void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/vulkan/vulkan_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
TVMStreamHandle GetCurrentStream(Device dev) final;

protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
Expand Down
2 changes: 2 additions & 0 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class WebGPUDeviceAPI : public DeviceAPI {

void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; }

TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not implemented"; }

void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down