Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ struct ProviderOptionsArray_Element : JSON::Element {
v.name = "QNN";
} else if (v.name == "webgpu") {
v.name = "WebGPU";
} else if (v.name == "dml") {
v.name = "DML";
}
}
}
Expand Down Expand Up @@ -768,7 +770,7 @@ bool IsGraphCaptureEnabled(Config::SessionOptions& session_options) {
throw std::runtime_error("Graph Capture is currently unsupported for CUDA");
}
}
} else if (provider_options.name == "dml") {
} else if (provider_options.name == "DML") {
return true;
} else if (provider_options.name == "NvTensorRtRtx") {
return true;
Expand Down
1 change: 1 addition & 0 deletions src/json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void TranslateException(std::string_view name) {
JSON::JSON(Element& element, std::string_view document) : begin_{document.data()}, end_{document.data() + document.size()} {
try {
Parse_Value(element, {});
element.OnComplete(false);
} catch (const std::exception& message) {
// Figure out line number of error by counting carriage returns seen from start to error location
int line = 1;
Expand Down
18 changes: 13 additions & 5 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,

Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size()));
session_options.AppendExecutionProvider_ROCM(ort_provider_options);
} else if (provider_options.name == "DML") {
#if USE_DML
} else if (provider_options.name == "dml") {
if (!GetDmlInterface()) {
LUID device_luid{};
LUID* p_device_luid{};
Expand All @@ -338,6 +338,8 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,

if (is_primary_session_options)
p_device = GetDeviceInterface(DeviceType::DML); // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
#else
throw std::runtime_error("DML provider requested, but the installed GenAI has not been built with DML support");
#endif
} else {
// For providers that go through the extensible AppendExecutionProvider API:
Expand Down Expand Up @@ -407,7 +409,7 @@ void EnsureDeviceOrtInit(DeviceInterface& device) {
// This ensures memory allocated on-device for model inputs/outputs is valid for the lifetime of GenAI.

// Names for the device types used by 'SetProviderSessionOptions'
static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "dml", "WebGPU", "QNN", "OpenVINO (Not used, see above)"};
static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "DML", "WebGPU", "QNN", "OpenVINO (Not used, see above)"};
static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));

// Create an OrtSessionOptions and set the options to use the DeviceType we're using here
Expand Down Expand Up @@ -737,9 +739,15 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
auto expanded = OrtValue::CreateTensor(p_device_inputs_->GetAllocator(), input_shape, element_type);
auto expanded_span = ByteWrapTensor(*p_device_inputs_, *expanded);

for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes));
// Detect fast & simple copy case
if (num_beams == 1) {
expanded_span.CopyFrom(input_span);
} else {
// TODO (RyanHill): To avoid cuda uninitialized memory warnings, we should copy input_span to device memory first
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes));
}
}
}
return expanded;
Expand Down
1 change: 1 addition & 0 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ TEST(CAPITests, Config) {
config->SetProviderOption("brainium", "custom_field2", "hello2");
config->ClearProviders();
config->AppendProvider("cuda");
config->AppendProvider("dml");
#endif
}

Expand Down
Loading