Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions source/slang/slang-reflection-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ static inline Type* convert(SlangReflectionType* type)

static inline SlangReflectionType* convert(Type* type)
{
// Prevent the AtomicType struct from being visible to the user
// through the reflection API.
if (auto atomicType = as<AtomicType>(type))
{
return (SlangReflectionType*)atomicType->getElementType();
}
return (SlangReflectionType*)type;
}

Expand Down Expand Up @@ -586,7 +592,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy

if (auto arrayType = as<ArrayExpressionType>(type))
{
return (SlangReflectionType*)arrayType->getElementType();
return convert(arrayType->getElementType());
}
else if (auto parameterGroupType = as<ParameterGroupType>(type))
{
Expand Down Expand Up @@ -1026,7 +1032,7 @@ SLANG_API SlangReflectionType* spReflection_FindTypeByName(

if (as<ErrorType>(result))
return nullptr;
return (SlangReflectionType*)result;
return convert(result);
}
catch (...)
{
Expand Down Expand Up @@ -1158,7 +1164,7 @@ SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(
if (!typeLayout)
return nullptr;

return (SlangReflectionType*)typeLayout->type;
return convert(typeLayout->type);
}

SLANG_API SlangTypeKind spReflectionTypeLayout_getKind(SlangReflectionTypeLayout* inTypeLayout)
Expand Down Expand Up @@ -4227,7 +4233,7 @@ SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(
{
auto constraints =
globalGenericParamDecl->getMembersOfType<GenericTypeConstraintDecl>();
return (SlangReflectionType*)constraints[index]->sup.Ptr();
return convert(constraints[index]->sup.Ptr());
}
// TODO: Add case for entry-point generic parameters.
}
Expand Down
135 changes: 135 additions & 0 deletions tools/slang-unit-test/unit-test-atomic-reflection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// unit-test-atomic-reflection.cpp

#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

SLANG_UNIT_TEST(atomicReflection)
{
const char* userSourceBody = R"(
RWStructuredBuffer<uint> bufA; // for reference
RWStructuredBuffer<Atomic<uint>> bufB;
struct TestStruct {
Atomic<uint> fieldA;
Atomic<uint> fieldB[2];
vector<Atomic<uint>, 2> fieldC;
};
RWStructuredBuffer<TestStruct> bufC;

[shader("compute")]
[numthreads(64, 1, 1)]
void main(uint2 dispatchThreadId : SV_DispatchThreadID)
{
}
)";
String userSource = userSourceBody;
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_HLSL;
targetDesc.profile = globalSession->findProfile("sm_5_0");
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

auto reflection = module->getLayout();

auto bufAVariableLayout = reflection->getParameterByIndex(0);
auto bufBVariableLayout = reflection->getParameterByIndex(1);
auto bufCVariableLayout = reflection->getParameterByIndex(2);

SLANG_CHECK(bufAVariableLayout != nullptr);
SLANG_CHECK(bufBVariableLayout != nullptr);
SLANG_CHECK(bufCVariableLayout != nullptr);

// This test walks down the parameters in a different way from the JSON
// option. The slangc option "-reflection-json" walks:
// param->getTypeLayout()->getElementTypeLayout()->getType()->getKind()
// This test uses an alternate walk:
// param->getVariable()->getType()->getElementType()->getKind()

auto bufAVariable = bufAVariableLayout->getVariable();
auto bufBVariable = bufBVariableLayout->getVariable();
auto bufCVariable = bufCVariableLayout->getVariable();

SLANG_CHECK(bufAVariable != nullptr);
SLANG_CHECK(bufBVariable != nullptr);
SLANG_CHECK(bufCVariable != nullptr);

auto bufAType = bufAVariable->getType();
auto bufBType = bufBVariable->getType();
auto bufCType = bufCVariable->getType();

SLANG_CHECK(bufAType->getKind() == slang::TypeReflection::Kind::Resource);
SLANG_CHECK(bufBType->getKind() == slang::TypeReflection::Kind::Resource);
SLANG_CHECK(bufCType->getKind() == slang::TypeReflection::Kind::Resource);

auto bufAElementType = bufAType->getElementType();
auto bufBElementType = bufBType->getElementType();
auto bufCElementType = bufCType->getElementType();

SLANG_CHECK(bufAElementType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(bufBElementType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(bufAElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
SLANG_CHECK(bufBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);

auto bufAResResultType = bufAType->getResourceResultType();
auto bufBResResultType = bufBType->getResourceResultType();

SLANG_CHECK(bufAResResultType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(bufBResResultType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(bufAResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
SLANG_CHECK(bufBResResultType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);

// Atomics embedded in structs require traversing the fields
SLANG_CHECK(bufCElementType->getKind() == slang::TypeReflection::Kind::Struct);

auto bufCFieldCount = bufCElementType->getFieldCount();
SLANG_CHECK(bufCFieldCount == 3);

auto fieldA = bufCElementType->getFieldByIndex(0);
auto fieldB = bufCElementType->getFieldByIndex(1);
auto fieldC = bufCElementType->getFieldByIndex(2);

SLANG_CHECK(fieldA != nullptr);
SLANG_CHECK(fieldB != nullptr);
SLANG_CHECK(fieldC != nullptr);

auto fieldAType = fieldA->getType();
auto fieldBType = fieldB->getType();
auto fieldCType = fieldC->getType();

SLANG_CHECK(fieldAType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(fieldAType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);

SLANG_CHECK(fieldBType->getKind() == slang::TypeReflection::Kind::Array);
SLANG_CHECK(fieldBType->getElementCount() == 2);
auto fieldBElementType = fieldBType->getElementType();
SLANG_CHECK(fieldBElementType != nullptr);
SLANG_CHECK(fieldBElementType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(fieldBElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);

SLANG_CHECK(fieldCType->getKind() == slang::TypeReflection::Kind::Vector);
SLANG_CHECK(fieldCType->getElementCount() == 2);
auto fieldCElementType = fieldCType->getElementType();
SLANG_CHECK(fieldCElementType != nullptr);
SLANG_CHECK(fieldCElementType->getKind() == slang::TypeReflection::Kind::Scalar);
SLANG_CHECK(fieldCElementType->getScalarType() == slang::TypeReflection::ScalarType::UInt32);
}
Loading