Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 138 additions & 45 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <ggml-cpp.h>

#include <algorithm>
#include <atomic>
#include <array>
#include <cfloat>
#include <cinttypes>
Expand All @@ -33,6 +34,7 @@
#include <future>
#include <fstream>
#include <memory>
#include <mutex>
#include <random>
#include <regex>
#include <set>
Expand All @@ -55,33 +57,24 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
{
// parallel initialization
static const size_t n_threads = N_THREADS;
// static RNG initialization (revisit if n_threads stops being constant)
static std::vector<std::default_random_engine> generators = []() {
std::random_device rd;
std::vector<std::default_random_engine> vec;
vec.reserve(n_threads);
//for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
return vec;
}();

auto init_thread = [&](size_t ith, size_t start, size_t end) {

auto init_thread = [&](size_t start, size_t end) {
thread_local std::default_random_engine gen(std::random_device{}());
std::uniform_real_distribution<float> distribution(min, max);
auto & gen = generators[ith];
for (size_t i = start; i < end; i++) {
data[i] = distribution(gen);
}
};

if (n_threads == 1) {
init_thread(0, 0, nels);
init_thread(0, nels);
} else {
std::vector<std::future<void>> tasks;
tasks.reserve(n_threads);
for (size_t i = 0; i < n_threads; i++) {
size_t start = i*nels/n_threads;
size_t end = (i+1)*nels/n_threads;
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
tasks.push_back(std::async(std::launch::async, init_thread, start, end));
}
for (auto & t : tasks) {
t.get();
Expand Down Expand Up @@ -516,6 +509,25 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
return true;
}

static std::string test_time_now() {
time_t t = time(NULL);
struct tm tm_buf;
#ifdef _WIN32
if (gmtime_s(&tm_buf, &t) != 0) {
return "";
}
#else
if (gmtime_r(&t, &tm_buf) == nullptr) {
return "";
}
#endif
char buf[32];
if (std::strftime(buf, sizeof(buf), "%FT%TZ", &tm_buf) == 0) {
return "";
}
return buf;
}

// Test result structure for SQL output
struct test_result {
std::string test_time;
Expand Down Expand Up @@ -545,11 +557,7 @@ struct test_result {
supported = false;
passed = false;

// Set test time
time_t t = time(NULL);
char buf[32];
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
test_time = test_time_now();

// Set build info
build_commit = ggml_commit();
Expand All @@ -573,11 +581,7 @@ struct test_result {
n_runs(n_runs),
device_description(device_description),
backend_reg_name(backend_reg_name) {
// Set test time
time_t t = time(NULL);
char buf[32];
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
test_time = buf;
test_time = test_time_now();

// Set build info
build_commit = ggml_commit();
Expand Down Expand Up @@ -1110,6 +1114,17 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
GGML_ABORT("invalid output format");
}

static std::mutex g_test_output_mutex;

static void print_test_result_locked(printer * output_printer, const test_result & result) {
if (output_printer == nullptr) {
return;
}

std::lock_guard<std::mutex> guard(g_test_output_mutex);
output_printer->print_test_result(result);
}

struct test_case {
virtual ~test_case() {}

Expand Down Expand Up @@ -1338,9 +1353,7 @@ struct test_case {
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
false, false, "not supported");

if (output_printer) {
output_printer->print_test_result(result);
}
print_test_result_locked(output_printer, result);

ggml_free(ctx);
return test_status_t::NOT_SUPPORTED;
Expand Down Expand Up @@ -1462,9 +1475,7 @@ struct test_case {
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed,
error_msg);

if (output_printer) {
output_printer->print_test_result(result);
}
print_test_result_locked(output_printer, result);

return test_passed ? test_status_t::OK : test_status_t::FAIL;
}
Expand Down Expand Up @@ -9492,8 +9503,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_from_file(const c
return test_cases;
}

static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
printer * output_printer, const char * test_file_path) {
static bool test_backend(ggml_backend_t backend, ggml_backend_dev_t dev, test_mode mode, const char * op_names_filter, const char * params_filter,
printer * output_printer, const char * test_file_path, int parallel_workers) {
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
if (params_filter == nullptr) {
return;
Expand Down Expand Up @@ -9546,21 +9557,90 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
set_use_ref(backend_cpu, true);
}

size_t n_ok = 0;
size_t tests_run = 0;
std::atomic<size_t> n_ok = 0;
std::atomic<size_t> tests_run = 0;
std::vector<std::string> failed_tests;
for (auto & test : test_cases) {
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
std::mutex failed_tests_mutex;

// Each worker grabs a chunk of cases at a time. The chunk shrinks as we
// run out of work so that a few slow tests at the tail get spread across
// workers instead of landing on one unlucky thread.
constexpr size_t MAX_TESTS_PER_ITER = 100;
std::atomic<size_t> test_idx = 0;

const auto & next_chunk = [&](size_t & my_begin, size_t & my_end) {
const size_t cur = test_idx.load(std::memory_order_relaxed);
const size_t remaining = cur < test_cases.size() ? test_cases.size() - cur : 0;
const size_t chunk = std::max<size_t>(1, std::min<size_t>(MAX_TESTS_PER_ITER, remaining / parallel_workers));
my_begin = test_idx.fetch_add(chunk);
my_end = std::min(my_begin + chunk, test_cases.size());
};

const auto & run_tests = [&](ggml_backend_t b, ggml_backend_t b_cpu) {
size_t my_begin, my_end;
next_chunk(my_begin, my_end);
while (my_begin < test_cases.size()) {
for (size_t i = my_begin; i < my_end; ++i) {
auto & test = test_cases[i];
test_status_t status = test->eval(b, b_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
std::lock_guard<std::mutex> guard(failed_tests_mutex);
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
}
}
next_chunk(my_begin, my_end);
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
};

if (parallel_workers <= 1) {
// Reuse the outer backend / backend_cpu so we don't pay an
// extra CPU backend init.
run_tests(backend, backend_cpu);
} else {
std::atomic<size_t> workers_started = 0;

const auto & eval_worker = [&]() {
ggml_backend_t b = ggml_backend_dev_init(dev, NULL);
if (b == NULL) {
return;
}

ggml_backend_t b_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
if (b_cpu == NULL) {
ggml_backend_free(b);
return;
}

if (set_use_ref) {
set_use_ref(b_cpu, true);
}
workers_started++;
run_tests(b, b_cpu);
ggml_backend_free(b_cpu);
ggml_backend_free(b);
};

std::vector<std::thread> threads;
threads.reserve(parallel_workers);
for (int i = 0; i < parallel_workers; ++i) {
threads.emplace_back(eval_worker);
}
for (auto & t : threads) {
t.join();
}

if (workers_started == 0 && !test_cases.empty()) {
ggml_backend_free(backend_cpu);
return false;
}
}

output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
output_printer->print_failed_tests(failed_tests);

Expand Down Expand Up @@ -9708,7 +9788,7 @@ static void show_test_coverage() {

static void usage(char ** argv) {
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops]", argv[0]);
printf(" [--show-coverage] [--test-file <path>]\n");
printf(" [--show-coverage] [--test-file <path>] [-j <n>]\n");
printf(" valid modes:\n");
printf(" - test (default, compare with CPU backend for correctness)\n");
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
Expand All @@ -9720,6 +9800,7 @@ static void usage(char ** argv) {
printf(" --list-ops lists all available GGML operations\n");
printf(" --show-coverage shows test coverage\n");
printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n");
printf(" -j <n> runs tests using <n> parallel worker threads (default: 1, test mode only)\n");
}

int main(int argc, char ** argv) {
Expand All @@ -9729,6 +9810,7 @@ int main(int argc, char ** argv) {
const char * backend_filter = nullptr;
const char * params_filter = nullptr;
const char * test_file_path = nullptr;
int parallel_workers = 1;

for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "test") == 0) {
Expand Down Expand Up @@ -9783,6 +9865,17 @@ int main(int argc, char ** argv) {
usage(argv);
return 1;
}
} else if (strcmp(argv[i], "-j") == 0) {
if (i + 1 < argc) {
parallel_workers = atoi(argv[++i]);
if (parallel_workers < 1) {
usage(argv);
return 1;
}
} else {
usage(argv);
return 1;
}
} else {
usage(argv);
return 1;
Expand Down Expand Up @@ -9835,7 +9928,7 @@ int main(int argc, char ** argv) {
false, "", ggml_backend_dev_description(dev),
total / 1024 / 1024, free / 1024 / 1024, true));

bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get(), test_file_path);
bool ok = test_backend(backend, dev, mode, op_names_filter, params_filter, output_printer.get(), test_file_path, parallel_workers);

if (ok) {
n_ok++;
Expand Down
Loading