Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions clients/common/rocblas_gentest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def setdefaults(test):
# Do not put constant defaults here -- use rocblas_common.yaml for that.
# These are only for dynamic defaults
# TODO: This should be ideally moved to YAML file, with eval'd expressions.

if all([x in test for x in ('M', 'incx', 'strideScale')]) and test['function']=='ger_strided_batched':
test.setdefault('stride_x', int(test['M'] * abs(test['incx']) *
test['strideScale']))
else:
test.setdefault('stride_x', 0)

if all([x in test for x in ('N', 'incy', 'strideScale')]) and test['function']=='ger_strided_batched':
test.setdefault('stride_y', int(test['N'] * abs(test['incy']) *
test['strideScale']))
else:
test.setdefault('stride_y', 0)

if test['transA'] == '*' or test['transB'] == '*':
test.setdefault('lda', 0)
test.setdefault('ldb', 0)
Expand Down
109 changes: 87 additions & 22 deletions clients/gtest/ger_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,77 @@
#include "rocblas_datatype2string.hpp"
#include "rocblas_test.hpp"
#include "testing_ger.hpp"
#include "testing_ger_batched.hpp"
#include "testing_ger_strided_batched.hpp"
#include "type_dispatch.hpp"
#include <cstring>
#include <type_traits>

namespace
{
// possible gemv test cases
enum ger_test_type
{
GER,
GER_BATCHED,
GER_STRIDED_BATCHED,
};

//ger test template
template <template <typename...> class FILTER, ger_test_type GER_TYPE>
struct ger_template : RocBLAS_Test<ger_template<FILTER, GER_TYPE>, FILTER>
{
// Filter for which types apply to this suite
static bool type_filter(const Arguments& arg)
{
return rocblas_simple_dispatch<ger_template::template type_filter_functor>(arg);
}

// Filter for which functions apply to this suite
static bool function_filter(const Arguments& arg)
{
switch(GER_TYPE)
{
case GER:
return !strcmp(arg.function, "ger") || !strcmp(arg.function, "ger_bad_arg");
case GER_BATCHED:
return !strcmp(arg.function, "ger_batched")
|| !strcmp(arg.function, "ger_batched_bad_arg");
case GER_STRIDED_BATCHED:
return !strcmp(arg.function, "ger_strided_batched")
|| !strcmp(arg.function, "ger_strided_batched_bad_arg");
}
return false;
}

// Google Test name suffix based on parameters
static std::string name_suffix(const Arguments& arg)
{
RocBLAS_TestName<ger_template> name;

name << rocblas_datatype2string(arg.a_type) << '_' << arg.M << '_' << arg.N << '_'
<< arg.alpha << '_' << arg.incx;

if(GER_TYPE == GER_STRIDED_BATCHED)
name << '_' << arg.stride_x;

name << '_' << arg.incy;

if(GER_TYPE == GER_STRIDED_BATCHED)
name << '_' << arg.stride_y;

name << '_' << arg.lda;

if(GER_TYPE == GER_STRIDED_BATCHED)
name << '_' << arg.stride_a;

if(GER_TYPE == GER_STRIDED_BATCHED || GER_TYPE == GER_BATCHED)
name << '_' << arg.batch_count;

return std::move(name);
}
};

// By default, this test does not apply to any types.
// The unnamed second parameter is used for enable_if below.
template <typename, typename = void>
Expand All @@ -37,38 +102,38 @@ namespace
testing_ger<T>(arg);
else if(!strcmp(arg.function, "ger_bad_arg"))
testing_ger_bad_arg<T>(arg);
else if(!strcmp(arg.function, "ger_batched"))
testing_ger_batched<T>(arg);
else if(!strcmp(arg.function, "ger_batched_bad_arg"))
testing_ger_batched_bad_arg<T>(arg);
else if(!strcmp(arg.function, "ger_strided_batched"))
testing_ger_strided_batched<T>(arg);
else if(!strcmp(arg.function, "ger_strided_batched_bad_arg"))
testing_ger_strided_batched_bad_arg<T>(arg);
else
FAIL() << "Internal error: Test called with unknown function: " << arg.function;
}
};

struct ger : RocBLAS_Test<ger, ger_testing>
using ger = ger_template<ger_testing, GER>;
TEST_P(ger, blas2)
{
// Filter for which types apply to this suite
static bool type_filter(const Arguments& arg)
{
return rocblas_simple_dispatch<type_filter_functor>(arg);
}

// Filter for which functions apply to this suite
static bool function_filter(const Arguments& arg)
{
return !strcmp(arg.function, "ger") || !strcmp(arg.function, "ger_bad_arg");
}
rocblas_simple_dispatch<ger_testing>(GetParam());
}
INSTANTIATE_TEST_CATEGORIES(ger);

// Google Test name suffix based on parameters
static std::string name_suffix(const Arguments& arg)
{
return RocBLAS_TestName<ger>{} << rocblas_datatype2string(arg.a_type) << '_' << arg.M
<< '_' << arg.N << '_' << arg.alpha << '_' << arg.incx
<< '_' << arg.incy << '_' << arg.lda;
}
};
using ger_batched = ger_template<ger_testing, GER_BATCHED>;
TEST_P(ger_batched, blas2)
{
rocblas_simple_dispatch<ger_testing>(GetParam());
}
INSTANTIATE_TEST_CATEGORIES(ger_batched);

TEST_P(ger, blas2)
using ger_strided_batched = ger_template<ger_testing, GER_STRIDED_BATCHED>;
TEST_P(ger_strided_batched, blas2)
{
rocblas_simple_dispatch<ger_testing>(GetParam());
}
INSTANTIATE_TEST_CATEGORIES(ger);
INSTANTIATE_TEST_CATEGORIES(ger_strided_batched);

} // namespace
128 changes: 101 additions & 27 deletions clients/gtest/ger_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,44 @@ include: known_bugs.yaml

Definitions:
- &small_matrix_size_range
- { M: -1, N: 1, lda: 1 }
- { M: 1, N: -1, lda: 1 }
- { M: 1, N: 1, lda: -1 }
- { M: 10, N: 1, lda: 9 }
- { M: 0, N: 1, lda: 1 }
- { M: 1, N: 0, lda: 1 }
- { M: 1, N: 1, lda: 0 }
- { M: 11, N: 12, lda: 13 }
- { M: 16, N: 16, lda: 16 }
- { M: 33, N: 32, lda: 33 }
- { M: 65, N: 65, lda: 66 }
- { M: -1, N: 1, lda: 1, stride_a: 1 }
- { M: 1, N: -1, lda: 1, stride_a: 1 }
- { M: 1, N: 1, lda: -1, stride_a: 1 }
- { M: 10, N: 1, lda: 9, stride_a: 1 }
- { M: 0, N: 1, lda: 1, stride_a: 1 }
- { M: 1, N: 0, lda: 1, stride_a: 1 }
- { M: 1, N: 1, lda: 0, stride_a: 1 }
- { M: 11, N: 12, lda: 13, stride_a: 1 }
- { M: 16, N: 16, lda: 16, stride_a: 256 }
- { M: 33, N: 32, lda: 33, stride_a: 1056 }
- { M: 65, N: 65, lda: 66, stride_a: 4300 }

- &medium_matrix_size_range
- { M: 10, N: 10, lda: 2 }
- { M: 600, N: 500, lda: 500 }
- { M: 1000, N: 1000, lda: 1000 }
- { M: 10, N: 10, lda: 2, stride_a: 1000 }
- { M: 600, N: 500, lda: 500, stride_a: 250000 }
- { M: 1000, N: 1000, lda: 1000, stride_a: 1000100 }

- &large_matrix_size_range
- { M: 2000, N: 2000, lda: 2000 }
- { M: 4011, N: 4011, lda: 4011 }
- { M: 8000, N: 8000, lda: 8000 }
- { M: 2000, N: 2000, lda: 2000, stride_a: 4000000 }
- { M: 4011, N: 4011, lda: 4011, stride_a: 16088200 }
- { M: 8000, N: 8000, lda: 8000, stride_a: 64000000 }

- &incx_incy_range
- { incx: 1, incy: 1 }
- { incx: -1, incy: -1 }
- { incx: 1, incy: -1 }
- { incx: -1, incy: -1 }
- { incx: 0, incy: -1 }
- { incx: 0, incy: 1 }
- { incx: 1, incy: 0 }
- { incx: 1, incy: 2 }
- { incx: 2, incy: 1 }
- { incx: 10, incy: 99 }
- { incx: 1, incy: 1}
- { incx: -1, incy: -1}
- { incx: 1, incy: -1}
- { incx: 0, incy: -1}
- { incx: 0, incy: 1}
- { incx: 1, incy: 0}
- { incx: 1, incy: 2}
- { incx: 2, incy: 1}
- { incx: 10, incy: 99}

- &nightly_incx_incy_range
- { incx: 1, incy: 1, strideScale: 1.5}
- { incx: 1, incy: -1, strideScale: 2}
- { incx: 1, incy: 2, strideScale: 1}
- { incx: 10, incy: 99, strideScale: 1}

Tests:
- name: ger_bad_arg
Expand Down Expand Up @@ -67,4 +72,73 @@ Tests:
matrix_size: *large_matrix_size_range
incx_incy: *incx_incy_range
alpha: [ -0.5, 2.0, 0.0, 0.6 ]

- name: ger_batched_bad_arg
category: pre_checkin
function: ger_batched_bad_arg
precision: *single_double_precisions
batch_count: [ -5, 0, 1, 5, 10 ]

- name: ger_batched_small
category: quick
function: ger_batched
precision: *single_double_precisions
matrix_size: *small_matrix_size_range
incx_incy: *incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
batch_count: [ -5, 0, 1, 5, 10 ]

- name: ger_batched_medium
category: pre_checkin
function: ger_batched
precision: *single_double_precisions
matrix_size: *medium_matrix_size_range
incx_incy: *incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
batch_count: [ 1, 5, 10 ]

- name: ger_batched_large
category: nightly
function: ger_batched
precision: *single_double_precisions
matrix_size: *large_matrix_size_range
incx_incy: *nightly_incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
batch_count: [ 1, 3 ]

- name: ger_strided_batched_bad_arg
category: pre_checkin
function: ger_strided_batched_bad_arg
precision: *single_double_precisions
strideScale: [ -1, 0, 0.5, 1, 2 ]
batch_count: [ -5, 0, 1, 5, 10 ]

- name: ger_strided_batched_small
category: quick
function: ger_strided_batched
precision: *single_double_precisions
matrix_size: *small_matrix_size_range
incx_incy: *incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
strideScale: [ 0.5, 1, 2 ]
batch_count: [ -5, 0, 1, 5, 10 ]

- name: ger_strided_batched_medium
category: pre_checkin
function: ger_strided_batched
precision: *single_double_precisions
matrix_size: *medium_matrix_size_range
incx_incy: *incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
strideScale: [ 0.5, 1, 2 ]
batch_count: [ 1, 5, 10 ]

- name: ger_strided_batched_large
category: nightly
function: ger_strided_batched
precision: *single_double_precisions
matrix_size: *large_matrix_size_range
incx_incy: *nightly_incx_incy_range
alpha: [ -0.5, 2.0, 0.0 ]
batch_count: [ 1, 3 ]
...
41 changes: 41 additions & 0 deletions clients/include/rocblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,47 @@ static constexpr auto rocblas_ger<float> = rocblas_sger;
template <>
static constexpr auto rocblas_ger<double> = rocblas_dger;

template <typename T>
rocblas_status (*rocblas_ger_batched)(rocblas_handle handle,
rocblas_int m,
rocblas_int n,
const T* alpha,
const T* const x[],
rocblas_int incx,
const T* const y[],
rocblas_int incy,
T* const A[],
rocblas_int lda,
rocblas_int batch_count);

template <>
static constexpr auto rocblas_ger_batched<float> = rocblas_sger_batched;

template <>
static constexpr auto rocblas_ger_batched<double> = rocblas_dger_batched;

template <typename T>
rocblas_status (*rocblas_ger_strided_batched)(rocblas_handle handle,
rocblas_int m,
rocblas_int n,
const T* alpha,
const T* x,
rocblas_int incx,
rocblas_int stride_x,
const T* y,
rocblas_int incy,
rocblas_int stride_y,
T* A,
rocblas_int lda,
rocblas_int stride_a,
rocblas_int batch_count);

template <>
static constexpr auto rocblas_ger_strided_batched<float> = rocblas_sger_strided_batched;

template <>
static constexpr auto rocblas_ger_strided_batched<double> = rocblas_dger_strided_batched;

// syr
template <typename T>
rocblas_status (*rocblas_syr)(rocblas_handle handle,
Expand Down
2 changes: 0 additions & 2 deletions clients/include/rocblas_common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ Defaults:
stride_b: 0
stride_c: 0
stride_d: 0
stride_x: 0
stride_y: 0
norm_check: 0
unit_check: 1
timing: 0
Expand Down
27 changes: 27 additions & 0 deletions clients/include/testing_gemv_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,34 @@ void testing_gemv_batched(const Arguments& arg)

//quick return
if(!M || !N || !batch_count)
{
device_vector<T*, 0, T> dAA1(1);
device_vector<T*, 0, T> dxA1(1);
device_vector<T*, 0, T> dy_1A1(1);

if(!dAA1 || !dxA1 || !dy_1A1)
{
CHECK_HIP_ERROR(hipErrorOutOfMemory);
return;
}

EXPECT_ROCBLAS_STATUS(rocblas_gemv_batched<T>(handle,
transA,
M,
N,
&h_alpha,
dAA1,
lda,
dxA1,
incx,
&h_beta,
dy_1A1,
incy,
batch_count),
rocblas_status_success);

return;
}

//Device-arrays of pointers to device memory
device_vector<T*, 0, T> dAA(batch_count);
Expand Down
Loading