Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.

#include "generators.h"
#include "sequences.h"
Expand Down
7 changes: 5 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.
#include "../generators.h"
#include "model.h"
#include "logits.h"
#include "../openvino/interface.h"
#include "../ryzenai/interface.h"

namespace Generators {

Expand All @@ -15,8 +18,8 @@ Logits::Logits(State& state)

input_sequence_lengths.resize(state_.params_->search.batch_size);

if (IsOpenVINOStatefulModel(state.model_) || state_.model_.p_device_->GetType() == DeviceType::RyzenAI) {
// In the case of OpenVINO stateful models or RyzenAI models, they are patched in a way so that they only return the
if (IsOpenVINOStatefulModel(state.model_) || IsRyzenAIPrunedModel(state_.model_)) {
// In the case of OpenVINO stateful models or RyzenAI pruned models, they are patched in a way so that they only return the
// sliced logits needed for sampling. For example, given 43 prompt tokens, instead of returning
// logits of the shape: [1,43,<vocab_size>]
// they will have shape: [1, 1,<vocab_size>].
Expand Down
24 changes: 23 additions & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// Modifications Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
// Modifications Copyright(C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <climits>
#include <random>
Expand Down Expand Up @@ -900,6 +900,20 @@ std::vector<std::string> SessionInfo::GetInputNames() const {
return names;
}

std::vector<int64_t> SessionInfo::GetInputShape(const std::string& name) const {
auto type_info = inputs_.find(name);
if (type_info == inputs_.end())
throw std::runtime_error("Model input was not found: " + name);
return type_info->second->GetTensorTypeAndShapeInfo().GetShape();
}

std::vector<int64_t> SessionInfo::GetOutputShape(const std::string& name) const {
auto type_info = outputs_.find(name);
if (type_info == outputs_.end())
throw std::runtime_error("Model output was not found: " + name);
return type_info->second->GetTensorTypeAndShapeInfo().GetShape();
}

std::vector<const char*> SessionInfo::GetInputSymbolicShape(const std::string& name) const {
auto type_info = inputs_.find(name);
if (type_info == inputs_.end())
Expand Down Expand Up @@ -1150,6 +1164,14 @@ std::shared_ptr<MultiModalProcessor> Model::CreateMultiModalProcessor() const {
return std::make_shared<MultiModalProcessor>(*config_, session_info_);
}

bool Model::IsPruned() const {
const auto& logits_name = config_->model.decoder.outputs.logits;
if (!session_info_.HasOutput(logits_name))
return false;
const auto logits_shape = session_info_.GetOutputShape(logits_name);
Comment thread
kunal-vaishnavi marked this conversation as resolved.
return logits_shape[1] == 1;
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) {
std::string config_overlay;
if (settings) {
Expand Down
7 changes: 7 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "model_type.h"
#include "ortx_tokenizer.h"
Expand Down Expand Up @@ -130,6 +132,9 @@ struct SessionInfo {

std::vector<std::string> GetInputNames() const;

std::vector<int64_t> GetInputShape(const std::string& name) const;
std::vector<int64_t> GetOutputShape(const std::string& name) const;

std::vector<const char*> GetInputSymbolicShape(const std::string& name) const;
std::vector<const char*> GetOutputSymbolicShape(const std::string& name) const;

Expand All @@ -153,6 +158,8 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model>, External

std::unique_ptr<OrtSession> CreateSession(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options);

bool IsPruned() const;

std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;
std::unique_ptr<OrtArenaCfg> arena_cfg_;
Expand Down
7 changes: 7 additions & 0 deletions src/ryzenai/interface.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.

#include "../generators.h"
#include "../search.h"
#include "../models/model.h"
#include "interface.h"
#include <filesystem>
#include <mutex>
Expand Down Expand Up @@ -207,4 +210,8 @@ RyzenAIInterface* GetRyzenAIInterface() {
return RyzenAI::interface_.get();
}

bool IsRyzenAIPrunedModel(const Model& model) {
Comment thread
kunal-vaishnavi marked this conversation as resolved.
return model.p_device_->GetType() == DeviceType::RyzenAI && model.IsPruned();
}

} // namespace Generators
5 changes: 5 additions & 0 deletions src/ryzenai/interface.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.

#pragma once

namespace Generators {
Expand All @@ -13,4 +15,7 @@ struct RyzenAIInterface : DeviceInterface {

RyzenAIInterface* GetRyzenAIInterface();

struct Model;
bool IsRyzenAIPrunedModel(const Model& model);

} // namespace Generators
2 changes: 2 additions & 0 deletions src/smartptrs.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <assert.h>
#include <atomic>
Expand Down
Loading