Skip to content

Commit

Permalink
wip: refactor xir
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 29, 2024
1 parent e0d5b76 commit 33292bb
Show file tree
Hide file tree
Showing 18 changed files with 109 additions and 36 deletions.
1 change: 1 addition & 0 deletions include/luisa/luisa-compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/instructions/cast.h>
#include <luisa/xir/instructions/clock.h>
#include <luisa/xir/instructions/continue.h>
#include <luisa/xir/instructions/gep.h>
#include <luisa/xir/instructions/if.h>
Expand Down
7 changes: 6 additions & 1 deletion include/luisa/xir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

#include <luisa/ast/type_registry.h>
#include <luisa/xir/constant.h>
#include <luisa/xir/special_register.h>
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/assert.h>
#include <luisa/xir/instructions/assume.h>
#include <luisa/xir/instructions/branch.h>
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/instructions/cast.h>
#include <luisa/xir/instructions/clock.h>
#include <luisa/xir/instructions/continue.h>
#include <luisa/xir/instructions/gep.h>
#include <luisa/xir/instructions/if.h>
Expand All @@ -25,6 +25,9 @@
#include <luisa/xir/instructions/switch.h>
#include <luisa/xir/instructions/unreachable.h>

namespace luisa::compute::xir {
class ClockInst;
}// namespace luisa::compute::xir
namespace luisa::compute::xir {

class LC_XIR_API Builder {
Expand Down Expand Up @@ -101,6 +104,8 @@ class LC_XIR_API Builder {
LoadInst *load(const Type *type, Value *variable) noexcept;
StoreInst *store(Value *variable, Value *value) noexcept;

ClockInst *clock() noexcept;

OutlineInst *outline() noexcept;
RayQueryInst *ray_query(Value *query_object) noexcept;
};
Expand Down
46 changes: 39 additions & 7 deletions include/luisa/xir/instruction.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "instruction.h"

#include <luisa/xir/user.h>

namespace luisa::compute::xir {
Expand All @@ -24,17 +26,31 @@ enum struct DerivedInstructionTag {
RETURN, // basic block terminator: return (early returns are removed after control flow normalization)
PHI, // basic block beginning: phi nodes

// variable instructions
/* variable instructions */
ALLOCA,
LOAD,
STORE,
GEP,

/* atomic instructions */
ATOMIC,// operates on buffers / shared memory

/* ALU (arithmetic logic unit) instructions */
ALU,// all are pure functions, free to move/eliminate

/* CTA (cooperative thread array) instructions */
CTA,// volatile, may involve synchronization and cannot be moved/eliminated

/* resource instructions */
RESOURCE,

/* other instructions */
CALL, // user or external function calls
CALL, // user or external function calls
CAST, // type casts
PRINT,// kernel print
CLOCK,// kernel clock

INTRINSIC,// intrinsic function calls
CAST, // type casts
PRINT, // kernel print

ASSERT,// assertion
ASSUME,// assumption
Expand All @@ -44,6 +60,8 @@ enum struct DerivedInstructionTag {
RAY_QUERY,// ray queries
};

class ControlFlowMerge;

class LC_XIR_API Instruction : public IntrusiveNode<Instruction, DerivedValue<DerivedValueTag::INSTRUCTION, User>> {

private:
Expand All @@ -70,6 +88,9 @@ class LC_XIR_API Instruction : public IntrusiveNode<Instruction, DerivedValue<De
[[nodiscard]] virtual bool is_terminator() const noexcept { return false; }
[[nodiscard]] BasicBlock *parent_block() noexcept { return _parent_block; }
[[nodiscard]] const BasicBlock *parent_block() const noexcept { return _parent_block; }

[[nodiscard]] virtual ControlFlowMerge *control_flow_merge() noexcept { return nullptr; }
[[nodiscard]] const ControlFlowMerge *control_flow_merge() const noexcept;
};

using InstructionList = InlineIntrusiveList<Instruction>;
Expand Down Expand Up @@ -158,14 +179,14 @@ class DerivedConditionalBranchInstruction : public DerivedInstruction<tag, Condi
using DerivedInstruction<tag, ConditionalBranchTerminatorInstruction>::DerivedInstruction;
};

class LC_XIR_API InstructionMergeMixin {
class LC_XIR_API ControlFlowMerge {

private:
BasicBlock *_merge_block{nullptr};

protected:
InstructionMergeMixin() noexcept = default;
~InstructionMergeMixin() noexcept = default;
ControlFlowMerge() noexcept = default;
~ControlFlowMerge() noexcept = default;

public:
void set_merge_block(BasicBlock *block) noexcept { _merge_block = block; }
Expand All @@ -174,4 +195,15 @@ class LC_XIR_API InstructionMergeMixin {
BasicBlock *create_merge_block(bool overwrite_existing = false) noexcept;
};

template<typename Base>
requires std::derived_from<Base, Instruction>
class ControlFlowMergeMixin : public Base,
public ControlFlowMerge {
public:
using Base::Base;
[[nodiscard]] ControlFlowMerge *control_flow_merge() noexcept final {
return static_cast<ControlFlowMerge *>(this);
}
};

}// namespace luisa::compute::xir
12 changes: 12 additions & 0 deletions include/luisa/xir/instructions/clock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <luisa/xir/instruction.h>

namespace luisa::compute::xir {

class LC_XIR_API ClockInst : public DerivedInstruction<DerivedInstructionTag::CLOCK> {
public:
ClockInst() noexcept;
};

}// namespace luisa::compute::xir
5 changes: 2 additions & 3 deletions include/luisa/xir/instructions/if.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ class BasicBlock;
// { merge_block }
//
// Note: this instruction must be the terminator of a basic block.
class IfInst final : public DerivedConditionalBranchInstruction<DerivedInstructionTag::IF>,
public InstructionMergeMixin {
class IfInst final : public ControlFlowMergeMixin<DerivedConditionalBranchInstruction<DerivedInstructionTag::IF>> {
public:
using DerivedConditionalBranchInstruction::DerivedConditionalBranchInstruction;
using ControlFlowMergeMixin::ControlFlowMergeMixin;
};

}// namespace luisa::compute::xir
3 changes: 0 additions & 3 deletions include/luisa/xir/instructions/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,6 @@ enum struct IntrinsicOp {

// shader execution re-ordering
SHADER_EXECUTION_REORDER,// (uint hint, uint hint_bits): void

// clock
CLOCK,// (): uint64
};

[[nodiscard]] LC_XIR_API luisa::string to_string(IntrinsicOp op) noexcept;
Expand Down
6 changes: 2 additions & 4 deletions include/luisa/xir/instructions/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
namespace luisa::compute::xir {

// Note: this instruction must be the terminator of a basic block.
class LC_XIR_API LoopInst final : public DerivedTerminatorInstruction<DerivedInstructionTag::LOOP>,
public InstructionMergeMixin {
class LC_XIR_API LoopInst final : public ControlFlowMergeMixin<DerivedTerminatorInstruction<DerivedInstructionTag::LOOP>> {

public:
static constexpr size_t operand_index_prepare_block = 0u;
Expand Down Expand Up @@ -36,8 +35,7 @@ class LC_XIR_API LoopInst final : public DerivedTerminatorInstruction<DerivedIns
[[nodiscard]] const BasicBlock *update_block() const noexcept;
};

class LC_XIR_API SimpleLoopInst final : public DerivedTerminatorInstruction<DerivedInstructionTag::SIMPLE_LOOP>,
public InstructionMergeMixin {
class LC_XIR_API SimpleLoopInst final : public ControlFlowMergeMixin<DerivedTerminatorInstruction<DerivedInstructionTag::SIMPLE_LOOP>> {
public:
static constexpr size_t operand_index_body_block = 0u;

Expand Down
5 changes: 2 additions & 3 deletions include/luisa/xir/instructions/outline.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ namespace luisa::compute::xir {

class BasicBlock;

class OutlineInst final : public DerivedBranchInstruction<DerivedInstructionTag::OUTLINE>,
public InstructionMergeMixin {
class OutlineInst final : public ControlFlowMergeMixin<DerivedBranchInstruction<DerivedInstructionTag::OUTLINE>> {
public:
using DerivedBranchInstruction::DerivedBranchInstruction;
using ControlFlowMergeMixin::ControlFlowMergeMixin;
};

}// namespace luisa::compute::xir
3 changes: 1 addition & 2 deletions include/luisa/xir/instructions/ray_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

namespace luisa::compute::xir {

class LC_XIR_API RayQueryInst final : public DerivedTerminatorInstruction<DerivedInstructionTag::RAY_QUERY>,
public InstructionMergeMixin {
class LC_XIR_API RayQueryInst final : public ControlFlowMergeMixin<DerivedTerminatorInstruction<DerivedInstructionTag::RAY_QUERY>> {

public:
static constexpr size_t operand_index_query_object = 0u;
Expand Down
3 changes: 1 addition & 2 deletions include/luisa/xir/instructions/switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ namespace luisa::compute::xir {
// { merge_block }
//
// Note: this instruction must be the terminator of a basic block.
class LC_XIR_API SwitchInst final : public DerivedTerminatorInstruction<DerivedInstructionTag::SWITCH>,
public InstructionMergeMixin {
class LC_XIR_API SwitchInst final : public ControlFlowMergeMixin<DerivedTerminatorInstruction<DerivedInstructionTag::SWITCH>> {

public:
using case_value_type = int;
Expand Down
14 changes: 9 additions & 5 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <luisa/runtime/rtx/hit.h>
#include <luisa/xir/module.h>
#include <luisa/xir/builder.h>
#include <luisa/xir/special_register.h>
#include <luisa/xir/metadata/name.h>
#include <luisa/xir/metadata/location.h>

Expand Down Expand Up @@ -2535,11 +2536,6 @@ class FallbackCodegen {
case xir::IntrinsicOp::INDIRECT_DISPATCH_SET_KERNEL: break;
case xir::IntrinsicOp::INDIRECT_DISPATCH_SET_COUNT: break;
case xir::IntrinsicOp::SHADER_EXECUTION_REORDER: return nullptr;// no-op on the LLVM side
case xir::IntrinsicOp::CLOCK: {
auto call = b.CreateIntrinsic(llvm::Intrinsic::readcyclecounter, {}, {});
auto llvm_result_type = _translate_type(inst->type(), true);
return b.CreateZExtOrTrunc(call, llvm_result_type);
}
}
LUISA_INFO("unsupported intrinsic op type: {}", static_cast<int>(inst->op()));
LUISA_NOT_IMPLEMENTED();
Expand Down Expand Up @@ -2872,6 +2868,14 @@ class FallbackCodegen {
}
case xir::DerivedInstructionTag::AUTO_DIFF: LUISA_NOT_IMPLEMENTED();
case xir::DerivedInstructionTag::RAY_QUERY: LUISA_NOT_IMPLEMENTED();
case xir::DerivedInstructionTag::CLOCK: {
auto call = b.CreateIntrinsic(llvm::Intrinsic::readcyclecounter, {}, {});
auto llvm_result_type = _translate_type(inst->type(), true);
return b.CreateZExtOrTrunc(call, llvm_result_type);
}
case xir::DerivedInstructionTag::ALU: break;
case xir::DerivedInstructionTag::CTA: break;
case xir::DerivedInstructionTag::RESOURCE: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid instruction.");
}
Expand Down
1 change: 1 addition & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
instructions/assume.cpp
instructions/call.cpp
instructions/cast.cpp
instructions/clock.cpp
instructions/gep.cpp
instructions/intrinsic.cpp
instructions/load.cpp
Expand Down
8 changes: 6 additions & 2 deletions src/xir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,21 @@ GEPInst *Builder::gep(const Type *type, Value *base, luisa::span<Value *const> i

LoadInst *Builder::load(const Type *type, Value *variable) noexcept {
LUISA_ASSERT(variable->is_lvalue(), "Load source must be an lvalue.");
LUISA_ASSERT(type == variable->type(), "Type mismatch in Load");
LUISA_ASSERT(type == variable->type(), "Type mismatch in Load");
return _create_and_append_instruction<LoadInst>(type, variable);
}

StoreInst *Builder::store(Value *variable, Value *value) noexcept {
LUISA_ASSERT(variable->is_lvalue(), "Store destination must be an lvalue.");
LUISA_ASSERT(!value->is_lvalue(), "Store source cannot be an lvalue.");
LUISA_ASSERT(variable->type() == value->type(), "Type mismatch in Store");
LUISA_ASSERT(variable->type() == value->type(), "Type mismatch in Store");
return _create_and_append_instruction<StoreInst>(variable, value);
}

ClockInst *Builder::clock() noexcept {
return _create_and_append_instruction<ClockInst>();
}

OutlineInst *Builder::outline() noexcept {
return _create_and_append_instruction<OutlineInst>();
}
Expand Down
6 changes: 5 additions & 1 deletion src/xir/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ void Instruction::replace_self_with(Instruction *node) noexcept {
remove_self();
}

const ControlFlowMerge *Instruction::control_flow_merge() const noexcept {
return const_cast<Instruction *>(this)->control_flow_merge();
}

TerminatorInstruction::TerminatorInstruction() noexcept
: Instruction{nullptr} {}

Expand Down Expand Up @@ -141,7 +145,7 @@ const BasicBlock *ConditionalBranchTerminatorInstruction::false_block() const no
return const_cast<ConditionalBranchTerminatorInstruction *>(this)->false_block();
}

BasicBlock *InstructionMergeMixin::create_merge_block(bool overwrite_existing) noexcept {
BasicBlock *ControlFlowMerge::create_merge_block(bool overwrite_existing) noexcept {
LUISA_ASSERT(merge_block() == nullptr || overwrite_existing,
"Merge block already exists.");
auto block = Pool::current()->create<BasicBlock>();
Expand Down
9 changes: 9 additions & 0 deletions src/xir/instructions/clock.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include <luisa/ast/type_registry.h>
#include <luisa/xir/instructions/clock.h>

namespace luisa::compute::xir {

ClockInst::ClockInst() noexcept
: DerivedInstruction{Type::of<luisa::ulong>()} {}

}// namespace luisa::compute::xir
2 changes: 0 additions & 2 deletions src/xir/instructions/intrinsic_name_map.inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ luisa::string to_string(IntrinsicOp op) noexcept {
case IntrinsicOp::INDIRECT_DISPATCH_SET_KERNEL: return "indirect_dispatch_set_kernel";
case IntrinsicOp::INDIRECT_DISPATCH_SET_COUNT: return "indirect_dispatch_set_count";
case IntrinsicOp::SHADER_EXECUTION_REORDER: return "shader_execution_reorder";
case IntrinsicOp::CLOCK: return "clock";
}
LUISA_ERROR_WITH_LOCATION("Unknown intrinsic operation: {}.",
static_cast<uint32_t>(op));
Expand Down Expand Up @@ -454,7 +453,6 @@ IntrinsicOp intrinsic_op_from_string(luisa::string_view name) noexcept {
{"indirect_dispatch_set_kernel", IntrinsicOp::INDIRECT_DISPATCH_SET_KERNEL},
{"indirect_dispatch_set_count", IntrinsicOp::INDIRECT_DISPATCH_SET_COUNT},
{"shader_execution_reorder", IntrinsicOp::SHADER_EXECUTION_REORDER},
{"clock", IntrinsicOp::CLOCK},
};
auto iter = m.find(name);
LUISA_ASSERT(iter != m.end(), "Unknown intrinsic operation: {}.", name);
Expand Down
3 changes: 2 additions & 1 deletion src/xir/translators/ast2xir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <luisa/ast/statement.h>
#include <luisa/ast/function.h>
#include <luisa/xir/builder.h>
#include <luisa/xir/special_register.h>
#include <luisa/xir/translators/ast2xir.h>

namespace luisa::compute::xir {
Expand Down Expand Up @@ -709,7 +710,7 @@ class AST2XIRContext {
case CallOp::TEXTURE3D_SAMPLE_GRAD: return resource_call(IntrinsicOp::TEXTURE3D_SAMPLE_GRAD);
case CallOp::TEXTURE3D_SAMPLE_GRAD_LEVEL: return resource_call(IntrinsicOp::TEXTURE3D_SAMPLE_GRAD_LEVEL);
case CallOp::SHADER_EXECUTION_REORDER: return resource_call(IntrinsicOp::SHADER_EXECUTION_REORDER);
case CallOp::CLOCK: return pure_call(IntrinsicOp::CLOCK);
case CallOp::CLOCK: return b.clock();
}
LUISA_NOT_IMPLEMENTED();
}
Expand Down
11 changes: 11 additions & 0 deletions src/xir/translators/xir2text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/instructions/cast.h>
#include <luisa/xir/instructions/clock.h>
#include <luisa/xir/instructions/continue.h>
#include <luisa/xir/instructions/gep.h>
#include <luisa/xir/instructions/intrinsic.h>
Expand Down Expand Up @@ -248,6 +249,10 @@ class XIR2TextTranslator final {
_main << " " << _value_ident(inst->condition());
}

void _emit_clock_inst(const ClockInst *inst [[maybe_unused]]) noexcept {
_main << "clock";
}

void _emit_if_inst(const IfInst *inst, int indent) noexcept {
_main << "if " << _value_ident(inst->condition()) << ", then ";
_emit_basic_block(inst->true_block(), indent);
Expand Down Expand Up @@ -469,6 +474,12 @@ class XIR2TextTranslator final {
case DerivedInstructionTag::ASSUME:
_emit_assume_inst(static_cast<const AssumeInst *>(inst));
break;
case DerivedInstructionTag::CLOCK:
_emit_clock_inst(static_cast<const ClockInst *>(inst));
break;
case DerivedInstructionTag::ALU: break;
case DerivedInstructionTag::CTA: break;
case DerivedInstructionTag::RESOURCE: break;
}
_main << ";";
_emit_use_debug_info(_main, inst->use_list());
Expand Down

0 comments on commit 33292bb

Please sign in to comment.