Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 20 additions & 44 deletions onnxruntime/test/lora/lora_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,6 @@ struct TestDataType {
verify_load(lora_adapter);
}
};

// Helper that wraps a single Parameter offset into a finished Adapter flatbuffer
// and returns a pointer to the deserialized Parameter.
// The FlatBufferBuilder must outlive the returned pointer.
const adapters::Parameter* BuildAdapterAndGetParam(flatbuffers::FlatBufferBuilder& fbb,
flatbuffers::Offset<adapters::Parameter> param_offset) {
auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer());
return adapter->parameters()->Get(0);
}

} // namespace

TEST(LoraAdapterTest, Load) {
Expand Down Expand Up @@ -252,6 +237,24 @@ TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ValidParam) {
}
}

#ifndef ORT_NO_EXCEPTIONS

namespace {
// Helper that wraps a single Parameter offset into a finished Adapter flatbuffer
// and returns a pointer to the deserialized Parameter.
// The FlatBufferBuilder must outlive the returned pointer.
const adapters::Parameter* BuildAdapterAndGetParam(flatbuffers::FlatBufferBuilder& fbb,
flatbuffers::Offset<adapters::Parameter> param_offset) {
auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

const auto* adapter = adapters::GetAdapter(fbb.GetBufferPointer());
return adapter->parameters()->Get(0);
}
} // namespace

TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_RawDataSizeMismatch) {
// Craft a flatbuffer Parameter where raw_data has fewer bytes than
// shape (8 x 4) * sizeof(float) = 128 bytes.
Expand Down Expand Up @@ -326,35 +329,6 @@ TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_ExcessRawData) {
ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
}

TEST(LoraAdapterTest, Load_RawDataSizeMismatch) {
// End-to-end: loading an adapter whose parameter has mismatched raw data
// should fail during LoraAdapter::Load.
flatbuffers::FlatBufferBuilder fbb;

auto name_offset = fbb.CreateString("bad_param");
std::vector<int64_t> dims = {8, 4};
auto dims_offset = fbb.CreateVector(dims);

// Provide 64 bytes instead of the expected 128 for float [8, 4].
std::vector<uint8_t> short_data(64, 0);
fbb.ForceVectorAlignment(short_data.size(), sizeof(uint8_t), 8);
auto data_offset = fbb.CreateVector(short_data);

auto param_offset = adapters::CreateParameter(
fbb, name_offset, dims_offset, adapters::TensorDataType::FLOAT, data_offset);

auto params_offset = fbb.CreateVector(&param_offset, 1);
auto adapter_offset = adapters::CreateAdapter(
fbb, adapters::kAdapterFormatVersion, kAdapterVersion, kModelVersion, params_offset);
adapters::FinishAdapterBuffer(fbb, adapter_offset);

std::vector<uint8_t> buffer(fbb.GetBufferPointer(),
fbb.GetBufferPointer() + fbb.GetSize());

lora::LoraAdapter adapter;
ASSERT_THROW(adapter.Load(std::move(buffer)), OnnxRuntimeException);
}

TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_MissingName) {
// Parameter with null name should throw gracefully.
flatbuffers::FlatBufferBuilder fbb;
Expand Down Expand Up @@ -442,6 +416,8 @@ TEST(LoraAdapterTest, CreateOrtValueOverLoraParameter_UndefinedDataType) {
ASSERT_THROW(adapters::utils::CreateOrtValueOverLoraParameter(*param), OnnxRuntimeException);
}

#endif // ORT_NO_EXCEPTIONS

#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
Expand Down
Loading