diff --git a/src/config.cpp b/src/config.cpp index c0afd1ad32..225a080bab 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -65,6 +65,8 @@ struct ProviderOptionsArray_Element : JSON::Element { v.name = "QNN"; } else if (v.name == "webgpu") { v.name = "WebGPU"; + } else if (v.name == "dml") { + v.name = "DML"; } } } @@ -768,7 +770,7 @@ bool IsGraphCaptureEnabled(Config::SessionOptions& session_options) { throw std::runtime_error("Graph Capture is currently unsupported for CUDA"); } } - } else if (provider_options.name == "dml") { + } else if (provider_options.name == "DML") { return true; } else if (provider_options.name == "NvTensorRtRtx") { return true; diff --git a/src/json.cpp b/src/json.cpp index 73449eda12..6907b7f265 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -50,6 +50,7 @@ void TranslateException(std::string_view name) { JSON::JSON(Element& element, std::string_view document) : begin_{document.data()}, end_{document.data() + document.size()} { try { Parse_Value(element, {}); + element.OnComplete(false); } catch (const std::exception& message) { // Figure out line number of error by counting carriage returns seen from start to error location int line = 1; diff --git a/src/models/model.cpp b/src/models/model.cpp index 82d78d82c7..4532ce7c74 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -311,8 +311,8 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options, Ort::ThrowOnError(Ort::api->UpdateROCMProviderOptions(&ort_provider_options, keys.data(), values.data(), keys.size())); session_options.AppendExecutionProvider_ROCM(ort_provider_options); + } else if (provider_options.name == "DML") { #if USE_DML - } else if (provider_options.name == "dml") { if (!GetDmlInterface()) { LUID device_luid{}; LUID* p_device_luid{}; @@ -338,6 +338,8 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options, if (is_primary_session_options) p_device = GetDeviceInterface(DeviceType::DML); // We use a DML allocator for input/output caches, but other tensors will use CPU tensors +#else + throw std::runtime_error("DML provider requested, but the installed GenAI has not been built with DML support"); #endif } else { // For providers that go through the extensible AppendExecutionProvider API: @@ -407,7 +409,7 @@ void EnsureDeviceOrtInit(DeviceInterface& device) { // This ensures memory allocated on-device for model inputs/outputs is valid for the lifetime of GenAI. // Names for the device types used by 'SetProviderSessionOptions' - static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "dml", "WebGPU", "QNN", "OpenVINO (Not used, see above)"}; + static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "DML", "WebGPU", "QNN", "OpenVINO (Not used, see above)"}; static_assert(std::size(device_type_names) == static_cast(DeviceType::MAX)); // Create an OrtSessionOptions and set the options to use the DeviceType we're using here @@ -737,9 +739,15 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, auto expanded = OrtValue::CreateTensor(p_device_inputs_->GetAllocator(), input_shape, element_type); auto expanded_span = ByteWrapTensor(*p_device_inputs_, *expanded); - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < num_beams; j++) { - expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes)); + // Detect fast & simple copy case + if (num_beams == 1) { + expanded_span.CopyFrom(input_span); + } else { + // TODO (RyanHill): To avoid cuda uninitialized memory warnings, we should copy input_span to device memory first + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + expanded_span.subspan((i * num_beams + j) * data_size_bytes, data_size_bytes).CopyFrom(input_span.subspan(i * data_size_bytes, data_size_bytes)); + } } } return expanded; diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index f4dbc80610..bfe622c62f 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -37,6 +37,7 @@ TEST(CAPITests, Config) { config->SetProviderOption("brainium", "custom_field2", "hello2"); config->ClearProviders(); config->AppendProvider("cuda"); + config->AppendProvider("dml"); #endif }