Skip to content

Commit

Permalink
refactor special registers
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 29, 2024
1 parent 621eb6f commit e0d5b76
Show file tree
Hide file tree
Showing 12 changed files with 623 additions and 530 deletions.
1 change: 1 addition & 0 deletions include/luisa/luisa-compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
#include <luisa/xir/passes/sink_alloca.h>
#include <luisa/xir/passes/trace_gep.h>
#include <luisa/xir/pool.h>
#include <luisa/xir/special_register.h>
#include <luisa/xir/translators/ast2xir.h>
#include <luisa/xir/translators/json2xir.h>
#include <luisa/xir/translators/xir2json.h>
Expand Down
1 change: 1 addition & 0 deletions include/luisa/xir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <luisa/ast/type_registry.h>
#include <luisa/xir/constant.h>
#include <luisa/xir/special_register.h>
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/assert.h>
#include <luisa/xir/instructions/assume.h>
Expand Down
11 changes: 0 additions & 11 deletions include/luisa/xir/instructions/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,6 @@ enum struct IntrinsicOp {
BINARY_EQUAL,
BINARY_NOT_EQUAL,

// thread coordination
THREAD_ID,
BLOCK_ID,
WARP_LANE_ID,
DISPATCH_ID,
KERNEL_ID,
OBJECT_ID,
BLOCK_SIZE,
WARP_SIZE,
DISPATCH_SIZE,

// block synchronization
SYNCHRONIZE_BLOCK,// ()

Expand Down
87 changes: 87 additions & 0 deletions include/luisa/xir/special_register.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#pragma once

#include <luisa/xir/value.h>

namespace luisa::compute::xir {

enum struct DerivedSpecialRegisterTag {
THREAD_ID,
BLOCK_ID,
WARP_LANE_ID,
DISPATCH_ID,
KERNEL_ID,
OBJECT_ID,
BLOCK_SIZE,
WARP_SIZE,
DISPATCH_SIZE,
};

[[nodiscard]] constexpr luisa::string_view to_string(DerivedSpecialRegisterTag tag) noexcept {
using namespace std::string_view_literals;
switch (tag) {
case DerivedSpecialRegisterTag::THREAD_ID: return "thread_id"sv;
case DerivedSpecialRegisterTag::BLOCK_ID: return "block_id"sv;
case DerivedSpecialRegisterTag::WARP_LANE_ID: return "warp_lane_id"sv;
case DerivedSpecialRegisterTag::DISPATCH_ID: return "dispatch_id"sv;
case DerivedSpecialRegisterTag::KERNEL_ID: return "kernel_id"sv;
case DerivedSpecialRegisterTag::OBJECT_ID: return "object_id"sv;
case DerivedSpecialRegisterTag::BLOCK_SIZE: return "block_size"sv;
case DerivedSpecialRegisterTag::WARP_SIZE: return "warp_size"sv;
case DerivedSpecialRegisterTag::DISPATCH_SIZE: return "dispatch_size"sv;
}
return "unknown"sv;
}

class LC_XIR_API SpecialRegister : public DerivedValue<DerivedValueTag::SPECIAL_REGISTER> {
public:
explicit SpecialRegister(const Type *type) noexcept : DerivedValue{type} {}
[[nodiscard]] virtual DerivedSpecialRegisterTag derived_special_register_tag() const noexcept = 0;
[[nodiscard]] static SpecialRegister *create(DerivedSpecialRegisterTag tag) noexcept;
};

namespace detail {

[[nodiscard]] LC_XIR_API const Type *special_register_type_uint() noexcept;
[[nodiscard]] LC_XIR_API const Type *special_register_type_uint3() noexcept;

template<typename T>
[[nodiscard]] auto get_special_register_type() noexcept {
if constexpr (std::is_same_v<T, uint>) {
return special_register_type_uint();
} else if constexpr (std::is_same_v<T, uint3>) {
return special_register_type_uint3();
} else {
static_assert(always_false_v<T>, "Unsupported special register type.");
}
}

}// namespace detail

template<typename T, DerivedSpecialRegisterTag tag>
class DerivedSpecialRegister : public SpecialRegister {
public:
DerivedSpecialRegister() noexcept : SpecialRegister{detail::get_special_register_type<T>()} {}
[[nodiscard]] static constexpr auto
static_derived_special_register_tag() noexcept { return tag; }
[[nodiscard]] DerivedSpecialRegisterTag
derived_special_register_tag() const noexcept final {
return static_derived_special_register_tag();
}
[[nodiscard]] static auto create() noexcept {
return static_cast<DerivedSpecialRegister *>(SpecialRegister::create(tag));
}
};

// special registers
// note that we add the `SPR` prefix to avoid potential name conflicts with macros
using SPR_ThreadID = DerivedSpecialRegister<uint3, DerivedSpecialRegisterTag::THREAD_ID>;
using SPR_BlockID = DerivedSpecialRegister<uint3, DerivedSpecialRegisterTag::BLOCK_ID>;
using SPR_WarpLaneID = DerivedSpecialRegister<uint, DerivedSpecialRegisterTag::WARP_LANE_ID>;
using SPR_DispatchID = DerivedSpecialRegister<uint3, DerivedSpecialRegisterTag::DISPATCH_ID>;
using SPR_KernelID = DerivedSpecialRegister<uint, DerivedSpecialRegisterTag::KERNEL_ID>;
using SPR_ObjectID = DerivedSpecialRegister<uint, DerivedSpecialRegisterTag::OBJECT_ID>;
using SPR_BlockSize = DerivedSpecialRegister<uint3, DerivedSpecialRegisterTag::BLOCK_SIZE>;
using SPR_WarpSize = DerivedSpecialRegister<uint, DerivedSpecialRegisterTag::WARP_SIZE>;
using SPR_DispatchSize = DerivedSpecialRegister<uint3, DerivedSpecialRegisterTag::DISPATCH_SIZE>;

}// namespace luisa::compute::xir
1 change: 1 addition & 0 deletions include/luisa/xir/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum struct DerivedValueTag {
INSTRUCTION,
CONSTANT,
ARGUMENT,
SPECIAL_REGISTER,
};

class LC_XIR_API Value : public PooledObject,
Expand Down
29 changes: 20 additions & 9 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ class FallbackCodegen {
}
}

[[nodiscard]] llvm::Value *_translate_special_register(CurrentFunction &current, IRBuilder &b,
const xir::SpecialRegister *sreg) noexcept {
switch (sreg->derived_special_register_tag()) {
case xir::DerivedSpecialRegisterTag::THREAD_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_thread_id];
case xir::DerivedSpecialRegisterTag::BLOCK_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_block_id];
case xir::DerivedSpecialRegisterTag::WARP_LANE_ID: return llvm::ConstantInt::get(b.getInt32Ty(), 0);// CPU only has one lane
case xir::DerivedSpecialRegisterTag::DISPATCH_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_dispatch_id];
case xir::DerivedSpecialRegisterTag::KERNEL_ID: LUISA_NOT_IMPLEMENTED();
case xir::DerivedSpecialRegisterTag::OBJECT_ID: LUISA_NOT_IMPLEMENTED();
case xir::DerivedSpecialRegisterTag::BLOCK_SIZE: return current.builtin_variables[CurrentFunction::builtin_variable_index_block_size];
case xir::DerivedSpecialRegisterTag::WARP_SIZE: return llvm::ConstantInt::get(b.getInt32Ty(), 1);// CPU only has one lane
case xir::DerivedSpecialRegisterTag::DISPATCH_SIZE: return current.builtin_variables[CurrentFunction::builtin_variable_index_dispatch_size];
}
LUISA_ERROR_WITH_LOCATION("Invalid special register.");
}

[[nodiscard]] llvm::Value *_lookup_value(CurrentFunction &current, IRBuilder &b, const xir::Value *v, bool load_global = true) noexcept {
LUISA_ASSERT(v != nullptr, "Value is null.");
switch (v->derived_value_tag()) {
Expand All @@ -348,6 +364,10 @@ class FallbackCodegen {
LUISA_ASSERT(iter != current.value_map.end(), "Value not found.");
return iter->second;
}
case xir::DerivedValueTag::SPECIAL_REGISTER: {
auto sreg = static_cast<const xir::SpecialRegister *>(v);
return _translate_special_register(current, b, sreg);
}
}
LUISA_ERROR_WITH_LOCATION("Invalid value.");
}
Expand Down Expand Up @@ -1959,15 +1979,6 @@ class FallbackCodegen {
case xir::IntrinsicOp::BINARY_GREATER_EQUAL: return _translate_binary_greater_equal(current, b, inst->operand(0u), inst->operand(1u));
case xir::IntrinsicOp::BINARY_EQUAL: return _translate_binary_equal(current, b, inst->operand(0u), inst->operand(1u));
case xir::IntrinsicOp::BINARY_NOT_EQUAL: return _translate_binary_not_equal(current, b, inst->operand(0u), inst->operand(1u));
case xir::IntrinsicOp::THREAD_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_thread_id];
case xir::IntrinsicOp::BLOCK_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_block_id];
case xir::IntrinsicOp::WARP_LANE_ID: return llvm::ConstantInt::get(b.getInt32Ty(), 0);// CPU only has one lane
case xir::IntrinsicOp::DISPATCH_ID: return current.builtin_variables[CurrentFunction::builtin_variable_index_dispatch_id];
case xir::IntrinsicOp::KERNEL_ID: LUISA_NOT_IMPLEMENTED();
case xir::IntrinsicOp::OBJECT_ID: LUISA_NOT_IMPLEMENTED();
case xir::IntrinsicOp::BLOCK_SIZE: return current.builtin_variables[CurrentFunction::builtin_variable_index_block_size];
case xir::IntrinsicOp::WARP_SIZE: return llvm::ConstantInt::get(b.getInt32Ty(), 1);// CPU only has one lane
case xir::IntrinsicOp::DISPATCH_SIZE: return current.builtin_variables[CurrentFunction::builtin_variable_index_dispatch_size];
case xir::IntrinsicOp::SYNCHRONIZE_BLOCK: LUISA_NOT_IMPLEMENTED();
case xir::IntrinsicOp::ALL: {
auto llvm_operand = _lookup_value(current, b, inst->operand(0u));
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_xir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ int main() {
b.set_insertion_point(f->create_body_block());
auto add = b.call(Type::of<float>(), xir::IntrinsicOp::BINARY_MUL, {x, y});
auto mul = b.call(Type::of<float>(), xir::IntrinsicOp::BINARY_ADD, {add, y});
auto coord = b.call(Type::of<uint3>(), xir::IntrinsicOp::DISPATCH_ID, {});
auto coord = xir::SPR_DispatchID::create();
auto coord_x = b.call(Type::of<uint>(), xir::IntrinsicOp::EXTRACT, {coord, u32_zero});
auto outline = b.outline();
auto outline_body = outline->create_target_block();
Expand Down
1 change: 1 addition & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
metadata.cpp
module.cpp
pool.cpp
special_register.cpp
use.cpp
user.cpp
value.cpp
Expand Down
Loading

0 comments on commit e0d5b76

Please sign in to comment.