Skip to content

Commit

Permalink
Fix GemmDriver uninitialized field
Browse files Browse the repository at this point in the history
  • Loading branch information
AngryLoki committed Nov 3, 2023
1 parent f185a64 commit 717861b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
printf("Usage: ./driver *base_arg* *other_args*\n");
printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], "
"pool[fp16], lrn[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16,fp64]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}
Expand All @@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" &&
arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" &&
arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" &&
arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version")
{
Expand Down
17 changes: 15 additions & 2 deletions driver/gemm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,19 @@ int GemmDriver<T>::GetandSetData()
gemm_desc.strideB = gemm_desc.k * gemm_desc.n;
gemm_desc.strideC = gemm_desc.m * gemm_desc.n;

if constexpr (std::is_same_v<T, float>)
{
gemm_desc.dataType = miopenFloat;
}
else if constexpr (std::is_same_v<T, float16>)
{
gemm_desc.dataType = miopenHalf;
}
else
{
static_assert(!"unsupported type");
}

return (0);
}

Expand All @@ -230,9 +243,9 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
a = std::vector<T>(a_sz);
b = std::vector<T>(b_sz);
#if GEMM_DRIVER_DEBUG
c = std::vector<T>(c_sz, 1.);
c = std::vector<T>(c_sz, static_cast<T>(1.));
#else
c = std::vector<T>(c_sz, 0.);
c = std::vector<T>(c_sz, static_cast<T>(0.));
#endif
chost = c;

Expand Down
8 changes: 4 additions & 4 deletions driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ int main(int argc, char* argv[])
drv = new GemmDriver<float>();
}
// TODO half is not supported in gemm
// else if(base_arg == "gemmfp16")
// {
// drv = new GemmDriver<float16>();
// }
else if(base_arg == "gemmfp16")
{
drv = new GemmDriver<float16>();
}
#endif
else if(base_arg == "bnorm")
{
Expand Down

0 comments on commit 717861b

Please sign in to comment.