Skip to content

Commit

Permalink
add support for device log
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 17, 2024
1 parent 78501f2 commit ee69bbf
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 351 deletions.
580 changes: 291 additions & 289 deletions src/backends/common/shader_print_formatter.h

Large diffs are not rendered by default.

85 changes: 80 additions & 5 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class FallbackCodegen {
luisa::unordered_map<const Type *, luisa::unique_ptr<LLVMStruct>> _llvm_struct_types;
luisa::unordered_map<const xir::Constant *, llvm::Constant *> _llvm_constants;
luisa::unordered_map<const xir::Function *, llvm::Function *> _llvm_functions;
FallbackCodeGenFeedback::PrintInstMap _print_inst_map;

private:
void _reset() noexcept {
Expand Down Expand Up @@ -2606,6 +2607,78 @@ class FallbackCodegen {
LUISA_ERROR_WITH_LOCATION("Invalid cast operation.");
}

[[nodiscard]] llvm::Value *_translate_print_inst(CurrentFunction &current, IRBuilder &b,
const xir::PrintInst *inst) noexcept {
// create argument struct
llvm::SmallVector<llvm::Type *, 8> llvm_field_types;
llvm::SmallVector<size_t, 8> llvm_field_offsets;
size_t accum_size = 0u;
auto llvm_i8_type = b.getInt8Ty();
for (auto o : inst->operand_uses()) {
auto t = o->value()->type();
auto alignment = _get_type_alignment(t);
auto align_size = luisa::align(accum_size, alignment);
if (align_size > accum_size) {
auto llvm_pad_type = llvm::ArrayType::get(llvm_i8_type, align_size - accum_size);
llvm_field_types.emplace_back(llvm_pad_type);
}
llvm_field_offsets.emplace_back(llvm_field_types.size());
auto llvm_t = _translate_type(t, true);
llvm_field_types.emplace_back(llvm_t);
accum_size = align_size + _get_type_size(t);
}
auto total_size = luisa::align(accum_size, 16u);
if (total_size > accum_size) {
auto llvm_pad_type = llvm::ArrayType::get(llvm_i8_type, total_size - accum_size);
llvm_field_types.emplace_back(llvm_pad_type);
}
auto llvm_struct_type = llvm::StructType::get(_llvm_context, llvm_field_types);
// fill argument struct
auto llvm_struct_alloca = b.CreateAlloca(llvm_struct_type);
auto llvm_total_size = b.getInt64(total_size);
b.CreateLifetimeStart(llvm_struct_alloca, llvm_total_size);
llvm_struct_alloca->setAlignment(llvm::Align{16});
for (auto i = 0u; i < inst->operand_count(); i++) {
auto llvm_field_offset = llvm_field_offsets[i];
auto llvm_field_ptr = b.CreateStructGEP(llvm_struct_type, llvm_struct_alloca, llvm_field_offset);
auto field = inst->operand(i);
auto llvm_field = _lookup_value(current, b, field);
auto alignment = _get_type_alignment(field->type());
b.CreateAlignedStore(llvm_field, llvm_field_ptr, llvm::MaybeAlign{alignment});
}
// declare print function
auto llvm_i64_type = b.getInt64Ty();
auto llvm_ptr_type = llvm::PointerType::get(_llvm_context, 0);
// void print(const void *ctx, size_t fmt_id, const void *args);
auto llvm_print_func_type = llvm::FunctionType::get(b.getVoidTy(), {llvm_ptr_type, llvm_i64_type, llvm_ptr_type}, false);
auto llvm_print_func = llvm::Function::Create(llvm_print_func_type, llvm::Function::ExternalLinkage, "luisa.print", _llvm_module);
auto llvm_print_context = _llvm_module->getOrInsertGlobal("luisa.print.context", llvm::StructType::get(_llvm_context));
auto fmt_id = static_cast<uint64_t>(_print_inst_map.size());
_print_inst_map.emplace_back(inst, llvm_print_func->getName());
llvm_print_func->setCallingConv(llvm::CallingConv::C);
llvm_print_func->setNoSync();
llvm_print_func->setMustProgress();
llvm_print_func->setWillReturn();
llvm_print_func->setDoesNotThrow();
llvm_print_func->setOnlyAccessesInaccessibleMemOrArgMem();
llvm_print_func->setDoesNotFreeMemory();
llvm_print_func->setUWTableKind(llvm::UWTableKind::None);
for (auto &&llvm_print_arg : llvm_print_func->args()) {
if (llvm_print_arg.getType()->isPointerTy()) {
llvm_print_arg.addAttr(llvm::Attribute::NoCapture);
llvm_print_arg.addAttr(llvm::Attribute::NoAlias);
llvm_print_arg.addAttr(llvm::Attribute::ReadOnly);
llvm_print_arg.addAttr(llvm::Attribute::NoUndef);
llvm_print_arg.addAttr(llvm::Attribute::NonNull);
}
}
// call print function
auto llvm_fmt_id = b.getInt64(fmt_id);
auto llvm_call = b.CreateCall(llvm_print_func, {llvm_print_context, llvm_fmt_id, llvm_struct_alloca});
b.CreateLifetimeEnd(llvm_struct_alloca, llvm_total_size);
return llvm_call;
}

[[nodiscard]] llvm::Value *_translate_instruction(CurrentFunction &current, IRBuilder &b, const xir::Instruction *inst) noexcept {
switch (inst->derived_instruction_tag()) {
case xir::DerivedInstructionTag::SENTINEL: {
Expand Down Expand Up @@ -2756,8 +2829,8 @@ class FallbackCodegen {
return _translate_cast_inst(current, b, cast_inst->type(), cast_inst->op(), cast_inst->value());
}
case xir::DerivedInstructionTag::PRINT: {
LUISA_WARNING_WITH_LOCATION("Ignoring print instruction.");// TODO...
return nullptr;
auto print_inst = static_cast<const xir::PrintInst *>(inst);
return _translate_print_inst(current, b, print_inst);
}
case xir::DerivedInstructionTag::ASSERT: {
auto assert_inst = static_cast<const xir::AssertInst *>(inst);
Expand Down Expand Up @@ -3107,7 +3180,7 @@ class FallbackCodegen {
explicit FallbackCodegen(llvm::LLVMContext &ctx) noexcept
: _llvm_context{ctx} {}

void emit(llvm::Module *llvm_module, const xir::Module *module) noexcept {
FallbackCodeGenFeedback emit(llvm::Module *llvm_module, const xir::Module *module) noexcept {
auto location_md = module->find_metadata<xir::LocationMD>();
auto module_location = location_md ? location_md->file().string() : "unknown";
llvm_module->setSourceFileName(location_md ? location_md->file().string() : "unknown");
Expand All @@ -3117,12 +3190,14 @@ class FallbackCodegen {
_llvm_module = llvm_module;
_translate_module(module);
_reset();
return {.print_inst_map = std::exchange(_print_inst_map, {})};
}
};

void luisa_fallback_backend_codegen(llvm::LLVMContext &llvm_ctx, llvm::Module *llvm_module, const xir::Module *module) noexcept {
FallbackCodeGenFeedback
luisa_fallback_backend_codegen(llvm::LLVMContext &llvm_ctx, llvm::Module *llvm_module, const xir::Module *module) noexcept {
FallbackCodegen codegen{llvm_ctx};
codegen.emit(llvm_module, module);
return codegen.emit(llvm_module, module);
}

}// namespace luisa::compute::fallback
17 changes: 14 additions & 3 deletions src/backends/fallback/fallback_codegen.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
#pragma once

#include <luisa/core/stl/unordered_map.h>

namespace llvm {
class Module;
class LLVMContext;
}// namespace llvm

namespace luisa::compute::xir {
class PrintInst;
class Module;
}// namespace luisa::compute::xir

namespace luisa::compute::fallback {

void luisa_fallback_backend_codegen(llvm::LLVMContext &llvm_ctx,
llvm::Module *llvm_module,
const xir::Module *module) noexcept;
struct FallbackCodeGenFeedback {
using PrintInstMap = luisa::vector<std::pair<
const xir::PrintInst *,
luisa::string /* llvm symbol */>>;
PrintInstMap print_inst_map;
};

[[nodiscard]] FallbackCodeGenFeedback
luisa_fallback_backend_codegen(llvm::LLVMContext &llvm_ctx,
llvm::Module *llvm_module,
const xir::Module *module) noexcept;

}// namespace luisa::compute::fallback
5 changes: 5 additions & 0 deletions src/backends/fallback/fallback_command_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <luisa/core/basic_types.h>
#include <luisa/core/stl/queue.h>
#include <luisa/core/stl/functional.h>
#include <luisa/runtime/rhi/device_interface.h>

#if defined(LUISA_PLATFORM_APPLE)
#define LUISA_FALLBACK_USE_DISPATCH_QUEUE
Expand Down Expand Up @@ -41,6 +42,7 @@ class FallbackCommandQueue {
std::atomic_size_t _total_enqueue_count{0u};
std::atomic_size_t _total_finish_count{0u};
size_t _worker_count{0u};
DeviceInterface::StreamLogCallback _log_callback;

#if defined(LUISA_FALLBACK_USE_DISPATCH_QUEUE)
dispatch_queue_t _dispatch_queue{nullptr};
Expand All @@ -59,6 +61,9 @@ class FallbackCommandQueue {
void enqueue(luisa::move_only_function<void()> &&task) noexcept;
void enqueue_parallel(uint n, luisa::move_only_function<void(uint)> &&task) noexcept;
void synchronize() noexcept;

void set_log_callback(DeviceInterface::StreamLogCallback callback) noexcept { _log_callback = std::move(callback); }
[[nodiscard]] auto &log_callback() const noexcept { return _log_callback; }
};

}// namespace luisa::compute::fallback
3 changes: 2 additions & 1 deletion src/backends/fallback/fallback_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ void FallbackDevice::dispatch(uint64_t stream_handle, CommandList &&list) noexce
}

void FallbackDevice::set_stream_log_callback(uint64_t stream_handle, const DeviceInterface::StreamLogCallback &callback) noexcept {
DeviceInterface::set_stream_log_callback(stream_handle, callback);
auto stream = reinterpret_cast<FallbackStream *>(stream_handle);
stream->queue()->set_log_callback(callback);
}

SwapchainCreationInfo FallbackDevice::create_swapchain(const SwapchainOption &option, uint64_t stream_handle) noexcept {
Expand Down
138 changes: 88 additions & 50 deletions src/backends/fallback/fallback_shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
// Created by swfly on 2024/11/21.
//
#include <fstream>
#include <luisa/xir/translators/ast2xir.h>
#include <luisa/xir/translators/xir2text.h>
#include <luisa/core/stl.h>
#include <luisa/core/logging.h>
#include <luisa/core/clock.h>

#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/MC/TargetRegistry.h>
Expand All @@ -20,6 +15,15 @@
#include <llvm/Passes/PassBuilder.h>
#include <llvm/IR/LegacyPassManager.h>

#include <luisa/xir/translators/ast2xir.h>
#include <luisa/xir/translators/xir2text.h>
#include <luisa/core/stl.h>
#include <luisa/core/logging.h>
#include <luisa/core/clock.h>
#include <luisa/xir/instructions/print.h>

#include "../common/shader_print_formatter.h"

#include "fallback_codegen.h"
#include "fallback_texture.h"
#include "fallback_accel.h"
Expand Down Expand Up @@ -52,6 +56,20 @@ static void luisa_fallback_assert(bool condition, const char *message) noexcept
if (!condition) { LUISA_ERROR_WITH_LOCATION("Assertion failed: {}.", message); }
}

static thread_local const DeviceInterface::StreamLogCallback *current_device_log_callback{nullptr};

static void luisa_fallback_print(const FallbackShader *shader, size_t fmt_id, const std::byte *args) noexcept {
static thread_local luisa::string scratch;
scratch.clear();
auto formatter = shader->print_formatter(fmt_id);
(*formatter)(scratch, {args, formatter->size()});
if (current_device_log_callback) {
(*current_device_log_callback)(scratch);
} else {
LUISA_INFO("[DEVICE] {}", scratch);
}
}

struct FallbackShaderLaunchConfig {
uint3 block_id;
uint3 dispatch_size;
Expand Down Expand Up @@ -114,10 +132,51 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
LUISA_ERROR_WITH_LOCATION("Failed to create LLJIT.");
}

// if (auto generator = ::llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
// _jit->getDataLayout().getGlobalPrefix())) {
// _jit->getMainJITDylib().addGenerator(std::move(generator.get()));
// } else {
// ::llvm::handleAllErrors(generator.takeError(), [](const ::llvm::ErrorInfoBase &err) {
// LUISA_WARNING_WITH_LOCATION("DynamicLibrarySearchGenerator::GetForCurrentProcess(): {}", err.message());
// });
// LUISA_ERROR_WITH_LOCATION("Failed to add generator.");
// }

_block_size = kernel.block_size();
_build_bound_arguments(kernel.bound_arguments());

xir::Pool pool;
xir::PoolGuard guard{&pool};
auto xir_module = xir::ast_to_xir_translate(kernel, {});
xir_module->set_name(luisa::format("kernel_{:016x}", kernel.hash()));
if (!option.name.empty()) { xir_module->set_location(option.name); }
// LUISA_INFO("Kernel XIR:\n{}", xir::xir_to_text_translate(xir_module, true));

auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
auto builtin_module = fallback_backend_device_builtin_module();
llvm::SMDiagnostic parse_error;
auto llvm_module = llvm::parseIR(llvm::MemoryBufferRef{builtin_module, ""}, parse_error, *llvm_ctx);
if (!llvm_module) {
LUISA_ERROR_WITH_LOCATION("Failed to generate LLVM IR: {}.",
luisa::string_view{parse_error.getMessage()});
}
auto codegen_feedback = luisa_fallback_backend_codegen(*llvm_ctx, llvm_module.get(), xir_module);
//llvm_module->print(llvm::errs(), nullptr, true, true);
//llvm_module->print(llvm::outs(), nullptr, true, true);
if (llvm::verifyModule(*llvm_module, &llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("LLVM module verification failed.");
}
// {
// llvm_module->print(llvm::errs(), nullptr, true, true);
// // std::error_code EC;
// // llvm::raw_fd_ostream file_stream("H:/abc.ll", EC, llvm::sys::fs::OF_None);
// // llvm_module->print(file_stream, nullptr, true, true);
// // file_stream.close();
// }

// map symbols
llvm::orc::SymbolMap symbol_map{};
auto map_symbol = [jit = _jit.get(), &symbol_map]<typename T>(const char *name, T *f) noexcept {
static_assert(std::is_function_v<T>);
auto addr = llvm::orc::ExecutorAddr::fromPtr(f);
auto symbol = llvm::orc::ExecutorSymbolDef{addr, llvm::JITSymbolFlags::Callable};
symbol_map.try_emplace(jit->mangleAndIntern(name), symbol);
Expand All @@ -142,6 +201,25 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
// assert
map_symbol("luisa.assert", &luisa_fallback_assert);

// bind print instructions
if (!codegen_feedback.print_inst_map.empty()) {
map_symbol("luisa.print.context", this);
_print_formatters.reserve(codegen_feedback.print_inst_map.size());
for (auto fmt_id = 0u; fmt_id < codegen_feedback.print_inst_map.size(); fmt_id++) {
auto &&[print_inst, llvm_symbol] = codegen_feedback.print_inst_map[fmt_id];
map_symbol(llvm_symbol.c_str(), &luisa_fallback_print);
LUISA_INFO("Mapping print instruction #{}: \"{}\" -> {}", fmt_id, print_inst->format(), llvm_symbol);
llvm::SmallVector<const Type *, 8u> arg_types;
for (auto o : print_inst->operand_uses()) {
arg_types.emplace_back(o->value()->type());
}
auto arg_pack_type = Type::structure(16u, arg_types);
_print_formatters.emplace_back(luisa::make_unique<ShaderPrintFormatter>(
print_inst->format(), arg_pack_type, false));
}
}

// define symbols
if (auto error = _jit->getMainJITDylib().define(
::llvm::orc::absoluteSymbols(std::move(symbol_map)))) {
::llvm::handleAllErrors(std::move(error), [](const ::llvm::ErrorInfoBase &err) {
Expand All @@ -150,48 +228,6 @@ FallbackShader::FallbackShader(const ShaderOption &option, Function kernel) noex
LUISA_ERROR_WITH_LOCATION("Failed to define symbols.");
}

if (auto generator = ::llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
_jit->getDataLayout().getGlobalPrefix())) {
_jit->getMainJITDylib().addGenerator(std::move(generator.get()));
} else {
::llvm::handleAllErrors(generator.takeError(), [](const ::llvm::ErrorInfoBase &err) {
LUISA_WARNING_WITH_LOCATION("DynamicLibrarySearchGenerator::GetForCurrentProcess(): {}", err.message());
});
LUISA_ERROR_WITH_LOCATION("Failed to add generator.");
}

_block_size = kernel.block_size();
_build_bound_arguments(kernel.bound_arguments());

xir::Pool pool;
xir::PoolGuard guard{&pool};
auto xir_module = xir::ast_to_xir_translate(kernel, {});
xir_module->set_name(luisa::format("kernel_{:016x}", kernel.hash()));
if (!option.name.empty()) { xir_module->set_location(option.name); }
// LUISA_INFO("Kernel XIR:\n{}", xir::xir_to_text_translate(xir_module, true));

auto llvm_ctx = std::make_unique<llvm::LLVMContext>();
auto builtin_module = fallback_backend_device_builtin_module();
llvm::SMDiagnostic parse_error;
auto llvm_module = llvm::parseIR(llvm::MemoryBufferRef{builtin_module, ""}, parse_error, *llvm_ctx);
if (!llvm_module) {
LUISA_ERROR_WITH_LOCATION("Failed to generate LLVM IR: {}.",
luisa::string_view{parse_error.getMessage()});
}
luisa_fallback_backend_codegen(*llvm_ctx, llvm_module.get(), xir_module);
//llvm_module->print(llvm::errs(), nullptr, true, true);
//llvm_module->print(llvm::outs(), nullptr, true, true);
if (llvm::verifyModule(*llvm_module, &llvm::errs())) {
LUISA_ERROR_WITH_LOCATION("LLVM module verification failed.");
}
// {
// llvm_module->print(llvm::errs(), nullptr, true, true);
// // std::error_code EC;
// // llvm::raw_fd_ostream file_stream("H:/abc.ll", EC, llvm::sys::fs::OF_None);
// // llvm_module->print(file_stream, nullptr, true, true);
// // file_stream.close();
// }

// optimize
llvm_module->setDataLayout(_target_machine->createDataLayout());
llvm_module->setTargetTriple(_target_machine->getTargetTriple().str());
Expand Down Expand Up @@ -436,7 +472,7 @@ void FallbackShader::dispatch(FallbackCommandQueue *queue, luisa::unique_ptr<Sha
auto grid_size = roundup_div(dispatch_size, block_size);
auto grid_count = grid_size.x * grid_size.y * grid_size.z;

queue->enqueue_parallel(grid_count, [dispatch_buffer = std::move(dispatch_buffer)](auto block) noexcept {
queue->enqueue_parallel(grid_count, [queue, dispatch_buffer = std::move(dispatch_buffer)](auto block) noexcept {
auto config = dispatch_buffer.config();
auto dispatch_size = config->dispatch_size;
auto block_size = config->block_size;
Expand All @@ -451,7 +487,9 @@ void FallbackShader::dispatch(FallbackCommandQueue *queue, luisa::unique_ptr<Sha
.block_size = {block_size[0], block_size[1], block_size[2]},
};
auto launch_params = dispatch_buffer.argument_buffer();
(config->kernel)(launch_params, &launch_config);
current_device_log_callback = queue->log_callback() ? &queue->log_callback() : nullptr;
config->kernel(launch_params, &launch_config);
current_device_log_callback = nullptr;
});
}

Expand Down
Loading

0 comments on commit ee69bbf

Please sign in to comment.