Skip to content

Commit

Permalink
refactor xir atomic inst
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 29, 2024
1 parent 33292bb commit 2115585
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 57 deletions.
1 change: 1 addition & 0 deletions include/luisa/luisa-compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/assert.h>
#include <luisa/xir/instructions/assume.h>
#include <luisa/xir/instructions/atomic.h>
#include <luisa/xir/instructions/branch.h>
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
Expand Down
12 changes: 12 additions & 0 deletions include/luisa/xir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <luisa/xir/instructions/alloca.h>
#include <luisa/xir/instructions/assert.h>
#include <luisa/xir/instructions/assume.h>
#include <luisa/xir/instructions/atomic.h>
#include <luisa/xir/instructions/branch.h>
#include <luisa/xir/instructions/break.h>
#include <luisa/xir/instructions/call.h>
Expand Down Expand Up @@ -108,6 +109,17 @@ class LC_XIR_API Builder {

OutlineInst *outline() noexcept;
RayQueryInst *ray_query(Value *query_object) noexcept;

AtomicInst *atomic(const Type *type, AtomicOp op, Value *base, luisa::span<Value *const> indices, luisa::span<Value *const> values) noexcept;
AtomicInst *atomic_fetch_add(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_sub(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_and(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_or(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_xor(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_min(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_fetch_max(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_exchange(const Type *type, Value *base, luisa::span<Value *const> indices, Value *value) noexcept;
AtomicInst *atomic_compare_exchange(const Type *type, Value *base, luisa::span<Value *const> indices, Value *expected, Value *desired) noexcept;
};

}// namespace luisa::compute::xir
2 changes: 1 addition & 1 deletion include/luisa/xir/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class ControlFlowMergeMixin : public Base,
public:
using Base::Base;
[[nodiscard]] ControlFlowMerge *control_flow_merge() noexcept final {
return static_cast<ControlFlowMerge *>(this);
return this;
}
};

Expand Down
75 changes: 75 additions & 0 deletions include/luisa/xir/instructions/atomic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include <luisa/xir/instruction.h>

namespace luisa::compute::xir {

enum class AtomicOp {
EXCHANGE, /// [(base, indices..., desired) -> old]: stores desired, returns old.
COMPARE_EXCHANGE,/// [(base, indices..., expected, desired) -> old]: stores (old == expected ? desired : old), returns old.
FETCH_ADD, /// [(base, indices..., val) -> old]: stores (old + val), returns old.
FETCH_SUB, /// [(base, indices..., val) -> old]: stores (old - val), returns old.
FETCH_AND, /// [(base, indices..., val) -> old]: stores (old & val), returns old.
FETCH_OR, /// [(base, indices..., val) -> old]: stores (old | val), returns old.
FETCH_XOR, /// [(base, indices..., val) -> old]: stores (old ^ val), returns old.
FETCH_MIN, /// [(base, indices..., val) -> old]: stores min(old, val), returns old.
FETCH_MAX, /// [(base, indices..., val) -> old]: stores max(old, val), returns old.
};

[[nodiscard]] constexpr luisa::string_view to_string(AtomicOp op) noexcept {
using namespace std::string_view_literals;
switch (op) {
case AtomicOp::EXCHANGE: return "exchange"sv;
case AtomicOp::COMPARE_EXCHANGE: return "compare_exchange"sv;
case AtomicOp::FETCH_ADD: return "fetch_add"sv;
case AtomicOp::FETCH_SUB: return "fetch_sub"sv;
case AtomicOp::FETCH_AND: return "fetch_and"sv;
case AtomicOp::FETCH_OR: return "fetch_or"sv;
case AtomicOp::FETCH_XOR: return "fetch_xor"sv;
case AtomicOp::FETCH_MIN: return "fetch_min"sv;
case AtomicOp::FETCH_MAX: return "fetch_max"sv;
}
return "unknown"sv;
}

[[nodiscard]] constexpr auto atomic_op_value_count(AtomicOp op) noexcept {
return op == AtomicOp::COMPARE_EXCHANGE ? 2u : 1u;
}

class LC_XIR_API AtomicInst final : public DerivedInstruction<DerivedInstructionTag::ATOMIC> {

private:
AtomicOp _op;

public:
explicit AtomicInst(const Type *type = nullptr, AtomicOp op = {},
Value *base = nullptr,
luisa::span<Value *const> indices = {},
luisa::span<Value *const> values = {}) noexcept;
[[nodiscard]] auto op() const noexcept { return _op; }
void set_op(AtomicOp op) noexcept { _op = op; }

[[nodiscard]] Value *base() noexcept;
[[nodiscard]] const Value *base() const noexcept;
void set_base(Value *base) noexcept;

[[nodiscard]] Use *base_use() noexcept;
[[nodiscard]] const Use *base_use() const noexcept;

[[nodiscard]] size_t index_count() const noexcept;
void set_index_count(size_t count) noexcept;

[[nodiscard]] luisa::span<Use *const> index_uses() noexcept;
[[nodiscard]] luisa::span<const Use *const> index_uses() const noexcept;
void set_indices(luisa::span<Value *const> indices) noexcept;

[[nodiscard]] size_t value_count() const noexcept {
return atomic_op_value_count(_op);
}

[[nodiscard]] luisa::span<Use *const> value_uses() noexcept;
[[nodiscard]] luisa::span<const Use *const> value_uses() const noexcept;
void set_values(luisa::span<Value *const> values) noexcept;
};

}// namespace luisa::compute::xir
2 changes: 1 addition & 1 deletion include/luisa/xir/instructions/clock.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace luisa::compute::xir {

class LC_XIR_API ClockInst : public DerivedInstruction<DerivedInstructionTag::CLOCK> {
class LC_XIR_API ClockInst final : public DerivedInstruction<DerivedInstructionTag::CLOCK> {
public:
ClockInst() noexcept;
};
Expand Down
11 changes: 0 additions & 11 deletions include/luisa/xir/instructions/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,6 @@ enum struct IntrinsicOp {
MATRIX_TRANSPOSE, // (floatNxN) -> floatNxN
MATRIX_INVERSE, // (floatNxN) -> floatNxN

// atomic operations
ATOMIC_EXCHANGE, /// [(atomic_ref, desired) -> old]: stores desired, returns old.
ATOMIC_COMPARE_EXCHANGE,/// [(atomic_ref, expected, desired) -> old]: stores (old == expected ? desired : old), returns old.
ATOMIC_FETCH_ADD, /// [(atomic_ref, val) -> old]: stores (old + val), returns old.
ATOMIC_FETCH_SUB, /// [(atomic_ref, val) -> old]: stores (old - val), returns old.
ATOMIC_FETCH_AND, /// [(atomic_ref, val) -> old]: stores (old & val), returns old.
ATOMIC_FETCH_OR, /// [(atomic_ref, val) -> old]: stores (old | val), returns old.
ATOMIC_FETCH_XOR, /// [(atomic_ref, val) -> old]: stores (old ^ val), returns old.
ATOMIC_FETCH_MIN, /// [(atomic_ref, val) -> old]: stores min(old, val), returns old.
ATOMIC_FETCH_MAX, /// [(atomic_ref, val) -> old]: stores max(old, val), returns old.

// resource operations
BUFFER_READ, /// [(buffer, index) -> value]: reads the index-th element in buffer
BUFFER_WRITE,/// [(buffer, index, value) -> void]: writes value into the index-th element of buffer
Expand Down
41 changes: 24 additions & 17 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ class FallbackCodegen {
llvm_ptr = b.CreateStructGEP(llvm_struct_type->type, llvm_ptr, padded_index);
ptr_type = ptr_type->members()[static_index];
} else {
LUISA_ASSERT(ptr_type->is_array() || ptr_type->is_vector() || ptr_type->is_matrix(), "Invalid pointer type.");
LUISA_ASSERT(ptr_type->is_array() || ptr_type->is_vector() || ptr_type->is_matrix(),
"Invalid pointer type: {}", ptr_type->description());
auto llvm_index = _lookup_value(current, b, index);
llvm_index = b.CreateZExtOrTrunc(llvm_index, llvm::Type::getInt64Ty(_llvm_context));
auto llvm_ptr_type = _translate_type(ptr_type, false);
Expand Down Expand Up @@ -1912,20 +1913,20 @@ class FallbackCodegen {
}

[[nodiscard]] llvm::Value *_translate_atomic_op(CurrentFunction &current, IRBuilder &b,
const char *op_name, size_t value_count,
const xir::IntrinsicInst *inst,
const char *op_name, const xir::AtomicInst *inst,
bool byte_address = false) noexcept {
LUISA_ASSERT(2 + value_count <= inst->operand_count(), "Invalid atomic operation.");
auto buffer = inst->operand(0u);
auto value_count = inst->value_count();
auto buffer = inst->base();
LUISA_ASSERT(buffer->type()->is_buffer(), "Invalid buffer type.");
auto buffer_elem_type = buffer->type()->element();
LUISA_ASSERT(buffer_elem_type != nullptr, "Invalid buffer element type.");
auto slot = inst->operand(1u);
auto indices = inst->index_uses();
auto slot = indices.front()->value();
indices = indices.subspan(1);
auto llvm_elem_ptr = _get_buffer_element_ptr(current, b, buffer, slot, byte_address);
if (byte_address) {
LUISA_ASSERT(2 + value_count == inst->operand_count(), "Invalid atomic operation.");
}
auto indices = inst->operand_uses().subspan(2, inst->operand_count() - 2 - value_count);
if (!indices.empty()) {
llvm_elem_ptr = _translate_gep(current, b, inst->type(),
buffer_elem_type, llvm_elem_ptr, indices);
Expand All @@ -1944,7 +1945,7 @@ class FallbackCodegen {
auto llvm_func = _llvm_module->getFunction(llvm::StringRef{llvm_func_name});
LUISA_ASSERT(llvm_func != nullptr && llvm_func->arg_size() == 1 + value_count, "Invalid atomic operation function.");
llvm::SmallVector<llvm::Value *, 4u> llvm_args{llvm_elem_ptr};
for (auto value_uses : inst->operand_uses().subspan(inst->operand_count() - value_count)) {
for (auto value_uses : inst->value_uses()) {
auto value = value_uses->value();
LUISA_ASSERT(value->type() == inst->type(), "Atomic operation value type mismatch.");
auto llvm_value = _lookup_value(current, b, value);
Expand Down Expand Up @@ -2409,15 +2410,6 @@ class FallbackCodegen {
case xir::IntrinsicOp::MATRIX_DETERMINANT: return _translate_matrix_determinant(current, b, inst);
case xir::IntrinsicOp::MATRIX_TRANSPOSE: return _translate_matrix_transpose(current, b, inst);
case xir::IntrinsicOp::MATRIX_INVERSE: return _translate_matrix_inverse(current, b, inst);
case xir::IntrinsicOp::ATOMIC_EXCHANGE: return _translate_atomic_op(current, b, "exchange", 1, inst);
case xir::IntrinsicOp::ATOMIC_COMPARE_EXCHANGE: return _translate_atomic_op(current, b, "compare.exchange", 2, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_ADD: return _translate_atomic_op(current, b, "fetch.add", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_SUB: return _translate_atomic_op(current, b, "fetch.sub", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_AND: return _translate_atomic_op(current, b, "fetch.and", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_OR: return _translate_atomic_op(current, b, "fetch.or", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_XOR: return _translate_atomic_op(current, b, "fetch.xor", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_MIN: return _translate_atomic_op(current, b, "fetch.min", 1, inst);
case xir::IntrinsicOp::ATOMIC_FETCH_MAX: return _translate_atomic_op(current, b, "fetch.max", 1, inst);
case xir::IntrinsicOp::BUFFER_READ: return _translate_buffer_read(current, b, inst);
case xir::IntrinsicOp::BUFFER_WRITE: return _translate_buffer_write(current, b, inst);
case xir::IntrinsicOp::BUFFER_SIZE: return _translate_buffer_size(current, b, inst);
Expand Down Expand Up @@ -2876,6 +2868,21 @@ class FallbackCodegen {
case xir::DerivedInstructionTag::ALU: break;
case xir::DerivedInstructionTag::CTA: break;
case xir::DerivedInstructionTag::RESOURCE: break;
case xir::DerivedInstructionTag::ATOMIC: {
auto atomic_inst = static_cast<const xir::AtomicInst *>(inst);
switch (atomic_inst->op()) {
case xir::AtomicOp::EXCHANGE: return _translate_atomic_op(current, b, "exchange", atomic_inst);
case xir::AtomicOp::COMPARE_EXCHANGE: return _translate_atomic_op(current, b, "compare.exchange", atomic_inst);
case xir::AtomicOp::FETCH_ADD: return _translate_atomic_op(current, b, "fetch.add", atomic_inst);
case xir::AtomicOp::FETCH_SUB: return _translate_atomic_op(current, b, "fetch.sub", atomic_inst);
case xir::AtomicOp::FETCH_AND: return _translate_atomic_op(current, b, "fetch.and", atomic_inst);
case xir::AtomicOp::FETCH_OR: return _translate_atomic_op(current, b, "fetch.or", atomic_inst);
case xir::AtomicOp::FETCH_XOR: return _translate_atomic_op(current, b, "fetch.xor", atomic_inst);
case xir::AtomicOp::FETCH_MIN: return _translate_atomic_op(current, b, "fetch.min", atomic_inst);
case xir::AtomicOp::FETCH_MAX: return _translate_atomic_op(current, b, "fetch.max", atomic_inst);
}
LUISA_ERROR_WITH_LOCATION("Invalid atomic operation.");
}
}
LUISA_ERROR_WITH_LOCATION("Invalid instruction.");
}
Expand Down
1 change: 1 addition & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
instructions/alloca.cpp
instructions/assert.cpp
instructions/assume.cpp
instructions/atomic.cpp
instructions/call.cpp
instructions/cast.cpp
instructions/clock.cpp
Expand Down
60 changes: 60 additions & 0 deletions src/xir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,66 @@ RayQueryInst *Builder::ray_query(Value *query_object) noexcept {
return _create_and_append_instruction<RayQueryInst>(query_object);
}

AtomicInst *Builder::atomic(const Type *type, AtomicOp op, Value *base,
luisa::span<Value *const> indices,
luisa::span<Value *const> values) noexcept {
return _create_and_append_instruction<AtomicInst>(type, op, base, indices, values);
}

AtomicInst *Builder::atomic_fetch_add(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_ADD, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_sub(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_SUB, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_and(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_AND, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_or(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_OR, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_xor(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_XOR, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_min(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_MIN, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_fetch_max(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::FETCH_MAX, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_exchange(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *value) noexcept {
return this->atomic(type, AtomicOp::EXCHANGE, base, indices, std::array{value});
}

AtomicInst *Builder::atomic_compare_exchange(const Type *type, Value *base,
luisa::span<Value *const> indices,
Value *expected, Value *desired) noexcept {
return this->atomic(type, AtomicOp::COMPARE_EXCHANGE, base, indices, std::array{expected, desired});
}

void Builder::set_insertion_point(Instruction *insertion_point) noexcept {
_insertion_point = insertion_point;
}
Expand Down
Loading

0 comments on commit 2115585

Please sign in to comment.