diff --git a/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir b/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir index 5b855218b4..1a0a727492 100644 --- a/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir +++ b/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir @@ -269,3 +269,83 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_ %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> llvm.return } + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1UC_L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<1, Cached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1UC_L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1C_L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Cached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1C_L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Streaming, 0>, #triton_gen.load_cache_control<1, Uncached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1S_L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Streaming, 0>, #triton_gen.load_cache_control<1, Cached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1S_L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, InvalidateAfterRead, 0>, #triton_gen.load_cache_control<1, Cached, 0>> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=L1IAR_L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-NOT: triton_gen.DecorationCacheControlINTEL + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index bc6afece9f..0a8f761abe 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -216,12 +216,60 @@ static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) { op.getTileWidth() == 32 && op.getVBlocks() == 1) return false; - if (op.getCacheControl() != TritonGEN::LoadCacheControl::DEFAULT) - return false; - return true; } +static SmallVector +loadCacheControlToDecoration(Builder &builder, uint32_t operandNum, + TritonGEN::LoadCacheControl orig) { + const auto build = [&builder, + operandNum](TritonGEN::LoadCacheControlDecorationEnum l1, + TritonGEN::LoadCacheControlDecorationEnum l3) + -> SmallVector { + return {builder.getAttr( + 0, l1, operandNum), + builder.getAttr( + 1, l3, operandNum)}; + }; + switch (orig) { + case TritonGEN::LoadCacheControl::DEFAULT: + return {}; + case TritonGEN::LoadCacheControl::L1UC_L3UC: + return build(TritonGEN::LoadCacheControlDecorationEnum::Uncached, + TritonGEN::LoadCacheControlDecorationEnum::Uncached); + case TritonGEN::LoadCacheControl::L1UC_L3C: + return build(TritonGEN::LoadCacheControlDecorationEnum::Uncached, + TritonGEN::LoadCacheControlDecorationEnum::Cached); + case TritonGEN::LoadCacheControl::L1C_L3UC: + return build(TritonGEN::LoadCacheControlDecorationEnum::Cached, + TritonGEN::LoadCacheControlDecorationEnum::Uncached); + case TritonGEN::LoadCacheControl::L1C_L3C: + return build(TritonGEN::LoadCacheControlDecorationEnum::Cached, + TritonGEN::LoadCacheControlDecorationEnum::Cached); + case TritonGEN::LoadCacheControl::L1S_L3UC: + return build(TritonGEN::LoadCacheControlDecorationEnum::Streaming, + TritonGEN::LoadCacheControlDecorationEnum::Uncached); + case TritonGEN::LoadCacheControl::L1S_L3C: + return build(TritonGEN::LoadCacheControlDecorationEnum::Streaming, + TritonGEN::LoadCacheControlDecorationEnum::Cached); + case TritonGEN::LoadCacheControl::L1IAR_L3C: + return build(TritonGEN::LoadCacheControlDecorationEnum::InvalidateAfterRead, + TritonGEN::LoadCacheControlDecorationEnum::Cached); + } + llvm_unreachable("Unhandled case"); +} + +static std::optional +loadCacheControlToCacheControls(Builder &builder, + TritonGEN::LoadCacheControl orig, + uint32_t operandNum) { + SmallVector decorations = + loadCacheControlToDecoration(builder, operandNum, orig); + if (decorations.empty()) + return {}; + return builder.getAttr(decorations); +} + static Value createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, ConversionPatternRewriter &rewriter) { MLIRContext *context = rewriter.getContext(); @@ -269,8 +317,15 @@ static Value createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, paramAttrs[5] = param5AttrBuilder.getAttributes(); intel::AttributeList attrs = getAttrList(funcAttrBuilder, paramAttrs); - createDeviceFunctionCall(rewriter, fnName, void_ty(context), argTypes, args, - attrs); + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, fnName, void_ty(context), argTypes, args, attrs); + constexpr uint32_t ptrOperandIndex = 0; + if (std::optional optCacheControls = + loadCacheControlToCacheControls(rewriter, op.getCacheControl(), + ptrOperandIndex)) { + call->setAttr(TritonGEN::TritonGENDialect::getCacheControlsAttrName(), + *optCacheControls); + } return rewriter.create(loc, resType, dest); }