Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/hipblaslt/clients/tests/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
target_sources(hipblaslt-test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/hipblaslt_gtest_main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hipblaslt_parallel_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hipblaslt_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul_gtest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/auxiliary_gtest.cpp
Expand Down
80 changes: 79 additions & 1 deletion projects/hipblaslt/clients/tests/src/hipblaslt_gtest_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#include "hipblaslt_test.hpp"
#include "test_cleanup.hpp"
#include "utility.hpp"
#ifndef _WIN32
#include "hipblaslt_parallel_test.hpp"
#endif
#include <string>

using namespace testing;
Expand All @@ -37,6 +40,7 @@ class ConfigurableEventListener : public TestEventListener
{
TestEventListener* const eventListener;
std::atomic_size_t skipped_tests{0}; // Number of skipped tests.
std::atomic_size_t current_test_number{0}; // Current test number (incremental counter).

public:
bool showTestCases = true; // Show the names of each test case.
Expand Down Expand Up @@ -86,8 +90,14 @@ class ConfigurableEventListener : public TestEventListener

void OnTestStart(const TestInfo& test_info) override
{
++current_test_number;
if(showTestNames)
{
// Print test number and delegate to default listener
int total_tests = UnitTest::GetInstance()->test_to_run_count();
hipblaslt_cout << "[Test #" << current_test_number << "/" << total_tests << "] " << std::flush;
eventListener->OnTestStart(test_info);
}
}

void OnTestPartResult(const TestPartResult& result) override
Expand Down Expand Up @@ -228,6 +238,74 @@ int main(int argc, char** argv)
{
std::string args = hipblaslt_capture_args(argc, argv);

// Check for --help to add our custom options
for(int i = 1; i < argc; i++)
{
std::string arg = argv[i];
if(arg == "--help" || arg == "-h" || arg == "-?" || arg == "/?" || arg == "--help-all")
{
hipblaslt_cout << "\nhipBLASLt Test Options:\n";
hipblaslt_cout << " --num_gpus=N\n";
hipblaslt_cout << " --num_gpus N\n";
hipblaslt_cout << " Run tests in parallel across N GPUs (Unix/Linux only).\n";
hipblaslt_cout << " Tests are automatically split evenly across the specified\n";
hipblaslt_cout << " number of GPUs. Each GPU runs its assigned tests independently.\n";
hipblaslt_cout << " Example: ./hipblaslt-test --num_gpus 8 --gtest_filter=\"*smoke*\"\n";
hipblaslt_cout << " Note: If --gtest_output=json:file.json is specified, per-GPU\n";
hipblaslt_cout << " results are saved as file_gpu0.json, file_gpu1.json, etc.\n";
hipblaslt_cout << "\n";
break;
}
}

// Check for --num_gpus argument and handle platform support
int num_gpus = 0;
bool has_num_gpus_flag = false;
int num_gpus_arg_start = -1; // Track where the flag starts for removal
int num_gpus_arg_count = 0; // How many argv entries to remove (1 or 2)

for(int i = 1; i < argc; i++)
{
std::string arg = argv[i];
if(arg.find("--num_gpus=") == 0)
{
num_gpus = std::atoi(arg.substr(11).c_str());
has_num_gpus_flag = true;
num_gpus_arg_start = i;
num_gpus_arg_count = 1;
break;
}
else if(arg == "--num_gpus" && i + 1 < argc)
{
num_gpus = std::atoi(argv[i + 1]);
has_num_gpus_flag = true;
num_gpus_arg_start = i;
num_gpus_arg_count = 2;
break;
}
}

#ifdef _WIN32
// On Windows, parallel GPU execution is not supported
if(has_num_gpus_flag)
{
hipblaslt_cerr << "Warning: --num_gpus is not supported on Windows. Ignoring flag." << std::endl;

// Remove --num_gpus from argv to prevent GTest from complaining about unknown flag
for(int i = num_gpus_arg_start; i + num_gpus_arg_count < argc; i++)
{
argv[i] = argv[i + num_gpus_arg_count];
}
argc -= num_gpus_arg_count;
}
#else
// If parallel GPUs requested, use parallel execution
if(num_gpus > 1)
{
return run_tests_parallel_gpus(argc, argv, num_gpus);
}
#endif

// Set signal handler
hipblaslt_test_sigaction();

Expand Down Expand Up @@ -256,7 +334,7 @@ int main(int argc, char** argv)
// Failures printed at end for reporting so repeat version info
hipblaslt_print_version();

// end test results with command line
// Print command line at the end
hipblaslt_print_args(args);

//hipblaslt_shutdown();
Expand Down
Loading
Loading