Skip to content

Commit

Permalink
remove alloc map usage in cuda codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 29, 2021
1 parent fd07b35 commit 0ca1924
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
const VarNode* buffer = op->buffer_var.as<VarNode>();
auto it = alloc_storage_scope_.find(buffer);
ICHECK(it != alloc_storage_scope_.end())
<< "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key";

std::string scope = it->second;
std::string scope = GetStorageScope(op->buffer_var);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
Expand All @@ -724,6 +719,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
const VarNode* buffer = op->buffer_var.as<VarNode>();
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
Expand Down

0 comments on commit 0ca1924

Please sign in to comment.