diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 8688f5e4194..b0d88e95432 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -40,6 +40,12 @@ static inline Type* convert(SlangReflectionType* type) static inline SlangReflectionType* convert(Type* type) { + // Prevent the AtomicType struct from being visible to the user + // through the reflection API. + if (auto atomicType = as(type)) + { + return (SlangReflectionType*)atomicType->getElementType(); + } return (SlangReflectionType*)type; } @@ -586,7 +592,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy if (auto arrayType = as(type)) { - return (SlangReflectionType*)arrayType->getElementType(); + return convert(arrayType->getElementType()); } else if (auto parameterGroupType = as(type)) { @@ -1026,7 +1032,7 @@ SLANG_API SlangReflectionType* spReflection_FindTypeByName( if (as(result)) return nullptr; - return (SlangReflectionType*)result; + return convert(result); } catch (...) { @@ -1158,7 +1164,7 @@ SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType( if (!typeLayout) return nullptr; - return (SlangReflectionType*)typeLayout->type; + return convert(typeLayout->type); } SLANG_API SlangTypeKind spReflectionTypeLayout_getKind(SlangReflectionTypeLayout* inTypeLayout) @@ -4227,7 +4233,7 @@ SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex( { auto constraints = globalGenericParamDecl->getMembersOfType(); - return (SlangReflectionType*)constraints[index]->sup.Ptr(); + return convert(constraints[index]->sup.Ptr()); } // TODO: Add case for entry-point generic parameters. } diff --git a/tools/slang-unit-test/unit-test-atomic-reflection.cpp b/tools/slang-unit-test/unit-test-atomic-reflection.cpp new file mode 100644 index 00000000000..9817f2aa619 --- /dev/null +++ b/tools/slang-unit-test/unit-test-atomic-reflection.cpp @@ -0,0 +1,135 @@ +// unit-test-atomic-reflection.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include + +using namespace Slang; + +SLANG_UNIT_TEST(atomicReflection) +{ + const char* userSourceBody = R"( + RWStructuredBuffer bufA; // for reference + RWStructuredBuffer> bufB; + struct TestStruct { + Atomic fieldA; + Atomic fieldB[2]; + vector, 2> fieldC; + }; + RWStructuredBuffer bufC; + + [shader("compute")] + [numthreads(64, 1, 1)] + void main(uint2 dispatchThreadId : SV_DispatchThreadID) + { + } + )"; + String userSource = userSourceBody; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto reflection = module->getLayout(); + + auto bufAVariableLayout = reflection->getParameterByIndex(0); + auto bufBVariableLayout = reflection->getParameterByIndex(1); + auto bufCVariableLayout = reflection->getParameterByIndex(2); + + SLANG_CHECK(bufAVariableLayout != nullptr); + SLANG_CHECK(bufBVariableLayout != nullptr); + SLANG_CHECK(bufCVariableLayout != nullptr); + + // This test walks down the parameters in a different way from the JSON + // option. The slangc option "-reflection-json" walks: + // param->getTypeLayout()->getElementTypeLayout()->getType()->getKind() + // This test uses an alternate walk: + // param->getVariable()->getType()->getElementType()->getKind() + + auto bufAVariable = bufAVariableLayout->getVariable(); + auto bufBVariable = bufBVariableLayout->getVariable(); + auto bufCVariable = bufCVariableLayout->getVariable(); + + SLANG_CHECK(bufAVariable != nullptr); + SLANG_CHECK(bufBVariable != nullptr); + SLANG_CHECK(bufCVariable != nullptr); + + auto bufAType = bufAVariable->getType(); + auto bufBType = bufBVariable->getType(); + auto bufCType = bufCVariable->getType(); + + SLANG_CHECK(bufAType->getKind() == slang::TypeReflection::Kind::Resource); + SLANG_CHECK(bufBType->getKind() == slang::TypeReflection::Kind::Resource); + SLANG_CHECK(bufCType->getKind() == slang::TypeReflection::Kind::Resource); + + auto bufAElementType = bufAType->getElementType(); + auto bufBElementType = bufBType->getElementType(); + auto bufCElementType = bufCType->getElementType(); + + SLANG_CHECK(bufAElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufBElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufAElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + SLANG_CHECK(bufBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + auto bufAResResultType = bufAType->getResourceResultType(); + auto bufBResResultType = bufBType->getResourceResultType(); + + SLANG_CHECK(bufAResResultType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufBResResultType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(bufAResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + SLANG_CHECK(bufBResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + // Atomics embedded in structs require traversing the fields + SLANG_CHECK(bufCElementType->getKind() == slang::TypeReflection::Kind::Struct); + + auto bufCFieldCount = bufCElementType->getFieldCount(); + SLANG_CHECK(bufCFieldCount == 3); + + auto fieldA = bufCElementType->getFieldByIndex(0); + auto fieldB = bufCElementType->getFieldByIndex(1); + auto fieldC = bufCElementType->getFieldByIndex(2); + + SLANG_CHECK(fieldA != nullptr); + SLANG_CHECK(fieldB != nullptr); + SLANG_CHECK(fieldC != nullptr); + + auto fieldAType = fieldA->getType(); + auto fieldBType = fieldB->getType(); + auto fieldCType = fieldC->getType(); + + SLANG_CHECK(fieldAType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldAType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + SLANG_CHECK(fieldBType->getKind() == slang::TypeReflection::Kind::Array); + SLANG_CHECK(fieldBType->getElementCount() == 2); + auto fieldBElementType = fieldBType->getElementType(); + SLANG_CHECK(fieldBElementType != nullptr); + SLANG_CHECK(fieldBElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); + + SLANG_CHECK(fieldCType->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(fieldCType->getElementCount() == 2); + auto fieldCElementType = fieldCType->getElementType(); + SLANG_CHECK(fieldCElementType != nullptr); + SLANG_CHECK(fieldCElementType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(fieldCElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32); +}