Skip to content

Commit

Permalink
Fix segfault for multiple GPU run (regression) (#15823)
Browse files Browse the repository at this point in the history
### Fix segfault for multiple GPU run

#15618 introduced
`GetOrtDeviceByMemType`. The intention should be: handle CPU device
differently in the if branch, while might by mistakenly passing the
unique default non-cpu device id.


```
OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
  if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
    return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id());
  }
  return default_device_;
}
```

We observed a segement fault thrown when running multiple GPU training  

`
CUDA_LAUNCH_BLOCKING=1 python -m torch.distributed.launch
--nproc_per_node=2
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path distilbert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs222/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused --max_steps
400 --logging_steps 1
`

It is found GPU0 works fine, GPU1 throw segement fault. Looking further,
a Shape node trying to allocate it's output tensor, trying to fetch
corresponding allocator with ORTDevice(Device:[DeviceType:0 MemoryType:1
DeviceId:1]), while CPU device did not have device id = 1, so a no
allocator returned. When we try to call `AsStreamBasedAllocator` for the
allocator, segement happens as no null check was done there.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored and snnn committed May 19, 2023
1 parent cd998ee commit 6b2013c
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 38 deletions.
20 changes: 9 additions & 11 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1745,9 +1745,8 @@ class PlannerImpl {

#else

void
PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
void PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
Expand All @@ -1760,7 +1759,7 @@ class PlannerImpl {
num_logic_streams_ = stream_nodes_.size();
}

// build each logic streams
// Build each logic streams
Status BuildExecutionPlan(const ExecutionProviders& execution_providers,
const IStreamCommandHandleRegistry& stream_handle_registry) {
// 1. create logic stream instance
Expand All @@ -1780,12 +1779,12 @@ class PlannerImpl {
execution_plan.emplace_back(nullptr);
}
}
// 2. determing following things:
// a. which node need to generate notification
// b. which node need to trigger downstream
// 2. Determining following things:
// a. which node needs to generate the notification
// b. which node needs to trigger downstream
#ifdef ENABLE_TRAINING
// We will leverage the topological order for the training scenario.
// The nodes before yieldOp in topo order will be executed in RunForward() and nodes after will be executed in RunBackward()
// The nodes before yieldOp in topo-order will be executed in RunForward() and nodes after will be executed in RunBackward()
// This partition may not be exactly the same as forward model/gradient model, for example, some nodes in gradient model are
// before yieldOp thus will be executed in RunForward()
// But the final result is still correct, as long as all the nodes will be executed in either RunForward() or RunBackward()
Expand Down Expand Up @@ -1820,7 +1819,7 @@ class PlannerImpl {
if (node_stream_map_[it->Index()] != i
#ifdef ENABLE_TRAINING
// Do not insert Barrier/TriggerDownStream step if the producer and consumer are in different sides of yieldOp
// As in this case producer will surely be ready before consumer is running.
// As in this case producer will surely be ready before the consumer is running.
&& !AreNodesSeparatedByYield(node_index, it->Index())
#endif
) {
Expand Down Expand Up @@ -2048,8 +2047,7 @@ class PlannerImpl {
}
#endif

static bool
IsNonTensor(const onnxruntime::NodeArg& nodearg) {
static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) {
// TODO: unclear why we should go through a string-representation of type
auto ptype = nodearg.Type();
auto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(ptype);
Expand Down
47 changes: 25 additions & 22 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
#ifdef ORT_ENABLE_STREAM
static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) {
ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr");
if (allocator->Info().alloc_type == OrtArenaAllocator) {
BFCArena* arena_ptr = static_cast<BFCArena*>(allocator.get());
return StreamAwareArena::FromBFCArena(*arena_ptr);
Expand Down Expand Up @@ -137,7 +138,7 @@ Status IExecutionFrame::GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std:

#endif

// Return nullptr if index map to an value that is an unused optional input/output
// Return nullptr if index map to a value that is an unused optional input/output
const OrtValue* IExecutionFrame::GetNodeInputOrOutputMLValue(int index) const {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
return ort_value_idx != NodeIndexInfo::kInvalidEntry ? &(all_values_[ort_value_idx]) : nullptr;
Expand All @@ -147,9 +148,9 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) {
return const_cast<OrtValue*>(GetNodeInputOrOutputMLValue(index));
}

// TO DO: make it thread safe
// This method is not thread safe!
// Return S_OK and nullptr if index map to an value that is an unused optional input/output
// TO DO: make it thread-safe
// This method is not thread-safe!
// Return S_OK and nullptr if index map to a value that is an unused optional input/output

Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int output_arg_index,
const TensorShape* shape, OrtValue*& p_ort_value,
Expand Down Expand Up @@ -191,7 +192,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int
}

bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const {
// By default, there is not information about inferred shape, so this default
// By default, there is no information about inferred shape, so this default
// implementation always returns false. The derived class of IExecutionFrame
// can override this function to provide, for example, activations' shape information.
return false;
Expand All @@ -213,7 +214,7 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
}

int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const {
// the validity of index is checked by GetMLValueIndex
// The validity of the index is checked by GetMLValueIndex
int ort_value_idx = node_index_info_.GetMLValueIndex(index);
return ort_value_idx;
}
Expand Down Expand Up @@ -241,7 +242,7 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
}
}

// 3. handle the weights.
// 3. Handle the weights.
// We do this after the fetches to handle an edge case where an initializer is an output.
// e.g. A Constant node gets lifted to an initializer so there's no Node producing the value as an output during
// Graph execution (i.e. Graph execution won't write the value to all_values_).
Expand All @@ -251,16 +252,16 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
for (const auto& entry : initializers) {
int ort_value_index = entry.first;

// if the initializer is an output we need to allocate or use a provided fetch buffer and copy the data
// so it can be returned to the caller.
// If the initializer is an output we need to allocate or use a provided fetch buffer and copy the data
// so it can be returned to the caller.
//
// The alternative to handling this as a special case would be to disallow an initializer providing a graph output.
// There's nothing in the ONNX spec that says a graph output must come from a node output though.
// If we took that approach we'd need to:
// - reject a model with an initializer or Constant node (as we convert those to initializers in Graph::Graph)
// that produces a graph output even though it conforms to the ONNX spec
// - update optimizers to not convert something to an initializer that is a graph output
// (e.g. constant folding)
// The alternative to handling this as a special case would be to disallow an initializer providing a graph output.
// There's nothing in the ONNX spec that says a graph output must come from a node output though.
// If we took that approach we'd need to:
// - reject a model with an initializer or Constant node (as we convert those to initializers in Graph::Graph)
// that produces a graph output even though it conforms to the ONNX spec
// - update optimizers to not convert something to an initializer that is a graph output
// (e.g. constant folding)
if (IsOutput(ort_value_index)) {
std::string name;
ORT_THROW_IF_ERROR(ort_value_idx_map_.GetName(ort_value_index, name));
Expand Down Expand Up @@ -288,7 +289,7 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
#endif // !defined(DISABLE_SPARSE_TENSORS)
if (!dest.IsAllocated()) {
// NOTE: This doesn't need to support ExecutionFrame custom allocators as they only come into play
// for a subgraph with an output of unknown shape that needs to be accumulated by the control flow node.
// for a subgraph with an output of unknown shape that needs to be accumulated by the control-flow node.
// If the initializer is providing the output, the shape is known.
AllocatorPtr allocator = GetAllocator(src.Location().device);
Tensor::InitOrtValue(src.DataType(), src.Shape(), std::move(allocator), dest);
Expand Down Expand Up @@ -535,7 +536,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
AllocatorPtr alloc = nullptr;

// if we have pre-calculated memory pattern, and the ort_value is not output mlvalue
// try to allocated on pre-allocated big chunk.
// try to allocate on pre-allocated big chunk.
const auto& per_alloc_plan = GetAllocationPlan(ort_value_index);

if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput &&
Expand All @@ -557,11 +558,11 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
} else {
// the block size may vary especially if the model has NonZero ops, or different sequence lengths are
// fed in, so use VERBOSE as the log level as it's expected.
// TODO: Should we re-use the block if the size is large enough? Would probably need to allow it
// TODO: Should we reuse the block if the size is large enough? Would probably need to allow it
// to be freed if the size difference was too large so our memory usage doesn't stick at a high water mark
LOGS(session_state_.Logger(), VERBOSE) << "For ort_value with index: " << ort_value_index
<< ", block in memory pattern size is: " << block->size_
<< " but the actually size is: " << size
<< " but the actual size is: " << size
<< ", fall back to default allocation behavior";
}
}
Expand All @@ -572,6 +573,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va

// no memory pattern, or the pattern is not correct.
if (!alloc) alloc = GetAllocator(location);
ORT_ENFORCE(alloc && alloc.get() != nullptr, "Failed to get allocator for ", location.ToString());

Stream* current_stream = GetValueStream(ort_value_index);
if (current_stream) {
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -825,7 +828,7 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
}

// This method is not thread safe!
// Return S_OK and nullptr if index map to an value that is an unused optional input/output
// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
}
Expand Down Expand Up @@ -930,7 +933,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
}

// Search for inferred shape.
// If inferred shape is found, it's assigned to "shape" so that caller can use it.
// If the inferred shape is found, it's assigned to "shape" so that caller can use it.
if (inferred_shapes_ != nullptr) {
auto it = inferred_shapes_->find(ort_value_idx);
if (it != inferred_shapes_->end()) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&

OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, default_device_.Id());
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2533,7 +2533,7 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&

OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id());
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis

OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, default_device_.Id());
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&

OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, default_device_.Id());
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2844,7 +2844,7 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis

OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id());
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
Expand Down

0 comments on commit 6b2013c

Please sign in to comment.