diff --git a/test/main.cpp b/test/main.cpp index 374eaf84b3..0d49dee908 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -2,18 +2,34 @@ // Licensed under the MIT License. #include +#include #include #include "ort_genai.h" +// Global variable to store custom model base path +std::string g_custom_model_path; + int main(int argc, char** argv) { std::cout << "Generators Utility Library" << std::endl; + std::cout << "Initializing OnnxRuntime... "; std::cout.flush(); try { std::cout << "done" << std::endl; ::testing::InitGoogleTest(&argc, argv); + + // Parse custom model path argument after InitGoogleTest + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg == "--model_path" && i + 1 < argc) { + g_custom_model_path = argv[++i]; + std::cout << "Using custom model path: " << g_custom_model_path << std::endl; + break; + } + } + int result = RUN_ALL_TESTS(); std::cout << "Shutting down OnnxRuntime... "; OgaShutdown(); diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 56ce5696e0..70dd8c2ea4 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "span.h" @@ -12,6 +13,9 @@ #include #include +// External global variable from main.cpp for custom model path +extern std::string g_custom_model_path; + #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -274,9 +278,12 @@ static const std::pair c_phi3_nvtrt_model_paths[] = { }; void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* model_label) { + // Use custom path if provided, otherwise use default + std::string resolved_path = g_custom_model_path.empty() ? model_path : g_custom_model_path; + // Skip test if NvTensorRT model is not available - if (!std::filesystem::exists(model_path)) { - GTEST_SKIP() << "NvTensorRT model not available at: " << model_path; + if (!std::filesystem::exists(resolved_path)) { + GTEST_SKIP() << "NvTensorRT model not available at: " << resolved_path; } const std::vector input_ids_shape{1, 19}; const std::vector input_ids{32006, 887, 526, 263, 8444, 29871, 23869, 20255, 29889, 32007, 32010, 6324, 29892, 1128, 526, 366, 29973, 32007, 32001}; @@ -285,7 +292,7 @@ void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* mo const std::vector expected_output{ 32006, 887, 526, 263, 8444, 29871, 23869, 20255, 29889, 32007, 32010, 6324, 29892, 1128, 526, 366, 29973, 32007, 32001, // Input tokens (19) 15043, 29991, 306, 29915, 29885, 2599}; - auto config = OgaConfig::Create(model_path); + auto config = OgaConfig::Create(resolved_path.c_str()); config->ClearProviders(); config->AppendProvider("NvTensorRtRtx"); auto model = OgaModel::Create(*config); @@ -321,9 +328,12 @@ TEST(ModelTests, GreedySearchPhi3NvTensorRtRtx) { } void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const char* model_label) { + // Use custom path if provided, otherwise use default + std::string resolved_path = g_custom_model_path.empty() ? model_path : g_custom_model_path; + // Skip test if NvTensorRT model is not available - if (!std::filesystem::exists(model_path)) { - GTEST_SKIP() << "NvTensorRT model not available at: " << model_path; + if (!std::filesystem::exists(resolved_path)) { + GTEST_SKIP() << "NvTensorRT model not available at: " << resolved_path; } const std::vector input_ids_shape{1, 19}; @@ -336,7 +346,7 @@ void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const cha 32006, 887, 526, 263, 8444, 29871, 23869, 20255, 29889, 32007, 32010, 6324, 29892, 1128, 526, 366, 29973, 32007, 32001, // Input tokens (19) 15043, 1554, 13, 16271, 29892, 8733}; - auto config = OgaConfig::Create(model_path); + auto config = OgaConfig::Create(resolved_path.c_str()); config->ClearProviders(); config->AppendProvider("NvTensorRtRtx"); auto model = OgaModel::Create(*config); diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index 83090d10b5..e1ec1643f3 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -25,6 +25,9 @@ #define MODEL_PATH "../../test/test_models/" #endif +// External global variable from main.cpp for custom model path +extern std::string g_custom_model_path; + // Defined in sampling_tests.cpp void CreateRandomLogits(float* logits, int num_large, int vocab_size, int batch_size, std::mt19937& engine); @@ -109,11 +112,11 @@ void PrintSummary(const std::vector& results) { } BenchmarkResult RunBenchmark(const BenchmarkParams& params) { - const char* model_path = MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"; + std::string model_path = MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"; if (strcmp(params.device_type, "NvTensorRtRtx") == 0) { - model_path = MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt"; + model_path = g_custom_model_path.empty() ? MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt" : g_custom_model_path; } - auto config = OgaConfig::Create(model_path); + auto config = OgaConfig::Create(model_path.c_str()); std::string overlay = R"({ "model": { "vocab_size" : )" + std::to_string(params.vocab_size) + R"( } })"; config->Overlay(overlay.c_str()); config->ClearProviders(); @@ -184,7 +187,8 @@ TEST(SamplingBenchmarks, PerformanceTests) { device_types.push_back("cuda"); #endif // Add NvTensorRtRtx if model is available - if (std::filesystem::exists(MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt")) { + std::string resolved_nvtrt_path = g_custom_model_path.empty() ? MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt" : g_custom_model_path; + if (std::filesystem::exists(resolved_nvtrt_path)) { device_types.push_back("NvTensorRtRtx"); } diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 6fced8ed1d..e5de94fab6 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -19,6 +19,9 @@ #define MODEL_PATH "../../test/test_models/" #endif +// External global variable from main.cpp for custom model path +extern std::string g_custom_model_path; + TEST(SamplingTests, BatchedSamplingTopPCpu) { std::vector input_ids{0, 1, 2, 3}; std::vector expected_output{1, 2, 3, 4}; @@ -555,8 +558,12 @@ struct NvTensorRtRtxTestSetup { static NvTensorRtRtxTestSetup Create(int vocab_size, int batch_size, int max_length = 10) { NvTensorRtRtxTestSetup setup; + // Use custom path if provided, otherwise use default + std::string nvtrt_path = MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt"; + std::string resolved_path = g_custom_model_path.empty() ? nvtrt_path : g_custom_model_path; + // Check if model is available - if (!std::filesystem::exists(MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt")) { + if (!std::filesystem::exists(resolved_path)) { setup.is_available = false; return setup; } @@ -564,7 +571,7 @@ struct NvTensorRtRtxTestSetup { setup.is_available = true; // Create config with vocab_size overlay - auto config = OgaConfig::Create(MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt"); + auto config = OgaConfig::Create(resolved_path.c_str()); std::string overlay = R"({ "model": { "vocab_size" : )" + std::to_string(vocab_size) + R"( } })"; config->Overlay(overlay.c_str()); config->ClearProviders();