Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ class MetalWorkspace final : public DeviceAPI {
std::vector<id<MTLDevice>> devices;
// Warp size constant
std::vector<int> warp_size;
// Whether it is initialized.
bool initialized_{false};
// the mutex for initialization
std::mutex mutex;
MetalWorkspace();
// Destructor
~MetalWorkspace();
// Get device for given device
Expand All @@ -149,9 +146,6 @@ class MetalWorkspace final : public DeviceAPI {
<< "Invalid Metal device_id=" << dev.device_id;
return devices[dev.device_id];
}
// Initialize workspace
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(Device dev) final;
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
Expand All @@ -163,7 +157,16 @@ class MetalWorkspace final : public DeviceAPI {
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
void ReinitializeStreams();
void ReinitializeDefaultStreams();

/**
* Cast stream to the right metal stream data structure
* if stream is nullptr , return the default stream of device_id
* \param stream the input stream handle
* \param device_id The device id of interest
* \returns The stream used in this function.
*/
Stream* CastStreamOrGetDefault(TVMStreamHandle stream, int device_id);

// get the global workspace
static MetalWorkspace* Global();
Expand All @@ -184,7 +187,7 @@ class MetalThreadEntry {
/*! \brief The current device */
Device device;
/*! \brief The current stream */
std::vector<Stream*> stream;
std::vector<TVMStreamHandle> stream;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer>> temp_buffer_;
/*! \brief workspace pool */
Expand All @@ -193,6 +196,10 @@ class MetalThreadEntry {
MetalThreadEntry() : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
device.device_id = 0;
device.device_type = static_cast<DLDeviceType>(kDLMetal);
MetalWorkspace* global_ws = MetalWorkspace::Global();
// by default, set the stream to nullptr, which indicate
// that we are using default stream
this->stream.resize(global_ws->devices.size(), nullptr);
}
~MetalThreadEntry();
// Get temp buffer with at least size under dev.
Expand Down
40 changes: 12 additions & 28 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
AUTORELEASEPOOL {
this->Init();
size_t index = static_cast<size_t>(dev.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
Expand Down Expand Up @@ -142,29 +141,18 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

void MetalWorkspace::ReinitializeStreams() {
std::vector<Stream*>& threadStreams = MetalThreadEntry::ThreadLocal()->stream;
ICHECK_EQ(default_streams_.size(), threadStreams.size());
void MetalWorkspace::ReinitializeDefaultStreams() {
for (size_t i = 0; i < default_streams_.size(); ++i) {
if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i])
delete threadStreams[i];
delete default_streams_[i];
}
default_streams_.resize(devices.size());
threadStreams.resize(devices.size());
for (size_t i = 0; i < devices.size(); ++i) {
Stream* stream = new Stream(devices[i]);
default_streams_[i] = stream;
threadStreams[i] = stream;
}
}

void MetalWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mutex);
if (initialized_) return;
initialized_ = true;
if (devices.size() != 0) return;
MetalWorkspace::MetalWorkspace() {
#if TARGET_OS_IPHONE
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
Expand All @@ -178,7 +166,7 @@ int GetWarpSize(id<MTLDevice> dev) {
warp_size.push_back(GetWarpSize(d));
}
#endif
ReinitializeStreams();
this->ReinitializeDefaultStreams();
}

void MetalWorkspace::SetDevice(Device dev) {
Expand All @@ -189,7 +177,6 @@ int GetWarpSize(id<MTLDevice> dev) {
DLDataType type_hint) {
id<MTLBuffer> buf;
AUTORELEASEPOOL {
this->Init();
id<MTLDevice> dev = GetDevice(device);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
Expand Down Expand Up @@ -220,20 +207,20 @@ int GetWarpSize(id<MTLDevice> dev) {
};
}

Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int device_id) {
if (stream != nullptr) return static_cast<Stream*>(stream);
ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
return MetalThreadEntry::ThreadLocal()->stream[device_id];
ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
ICHECK(default_streams_[device_id] != nullptr);
return default_streams_[device_id];
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
AUTORELEASEPOOL {
this->Init();
Device dev = dev_from;
if (dev_from.device_type == kDLCPU) dev = dev_to;
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
Expand Down Expand Up @@ -303,15 +290,12 @@ int GetWarpSize(id<MTLDevice> dev) {
void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
Stream* s = static_cast<Stream*>(stream);
if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s)
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr;
delete s;
delete static_cast<Stream*>(stream);
}

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
AUTORELEASEPOOL {
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
Stream* s = CastStreamOrGetDefault(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
Expand All @@ -325,7 +309,7 @@ int GetWarpSize(id<MTLDevice> dev) {
void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
ICHECK(stream != nullptr);
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
Expand Down Expand Up @@ -374,7 +358,7 @@ int GetWarpSize(id<MTLDevice> dev) {
});

TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() {
MetalWorkspace::Global()->ReinitializeStreams();
MetalWorkspace::Global()->ReinitializeDefaultStreams();
});

} // namespace metal
Expand Down
9 changes: 4 additions & 5 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
AUTORELEASEPOOL {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
// obtain the stream
auto stream =
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id);
if (stream->HasErrorHappened()) return;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
Expand Down Expand Up @@ -265,10 +267,7 @@ Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string fmt,
std::string source) {
ObjectPtr<Object> n;
AUTORELEASEPOOL {
metal::MetalWorkspace::Global()->Init();
n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
};
AUTORELEASEPOOL { n = make_object<MetalModuleNode>(smap, fmap, fmt, source); };
return Module(n);
}

Expand Down