Skip to content

Commit

Permalink
begin removing storage_scope attr
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 29, 2021
1 parent d39a470 commit e20e195
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 25 deletions.
5 changes: 3 additions & 2 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -99,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Expand Down
8 changes: 2 additions & 6 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
*p_native_bits = NativeVectorBits(info.scope);
*p_native_bits =
NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef<Var>(buf_var))));
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
Expand Down Expand Up @@ -1390,11 +1391,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
}
}
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
alloc_storage_info_[v].scope =
runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
Expand Down
2 changes: 0 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief The alignment of allocation */
int alignment{0};
};
Expand Down
6 changes: 3 additions & 3 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}

if (info.scope.rank == runtime::StorageRank::kLocal) {
auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -75,7 +75,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Expand Down
10 changes: 4 additions & 6 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include "codegen_spirv.h"

#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -638,13 +639,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()];
auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var));
spirv::SType etype = builder_->GetSType(op->dtype);
if (info.scope.rank == runtime::StorageRank::kLocal) {
if (storage_scope.rank == runtime::StorageRank::kLocal) {
buf =
builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
} else {
// shared memory
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory
buf =
Expand All @@ -667,10 +669,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
}
}
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
storage_info_[v].scope = runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
Expand Down
6 changes: 0 additions & 6 deletions src/tir/transforms/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,6 @@ class ThreadSyncInserter : public StmtExprMutator {
is_lead_ = PrimExpr();
}
return ret;
} else if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
storage_scope_[buf] = StorageScope::Create(op->value.as<StringImmNode>()->value);
return StmtExprMutator::VisitStmt_(op);
} else {
return StmtExprMutator::VisitStmt_(op);
}
Expand Down Expand Up @@ -335,8 +331,6 @@ class ThreadSyncInserter : public StmtExprMutator {
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageScope> storage_scope_;
// The read write statistics of storage
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_;
// The statistics for global barrier
Expand Down

0 comments on commit e20e195

Please sign in to comment.