From ed53b9b469a7c85b3a909879f44686dcebe41c85 Mon Sep 17 00:00:00 2001 From: Oleksandr Kholodnyi Date: Thu, 8 Jan 2026 12:20:32 -0800 Subject: [PATCH 1/4] [RyzenAI] Non-pruned models backward compatibility --- src/models/logits.cpp | 5 +++-- src/models/model.cpp | 14 ++++++++++++++ src/models/model.h | 3 +++ src/ryzenai/interface.cpp | 15 +++++++++++++++ src/ryzenai/interface.h | 3 +++ 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 522ae711b0..b30fddaca0 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -4,6 +4,7 @@ #include "model.h" #include "logits.h" #include "../openvino/interface.h" +#include "../ryzenai/interface.h" namespace Generators { @@ -15,8 +16,8 @@ Logits::Logits(State& state) input_sequence_lengths.resize(state_.params_->search.batch_size); - if (IsOpenVINOStatefulModel(state.model_) || state_.model_.p_device_->GetType() == DeviceType::RyzenAI) { - // In the case of OpenVINO stateful models or RyzenAI models, they are patched in a way so that they only return the + if (IsOpenVINOStatefulModel(state.model_) || IsRyzenAIPrunedModel(state_.model_)) { + // In the case of OpenVINO stateful models or RyzenAI pruned models, they are patched in a way so that they only return the // sliced logits needed for sampling. For example, given 43 prompt tokens, instead of returning // logits of the shape: [1,43,] // they will have shape: [1, 1,]. diff --git a/src/models/model.cpp b/src/models/model.cpp index 517a9b9487..2a3454695f 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -900,6 +900,20 @@ std::vector SessionInfo::GetInputNames() const { return names; } +std::vector SessionInfo::GetInputShape(const std::string& name) const { + auto type_info = inputs_.find(name); + if (type_info == inputs_.end()) + throw std::runtime_error("Model input was not found: " + name); + return type_info->second->GetTensorTypeAndShapeInfo().GetShape(); +} + +std::vector SessionInfo::GetOutputShape(const std::string& name) const { + auto type_info = outputs_.find(name); + if (type_info == outputs_.end()) + throw std::runtime_error("Model output was not found: " + name); + return type_info->second->GetTensorTypeAndShapeInfo().GetShape(); +} + std::vector SessionInfo::GetInputSymbolicShape(const std::string& name) const { auto type_info = inputs_.find(name); if (type_info == inputs_.end()) diff --git a/src/models/model.h b/src/models/model.h index 7faa1000fe..08feceb1d9 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -130,6 +130,9 @@ struct SessionInfo { std::vector GetInputNames() const; + std::vector GetInputShape(const std::string& name) const; + std::vector GetOutputShape(const std::string& name) const; + std::vector GetInputSymbolicShape(const std::string& name) const; std::vector GetOutputSymbolicShape(const std::string& name) const; diff --git a/src/ryzenai/interface.cpp b/src/ryzenai/interface.cpp index 837f597428..4df5566977 100644 --- a/src/ryzenai/interface.cpp +++ b/src/ryzenai/interface.cpp @@ -1,5 +1,6 @@ #include "../generators.h" #include "../search.h" +#include "../models/model.h" #include "interface.h" #include #include @@ -207,4 +208,18 @@ RyzenAIInterface* GetRyzenAIInterface() { return RyzenAI::interface_.get(); } +bool IsRyzenAIPrunedModel(const Model& model) { + if (model.p_device_->GetType() != DeviceType::RyzenAI) + return false; + + const auto& logits_name = model.config_->model.decoder.outputs.logits; + + if (!model.session_info_.HasOutput(logits_name)) + return false; + + const auto logits_shape = model.session_info_.GetOutputShape(logits_name); + + return logits_shape[1] == 1; +} + } // namespace Generators diff --git a/src/ryzenai/interface.h b/src/ryzenai/interface.h index 3657ca9779..ce5f62604e 100644 --- a/src/ryzenai/interface.h +++ b/src/ryzenai/interface.h @@ -13,4 +13,7 @@ struct RyzenAIInterface : DeviceInterface { RyzenAIInterface* GetRyzenAIInterface(); +struct Model; +bool IsRyzenAIPrunedModel(const Model& model); + } // namespace Generators From a1d6ef18d7867f77e6d5812fba358b87dbde3a2a Mon Sep 17 00:00:00 2001 From: Oleksandr Kholodnyi Date: Thu, 8 Jan 2026 12:48:27 -0800 Subject: [PATCH 2/4] Added missing copyrights --- src/generators.cpp | 2 ++ src/models/logits.cpp | 2 ++ src/models/model.cpp | 2 +- src/models/model.h | 2 ++ src/ryzenai/interface.cpp | 2 ++ src/ryzenai/interface.h | 2 ++ src/smartptrs.h | 2 ++ 7 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index aa337c1a10..6d1d8cc309 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #include "generators.h" #include "sequences.h" diff --git a/src/models/logits.cpp b/src/models/logits.cpp index b30fddaca0..69a85aa1f4 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #include "../generators.h" #include "model.h" #include "logits.h" diff --git a/src/models/model.cpp b/src/models/model.cpp index 2a3454695f..dda2cacb5d 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// Modifications Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. +// Modifications Copyright(C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved. #include #include #include diff --git a/src/models/model.h b/src/models/model.h index 08feceb1d9..d39584068e 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "model_type.h" #include "ortx_tokenizer.h" diff --git a/src/ryzenai/interface.cpp b/src/ryzenai/interface.cpp index 4df5566977..02595d1f0b 100644 --- a/src/ryzenai/interface.cpp +++ b/src/ryzenai/interface.cpp @@ -1,3 +1,5 @@ +// Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. + #include "../generators.h" #include "../search.h" #include "../models/model.h" diff --git a/src/ryzenai/interface.h b/src/ryzenai/interface.h index ce5f62604e..880dd1b181 100644 --- a/src/ryzenai/interface.h +++ b/src/ryzenai/interface.h @@ -1,3 +1,5 @@ +// Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. + #pragma once namespace Generators { diff --git a/src/smartptrs.h b/src/smartptrs.h index c0a22b216f..33adb2d89e 100644 --- a/src/smartptrs.h +++ b/src/smartptrs.h @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include From 2f94b81458b41a7e7e0e9796453e7d0ab3985b9b Mon Sep 17 00:00:00 2001 From: Oleksandr Kholodnyi Date: Thu, 8 Jan 2026 12:58:39 -0800 Subject: [PATCH 3/4] Pruned logits detection code was made common (Model::IsPruned) --- src/models/model.cpp | 8 ++++++++ src/models/model.h | 2 ++ src/ryzenai/interface.cpp | 12 +----------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index dda2cacb5d..640cb1bf8d 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1164,6 +1164,14 @@ std::shared_ptr Model::CreateMultiModalProcessor() const { return std::make_shared(*config_, session_info_); } +bool Model::IsPruned() const { + const auto& logits_name = config_->model.decoder.outputs.logits; + if (!session_info_.HasOutput(logits_name)) + return false; + const auto logits_shape = session_info_.GetOutputShape(logits_name); + return logits_shape[1] == 1; +} + std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) { std::string config_overlay; if (settings) { diff --git a/src/models/model.h b/src/models/model.h index d39584068e..f3f548916c 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -158,6 +158,8 @@ struct Model : std::enable_shared_from_this, LeakChecked, External std::unique_ptr CreateSession(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options); + bool IsPruned() const; + std::unique_ptr config_; std::unique_ptr session_options_; std::unique_ptr arena_cfg_; diff --git a/src/ryzenai/interface.cpp b/src/ryzenai/interface.cpp index 02595d1f0b..3451bc0e1b 100644 --- a/src/ryzenai/interface.cpp +++ b/src/ryzenai/interface.cpp @@ -211,17 +211,7 @@ RyzenAIInterface* GetRyzenAIInterface() { } bool IsRyzenAIPrunedModel(const Model& model) { - if (model.p_device_->GetType() != DeviceType::RyzenAI) - return false; - - const auto& logits_name = model.config_->model.decoder.outputs.logits; - - if (!model.session_info_.HasOutput(logits_name)) - return false; - - const auto logits_shape = model.session_info_.GetOutputShape(logits_name); - - return logits_shape[1] == 1; + return model.p_device_->GetType() == DeviceType::RyzenAI && model.IsPruned(); } } // namespace Generators From aff73a03e284bdcdb5fd4739c2fd44da037f992d Mon Sep 17 00:00:00 2001 From: Oleksandr Kholodnyi Date: Thu, 8 Jan 2026 12:59:24 -0800 Subject: [PATCH 4/4] Lint --- src/generators.cpp | 2 +- src/models/logits.cpp | 2 +- src/models/model.h | 2 +- src/smartptrs.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 6d1d8cc309..d163eeacde 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// +// // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #include "generators.h" diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 69a85aa1f4..60cafd7a0a 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// +// // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #include "../generators.h" #include "model.h" diff --git a/src/models/model.h b/src/models/model.h index f3f548916c..906e7e2db2 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// +// // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "model_type.h" diff --git a/src/smartptrs.h b/src/smartptrs.h index 33adb2d89e..dfd769b234 100644 --- a/src/smartptrs.h +++ b/src/smartptrs.h @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// +// // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #pragma once #include