Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,23 @@ class SpirvBuilder {
bool isPrecise, bool isNointerp, llvm::StringRef name = "",
llvm::Optional<SpirvInstruction *> init = llvm::None,
SourceLocation loc = {});

// Adds a variable to the module.
SpirvVariable *
addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
bool isPrecise, bool isNointerp, llvm::StringRef name = "",
llvm::Optional<SpirvInstruction *> init = llvm::None,
SourceLocation loc = {});

// Adds a variable to the module. It will be placed in the variable list
// before `pos`.
SpirvVariable *
addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
bool isPrecise, bool isNointerp, SpirvInstruction *before,
llvm::StringRef name = "",
llvm::Optional<SpirvInstruction *> init = llvm::None,
SourceLocation loc = {});

/// \brief Decorates the given target with the given location.
void decorateLocation(SpirvInstruction *target, uint32_t location);

Expand Down
4 changes: 4 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class SpirvModule {
// Adds a variable to the module.
void addVariable(SpirvVariable *);

// Adds a variable to the module immediately before `pos`.
// If `pos` is not found, `var` is added at the end of the variable list.
void addVariable(SpirvVariable *var, SpirvInstruction *pos);

// Adds a decoration to the module.
void addDecoration(SpirvDecoration *);

Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1861,7 +1861,7 @@ void DeclResultIdMapper::createCounterVar(
}

SpirvVariable *counterInstr = spvBuilder.addModuleVar(
counterType, sc, /*isPrecise*/ false, false, counterName);
counterType, sc, /*isPrecise*/ false, false, declInstr, counterName);

if (!isAlias) {
// Non-alias counter variables should be put in to resourceVars so that
Expand Down
148 changes: 78 additions & 70 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2644,80 +2644,15 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
// NodePayloadArray types
else if (const auto *npaType = dyn_cast<NodePayloadArrayType>(type)) {
const uint32_t elemTypeId = emitType(npaType->getElementType());

// Output the decorations for the type first. This will create other values
// that are on the decorations, and they must appear before the type.
emitDecorationsForNodePayloadArrayTypes(npaType, id);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this was moved for the same reason? But no test change seem related to the node payload?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. Some tests failed validation with the latest spirv-tools without this. This change will be tested once we update spirv-tools. I want it to be a separate change in case there is a regression. It will be easier to identify what caused the change.


initTypeInstruction(spv::Op::OpTypeNodePayloadArrayAMDX);
curTypeInst.push_back(id);
curTypeInst.push_back(elemTypeId);
finalizeTypeInstruction();

// Emit decorations
const ParmVarDecl *nodeDecl = npaType->getNodeDecl();
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
StringRef name = nodeDecl->getName();
unsigned index = 0;
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
name = nodeID->getName();
index = nodeID->getArrayIndex();
}

auto *str = new (context) SpirvConstantString(name);
uint32_t nodeName = getOrCreateConstantString(str);
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
llvm::None, true);
if (index) {
uint32_t baseIndex = getOrCreateConstantInt(
llvm::APInt(32, index), context.getUIntType(32), false);
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX,
{baseIndex}, llvm::None, true);
}
}

uint32_t maxRecords;
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
context.getUIntType(32), false);
} else {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
context.getUIntType(32), false);
}
emitDecoration(id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
llvm::None, true);

if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsSharedWithAttr>()) {
const DeclContext *dc = nodeDecl->getParentFunctionOrMethod();
if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
IdentifierInfo *ii = attr->getName();
bool alreadyExists = false;
for (auto *paramDecl : funDecl->params()) {
if (paramDecl->getIdentifier() == ii) {
assert(paramDecl != nodeDecl);
auto otherType = context.getNodeDeclPayloadType(paramDecl);
const uint32_t otherId =
getResultIdForType(otherType, &alreadyExists);
assert(alreadyExists && "forward references not allowed in "
"MaxRecordsSharedWith attribute");
emitDecoration(id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
{otherId}, llvm::None, true);
break;
}
}
assert(alreadyExists &&
"invalid reference in MaxRecordsSharedWith attribute");
}
}
if (const auto *attr = nodeDecl->getAttr<HLSLAllowSparseNodesAttr>()) {
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
llvm::None);
}
if (const auto *attr = nodeDecl->getAttr<HLSLUnboundedSparseNodesAttr>()) {
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
llvm::None);
}
if (const auto *attr = nodeDecl->getAttr<HLSLNodeArraySizeAttr>()) {
uint32_t arraySize = getOrCreateConstantInt(
llvm::APInt(32, attr->getCount()), context.getUIntType(32), false);
emitDecoration(id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
llvm::None, true);
}
}
// Structure types
else if (const auto *structType = dyn_cast<StructType>(type)) {
Expand Down Expand Up @@ -2998,5 +2933,78 @@ void EmitTypeHandler::emitNameForType(llvm::StringRef name,
nameInstr.end());
}

void EmitTypeHandler::emitDecorationsForNodePayloadArrayTypes(
const NodePayloadArrayType *npaType, uint32_t id) {
// Emit decorations
const ParmVarDecl *nodeDecl = npaType->getNodeDecl();
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
StringRef name = nodeDecl->getName();
unsigned index = 0;
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
name = nodeID->getName();
index = nodeID->getArrayIndex();
}

auto *str = new (context) SpirvConstantString(name);
uint32_t nodeName = getOrCreateConstantString(str);
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
llvm::None, true);
if (index) {
uint32_t baseIndex = getOrCreateConstantInt(
llvm::APInt(32, index), context.getUIntType(32), false);
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {baseIndex},
llvm::None, true);
}
}

uint32_t maxRecords;
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
context.getUIntType(32), false);
} else {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
context.getUIntType(32), false);
}
emitDecoration(id, spv::Decoration::NodeMaxPayloadsAMDX, {maxRecords},
llvm::None, true);

if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsSharedWithAttr>()) {
const DeclContext *dc = nodeDecl->getParentFunctionOrMethod();
if (const auto *funDecl = dyn_cast_or_null<FunctionDecl>(dc)) {
IdentifierInfo *ii = attr->getName();
bool alreadyExists = false;
for (auto *paramDecl : funDecl->params()) {
if (paramDecl->getIdentifier() == ii) {
assert(paramDecl != nodeDecl);
auto otherType = context.getNodeDeclPayloadType(paramDecl);
const uint32_t otherId =
getResultIdForType(otherType, &alreadyExists);
assert(alreadyExists && "forward references not allowed in "
"MaxRecordsSharedWith attribute");
emitDecoration(id, spv::Decoration::NodeSharesPayloadLimitsWithAMDX,
{otherId}, llvm::None, true);
break;
}
}
assert(alreadyExists &&
"invalid reference in MaxRecordsSharedWith attribute");
}
}
if (const auto *attr = nodeDecl->getAttr<HLSLAllowSparseNodesAttr>()) {
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
llvm::None);
}
if (const auto *attr = nodeDecl->getAttr<HLSLUnboundedSparseNodesAttr>()) {
emitDecoration(id, spv::Decoration::PayloadNodeSparseArrayAMDX, {},
llvm::None);
}
if (const auto *attr = nodeDecl->getAttr<HLSLNodeArraySizeAttr>()) {
uint32_t arraySize = getOrCreateConstantInt(
llvm::APInt(32, attr->getCount()), context.getUIntType(32), false);
emitDecoration(id, spv::Decoration::PayloadNodeArraySizeAMDX, {arraySize},
llvm::None, true);
}
}

} // end namespace spirv
} // end namespace clang
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class EmitTypeHandler {
void emitNameForType(llvm::StringRef name, uint32_t targetTypeId,
llvm::Optional<uint32_t> memberIndex = llvm::None);

void
emitDecorationsForNodePayloadArrayTypes(const NodePayloadArrayType *npaType,
uint32_t id);

// There is no guarantee that an instruction or a function or a basic block
// has been assigned result-id. This method returns the result-id for the
// given object. If a result-id has not been assigned yet, it'll assign
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,21 @@ SpirvVariable *SpirvBuilder::addModuleVar(
return var;
}

SpirvVariable *SpirvBuilder::addModuleVar(
const SpirvType *type, spv::StorageClass storageClass, bool isPrecise,
bool isNointerp, SpirvInstruction *pos, llvm::StringRef name,
llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {
assert(storageClass != spv::StorageClass::Function);
// Note: We store the underlying type in the variable, *not* the pointer type.
auto *var = new (context)
SpirvVariable(type, loc, storageClass, isPrecise, isNointerp,
init.hasValue() ? init.getValue() : nullptr);
var->setResultType(type);
var->setDebugName(name);
mod->addVariable(var, pos);
return var;
}

void SpirvBuilder::decorateLocation(SpirvInstruction *target,
uint32_t location) {
auto *decor =
Expand Down
6 changes: 6 additions & 0 deletions tools/clang/lib/SPIRV/SpirvModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ void SpirvModule::addVariable(SpirvVariable *var) {
variables.push_back(var);
}

void SpirvModule::addVariable(SpirvVariable *var, SpirvInstruction *pos) {
assert(var && "cannot add null variable to the module");
auto location = std::find(variables.begin(), variables.end(), pos);
variables.insert(location, var);
}

void SpirvModule::addDecoration(SpirvDecoration *decor) {
assert(decor && "cannot add null decoration to the module");
decorations.insert(decor);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s

// CHECK: %counter_var_t_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
// CHECK: %counter_var_rw = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
// CHECK: %counter_var_t_1_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
// CHECK: %counter_var_s_0 = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private

RWStructuredBuffer<uint> rw : register(u0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ RWStructuredBuffer<S1> selectRWSBuffer(RWStructuredBuffer<S1> paramRWSBu
AppendStructuredBuffer<S2> selectASBuffer(AppendStructuredBuffer<S2> paramASBuffer, bool selector);
ConsumeStructuredBuffer<S3> selectCSBuffer(ConsumeStructuredBuffer<S3> paramCSBuffer, bool selector);

// CHECK: %counter_var_globalRWSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
RWStructuredBuffer<S1> globalRWSBuffer;
// CHECK: %counter_var_globalASBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
AppendStructuredBuffer<S2> globalASBuffer;
Expand All @@ -38,7 +39,6 @@ static RWStructuredBuffer<S1> staticgRWSBuffer = globalRWSBuffer;
// CHECK: %counter_var_staticgASBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
static AppendStructuredBuffer<S2> staticgASBuffer = globalASBuffer;
// CHECK: %counter_var_staticgCSBuffer = OpVariable %_ptr_Private__ptr_Uniform_type_ACSBuffer_counter Private
// CHECK: %counter_var_globalRWSBuffer = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
static ConsumeStructuredBuffer<S3> staticgCSBuffer = globalCSBuffer;

// Counter variables for function returns, function parameters, and local variables have an extra level of pointer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct T {
// CHECK: OpDecorate %counter_var_myAppendStructuredBuffer DescriptorSet 0
// CHECK: OpDecorate %counter_var_myAppendStructuredBuffer Binding 1

// CHECK: %myAppendStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_AppendStructuredBuffer_T_uint_5 Uniform
// CHECK: %counter_var_myAppendStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_5 Uniform
// CHECK: %myAppendStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_AppendStructuredBuffer_T_uint_5 Uniform
AppendStructuredBuffer<T> myAppendStructuredBuffer[5];

void main() {}
Expand Down
28 changes: 14 additions & 14 deletions tools/clang/test/CodeGenSPIRV/type.append-structured-buffer.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@

// CHECK: OpExtension "SPV_GOOGLE_hlsl_functionality1"

// CHECK: OpName %type_AppendStructuredBuffer_v4float "type.AppendStructuredBuffer.v4float"
// CHECK: OpName %buffer1 "buffer1"
// CHECK-DAG: OpName %type_AppendStructuredBuffer_v4float "type.AppendStructuredBuffer.v4float"
// CHECK-DAG: OpName %buffer1 "buffer1"

// CHECK: OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
// CHECK: OpName %counter_var_buffer1 "counter.var.buffer1"
// CHECK-DAG: OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
// CHECK-DAG: OpName %counter_var_buffer1 "counter.var.buffer1"

// CHECK: OpName %type_AppendStructuredBuffer_S "type.AppendStructuredBuffer.S"
// CHECK: OpName %buffer2 "buffer2"
// CHECK-DAG: OpName %type_AppendStructuredBuffer_S "type.AppendStructuredBuffer.S"
// CHECK-DAG: OpName %buffer2 "buffer2"

// CHECK: OpName %counter_var_buffer2 "counter.var.buffer2"
// CHECK-DAG: OpName %counter_var_buffer2 "counter.var.buffer2"

// CHECK: OpDecorateId %buffer1 CounterBuffer %counter_var_buffer1
// CHECK: OpDecorateId %buffer2 CounterBuffer %counter_var_buffer2

// CHECK: %type_AppendStructuredBuffer_v4float = OpTypeStruct %_runtimearr_v4float
// CHECK: %_ptr_Uniform_type_AppendStructuredBuffer_v4float = OpTypePointer Uniform %type_AppendStructuredBuffer_v4float
// CHECK-DAG: OpDecorateId %buffer1 CounterBuffer %counter_var_buffer1
// CHECK-DAG: OpDecorateId %buffer2 CounterBuffer %counter_var_buffer2

// CHECK: %type_ACSBuffer_counter = OpTypeStruct %int
// CHECK: %_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter

// CHECK: %type_AppendStructuredBuffer_v4float = OpTypeStruct %_runtimearr_v4float
// CHECK: %_ptr_Uniform_type_AppendStructuredBuffer_v4float = OpTypePointer Uniform %type_AppendStructuredBuffer_v4float

// CHECK: %type_AppendStructuredBuffer_S = OpTypeStruct %_runtimearr_S
// CHECK: %_ptr_Uniform_type_AppendStructuredBuffer_S = OpTypePointer Uniform %type_AppendStructuredBuffer_S

// CHECK: %buffer1 = OpVariable %_ptr_Uniform_type_AppendStructuredBuffer_v4float Uniform
// CHECK: %counter_var_buffer1 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
// CHECK: %buffer1 = OpVariable %_ptr_Uniform_type_AppendStructuredBuffer_v4float Uniform
AppendStructuredBuffer<float4> buffer1;

struct S {
Expand All @@ -35,8 +35,8 @@ struct S {
float2x3 c;
};

// CHECK: %buffer2 = OpVariable %_ptr_Uniform_type_AppendStructuredBuffer_S Uniform
// CHECK: %counter_var_buffer2 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
// CHECK: %buffer2 = OpVariable %_ptr_Uniform_type_AppendStructuredBuffer_S Uniform
AppendStructuredBuffer<S> buffer2;

float main() : A {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct T {
// CHECK: OpDecorate %counter_var_myConsumeStructuredBuffer DescriptorSet 0
// CHECK: OpDecorate %counter_var_myConsumeStructuredBuffer Binding 1

// CHECK: %myConsumeStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_ConsumeStructuredBuffer_T_uint_2 Uniform
// CHECK: %counter_var_myConsumeStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_2 Uniform
// CHECK: %myConsumeStructuredBuffer = OpVariable %_ptr_Uniform__arr_type_ConsumeStructuredBuffer_T_uint_2 Uniform
ConsumeStructuredBuffer<T> myConsumeStructuredBuffer[2];

void main() {}
Expand Down
Loading
Loading