Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::YieldOp::create(*this, loc, value);
}

cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
mlir::Value stride) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride);
cir::PtrStrideOp
createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride,
std::optional<CIR_GEPNoWrapFlags> flags = std::nullopt) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride,
flags.value_or(CIR_GEPNoWrapFlags::none));
}

cir::CallOp createCallOp(mlir::Location loc,
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIREnumAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
include "mlir/IR/EnumAttr.td"
include "clang/CIR/Dialect/IR/CIRDialect.td"

class CIR_I32BitEnum<string name, string summary, list<BitEnumCaseBase> cases>
: I32BitEnum<name, summary, cases> {
let cppNamespace = "::cir";
}

class CIR_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, summary, cases> {
let cppNamespace = "::cir";
Expand Down
32 changes: 27 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,24 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
//===----------------------------------------------------------------------===//
// PtrStrideOp
//===----------------------------------------------------------------------===//
def CIR_GEPNone : I32BitEnumCaseNone<"none">;
def CIR_GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
def CIR_GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
def CIR_GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
def CIR_GEPInbounds
: BitEnumCaseGroup<"inbounds", [CIR_GEPInboundsFlag, CIR_GEPNusw]>;

def CIR_GEPNoWrapFlags
: CIR_I32BitEnum<"CIR_GEPNoWrapFlags", "::cir::CIR_GEPNoWrapFlags",
[CIR_GEPNone, CIR_GEPInboundsFlag, CIR_GEPNusw, CIR_GEPNuw,
CIR_GEPInbounds]> {
let cppNamespace = "::cir";
let printBitEnumPrimaryGroups = 1;
}

def CIR_GEPNoWrapFlagsProp : EnumProp<CIR_GEPNoWrapFlags> {
let defaultValue = interfaceType#"::none";
}

def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[
Pure, AllTypesMatch<["base", "result"]>
Expand All @@ -397,19 +415,23 @@ def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[

```mlir
%3 = cir.const 0 : i32

%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>

%5 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds), !cir.ptr<i32>

%6 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds|nuw), !cir.ptr<i32>

```
}];

let arguments = (ins
CIR_PointerType:$base,
CIR_AnyFundamentalIntType:$stride
);
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride,
CIR_GEPNoWrapFlagsProp:$noWrapFlags);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
$base`,` $stride `:` functional-type(operands, results) attr-dict
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results) attr-dict
}];

let extraClassDeclaration = [{
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2842,10 +2842,16 @@ mlir::Value CIRGenFunction::emitCheckedInBoundsGEP(
assert(IdxList.size() == 1 && "multi-index ptr arithmetic NYI");
mlir::Value GEPVal =
builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr, IdxList[0]);

// If the pointer overflow sanitizer isn't enabled, do nothing.
if (!SanOpts.has(SanitizerKind::PointerOverflow))
return GEPVal;
if (!SanOpts.has(SanitizerKind::PointerOverflow)) {
cir::CIR_GEPNoWrapFlags nwFlags = cir::CIR_GEPNoWrapFlags::inbounds;
if (!SignedIndices && !IsSubtraction)
nwFlags = nwFlags | cir::CIR_GEPNoWrapFlags::nuw;
return builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr,
IdxList[0], nwFlags);
}

return GEPVal;

// TODO(cir): the unreachable code below hides a substantial amount of code
// from the original codegen related with pointer overflow sanitizer.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
Expand Down
22 changes: 20 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ void walkRegionSkipping(mlir::Region &region,
});
}

/// Convert from a CIR PtrStrideOp kind to an LLVM IR equivalent of GEP.
mlir::LLVM::GEPNoWrapFlags
convertPtrStrideKindToGEPFlags(cir::CIR_GEPNoWrapFlags flags) {
using CIRFlags = cir::CIR_GEPNoWrapFlags;
using LLVMFlags = mlir::LLVM::GEPNoWrapFlags;

LLVMFlags x = LLVMFlags::none;
if ((flags & CIRFlags::inboundsFlag) == CIRFlags::inboundsFlag)
x = x | LLVMFlags::inboundsFlag;
if ((flags & CIRFlags::nusw) == CIRFlags::nusw)
x = x | LLVMFlags::nusw;
if ((flags & CIRFlags::inbounds) == CIRFlags::inbounds)
x = x | LLVMFlags::inbounds;
if ((flags & CIRFlags::nuw) == CIRFlags::nuw)
x = x | LLVMFlags::nuw;
return x;
}

/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
mlir::LLVM::ICmpPredicate convertCmpKindToICmpPredicate(cir::CmpOpKind kind,
bool isSigned) {
Expand Down Expand Up @@ -1023,9 +1041,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index,
convertPtrStrideKindToGEPFlags(adaptor.getNoWrapFlags()));
return mlir::success();
}

Expand Down
21 changes: 11 additions & 10 deletions clang/test/CIR/CodeGen/pointer-arith-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ void *f4(void *a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[SUB]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[SUB]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>

// LLVM-LABEL: f4
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f4, just make sure it does not crash.
void *f4_1(void *a, int b) { return (a -= b); }
Expand All @@ -52,13 +52,13 @@ FP f5(FP a, int b) { return a + b; }
// CIR-LABEL: f5
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[STRIDE]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[STRIDE]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>

// LLVM-LABEL: f5
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]

// These test the same paths above, just make sure it does not crash.
FP f5_1(FP a, int b) { return (a += b); }
Expand All @@ -70,14 +70,14 @@ FP f7(FP a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[SUB]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[SUB]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>

// LLVM-LABEL: f7
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f7, just make sure it does not crash.
FP f7_1(FP a, int b) { return (a -= b); }
Expand All @@ -87,14 +87,14 @@ void f8(void *a, int b) { return *(id(a + b)); }
// CIR-LABEL: f8
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[STRIDE]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[STRIDE]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.return

// LLVM-LABEL: f8
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: ret void

void f8_1(void *a, int b) { return a[b]; }
Expand All @@ -119,7 +119,8 @@ unsigned char *p(unsigned int x) {

// CIR-LABEL: @p
// CIR: %[[SUB:.*]] = cir.binop(sub
// CIR: cir.ptr_stride {{.*}}, %[[SUB]] : (!cir.ptr<!u8i>, !u32i) -> !cir.ptr<!u8i>
// CIR: cir.ptr_stride inbounds|nuw {{.*}}, %[[SUB]] : (!cir.ptr<!u8i>, !u32i) -> !cir.ptr<!u8i>

// LLVM-LABEL: @p
// LLVM: getelementptr i8, ptr {{.*}}
// LLVM: getelementptr inbounds nuw i8, ptr {{.*}}

12 changes: 6 additions & 6 deletions clang/test/CIR/CodeGen/pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
void foo(int *iptr, char *cptr, unsigned ustride) {
*(iptr + 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
*(cptr + 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
*(iptr - 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
*(cptr - 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
*(iptr + ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !u32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds|nuw %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !u32i) -> !cir.ptr<!s32i>

// Must convert unsigned stride to a signed one.
*(iptr - ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
}

void testPointerSubscriptAccess(int *ptr) {
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CIR/IR/ptr_stride.cir
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ module {
%4 = cir.ptr_stride %2, %3 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
cir.return
}

cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
%0 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%1 = cir.ptr_stride nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%2 = cir.ptr_stride inbounds|nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%3 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
cir.return
}
}

// CHECK: cir.func @arraysubscript(%arg0: !s32i) {
Expand All @@ -20,3 +28,12 @@ module {
// CHECK-NEXT: %4 = cir.ptr_stride %2, %3 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }


// CHECK: cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
// CHECK-NEXT: %0 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %1 = cir.ptr_stride nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %2 = cir.ptr_stride inbounds|nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %3 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }
Loading