Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions onnxruntime/lora/adapter_format_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#include "core/framework/allocator.h"
#include "core/common/common.h"
#include "core/common/endian.h"
#include "core/framework/endian_utils.h"
#include "core/common/safeint.h"
#include "core/common/span_utils.h"
#include "core/framework/endian_utils.h"
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "core/framework/ort_value.h"
Expand Down Expand Up @@ -155,7 +156,14 @@ std::pair<std::string, OrtValue> CreateOrtValueOverLoraParameter(const Parameter
const auto data_type = param.data_type();
// Copying shape takes care of endianess using flatbuffers accessors
TensorShapeVector shape(param.dims()->begin(), param.dims()->end());
TensorShape tensor_shape(shape);
const auto elem_type = DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int32_t>(data_type))->GetElementType();
const size_t expected_raw_data_size = SafeInt<size_t>(tensor_shape.Size()) * elem_type->Size();
if (param.raw_data()->size() != expected_raw_data_size) {
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
ORT_THROW("Lora Param:", param.name(),
"Raw data size does not match the expected size calculated from tensor shape and element type");
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
}

static const OrtMemoryInfo cpu_meminfo(CPU, OrtAllocatorType::OrtDeviceAllocator);

if constexpr (endian::native == endian::big) {
Expand All @@ -166,15 +174,15 @@ std::pair<std::string, OrtValue> CreateOrtValueOverLoraParameter(const Parameter
// of raw data
// const_cast is necessary due to Tensor class API
Tensor::InitOrtValue(elem_type,
TensorShape(shape),
tensor_shape,
const_cast<uint8_t*>(param.raw_data()->data()),
cpu_meminfo,
result);
}
} else {
// const_cast is necessary due to Tensor class API
Tensor::InitOrtValue(elem_type,
TensorShape(shape),
tensor_shape,
const_cast<uint8_t*>(param.raw_data()->data()),
cpu_meminfo,
result);
Expand Down
143 changes: 143 additions & 0 deletions onnxruntime/test/lora/lora_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,149 @@ TEST(LoraAdapterTest, Load) {
}
}

TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ValidParam) {
// Build a valid adapter with a single float parameter, then call
// CreateOrtValueOverLoraParameter on the deserialized Parameter.
constexpr std::array<int64_t, 2> shape = {8, 4};
InlinedVector<float> data(32);
std::iota(data.begin(), data.end(), 0.f);

adapters::utils::AdapterFormatBuilder adapter_builder;
adapter_builder.AddParameter("valid_param", adapters::TensorDataType::FLOAT,
shape, ReinterpretAsSpan<const uint8_t>(gsl::make_span(data)));

auto buffer = adapter_builder.Finish(kAdapterVersion, kModelVersion);

const auto* adapter = adapters::utils::ValidateAndGetAdapterFromBytes(buffer);
ASSERT_NE(adapter, nullptr);
ASSERT_NE(adapter->parameters(), nullptr);
ASSERT_EQ(adapter->parameters()->size(), 1u);

const auto* param = adapter->parameters()->Get(0);
auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param);

ASSERT_EQ(name, "valid_param");
ASSERT_TRUE(ort_value.IsTensor());

const auto& tensor = ort_value.Get<Tensor>();
ASSERT_EQ(tensor.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);

auto dims = tensor.Shape().GetDims();
ASSERT_EQ(dims.size(), 2u);
ASSERT_EQ(dims[0], 8);
ASSERT_EQ(dims[1], 4);

auto result_span = tensor.DataAsSpan<float>();
ASSERT_EQ(result_span.size(), 32u);
for (size_t i = 0; i < result_span.size(); ++i) {
ASSERT_EQ(static_cast<float>(i), result_span[i]);
}
}

TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_RawDataSizeMismatch) {
// Craft a flatbuffer Parameter where raw_data has fewer bytes than
// shape (8 x 4) * sizeof(float) = 128 bytes.
// We supply only 64 bytes (half the expected amount) so the validation
// inside CreateOrtValueOverLoraParameter must throw.
flatbuffers::FlatBufferBuilder fbb;

auto name_offset = fbb.CreateString("bad_param");
std::vector<int64_t> dims = {8, 4};
auto dims_offset = fbb.CreateVector(dims);

// 8 * 4 floats = 32 elements = 128 bytes expected.
// Provide only 64 bytes (16 floats worth) to trigger the mismatch.
std::vector<uint8_t> short_data(64, 0);
fbb.ForceVectorAlignment(short_data.size(), sizeof(uint8_t), 8);
auto data_offset = fbb.CreateVector(short_data);

auto param_offset = adapters::CreateParameter(
fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);

// Wrap the single parameter inside an Adapter so the buffer is valid flatbuffers.
auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

auto* buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated

// Retrieve the Parameter from the Adapter
const auto* adapter = adapters::GetAdapter(buf);
ASSERT_NE(adapter, nullptr);
ASSERT_NE(adapter->parameters(), nullptr);
ASSERT_EQ(adapter->parameters()->size(), 1u);

const auto* param = adapter->parameters()->Get(0);
ASSERT_NE(param, nullptr);

// The raw_data is 64 bytes but shape says 8x4 floats = 128 bytes.
// CreateOrtValueOverLoraParameter must throw.
ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
}

TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ExcessRawData) {
// Craft a flatbuffer Parameter where raw_data has MORE bytes than expected.
// Shape (2, 2) with float => 4 elements => 16 bytes expected, but we supply 32.
flatbuffers::FlatBufferBuilder fbb;

auto name_offset = fbb.CreateString("excess_param");
std::vector<int64_t> dims = {2, 2};
auto dims_offset = fbb.CreateVector(dims);

// 2 * 2 floats = 4 elements = 16 bytes expected. Supply 32.
std::vector<uint8_t> excess_data(32, 0);
fbb.ForceVectorAlignment(excess_data.size(), sizeof(uint8_t), 8);
auto data_offset = fbb.CreateVector(excess_data);

auto param_offset = adapters::CreateParameter(
fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);

auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer());
ASSERT_NE(adapter, nullptr);

const auto* param = adapter->parameters()->Get(0);
ASSERT_NE(param, nullptr);

// Excess data should also trigger the mismatch throw.
ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
}

TEST(LoraAdapterTest, Load_RawDataSizeMismatch) {
// End-to-end: loading an adapter whose parameter has mismatched raw data
// should fail during LoraAdapter::Load.
Comment thread
yuslepukhin marked this conversation as resolved.
flatbuffers::FlatBufferBuilder fbb;

auto name_offset = fbb.CreateString("bad_param");
std::vector<int64_t> dims = {8, 4};
auto dims_offset = fbb.CreateVector(dims);

// Provide 64 bytes instead of the expected 128 for float [8, 4].
std::vector<uint8_t> short_data(64, 0);
fbb.ForceVectorAlignment(short_data.size(), sizeof(uint8_t), 8);
auto data_offset = fbb.CreateVector(short_data);

auto param_offset = adapters::CreateParameter(
fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);

auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

std::vector<uint8_t> buffer(fbb.GetBufferPointer(),
fbb.GetBufferPointer() + fbb.GetSize());

lora::LoraAdapter adapter;
ASSERT_THROW(adapter.Load(std::move(buffer)), OnnxRuntimeException);
}

#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
Expand Down
Loading