Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::YieldOp::create(*this, loc, value);
}

cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
mlir::Value stride) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride);
cir::PtrStrideOp
createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride,
std::optional<GEPNoWrapFlags> flags = std::nullopt) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride,
flags.value_or(GEPNoWrapFlags::none));
}

cir::CallOp createCallOp(mlir::Location loc,
Expand Down
26 changes: 21 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,22 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
//===----------------------------------------------------------------------===//
// PtrStrideOp
//===----------------------------------------------------------------------===//
def GEPNone : I32BitEnumCaseNone<"none">;
def GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
def GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
def GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
def GEPInbounds : BitEnumCaseGroup<"inbounds", [GEPInboundsFlag, GEPNusw]>;

def GEPNoWrapFlags
: I32BitEnum<"GEPNoWrapFlags", "::cir::GEPNoWrapFlags",
[GEPNone, GEPInboundsFlag, GEPNusw, GEPNuw, GEPInbounds]> {
let cppNamespace = "::cir";
let printBitEnumPrimaryGroups = 1;
}

def GEPNoWrapFlagsProp : EnumProp<GEPNoWrapFlags> {
let defaultValue = interfaceType#"::none";
}

def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[
Pure, AllTypesMatch<["base", "result"]>
Expand All @@ -397,19 +413,19 @@ def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[

```mlir
%3 = cir.const 0 : i32
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>

%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>
```
}];

let arguments = (ins
CIR_PointerType:$base,
CIR_AnyFundamentalIntType:$stride
);
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride,
GEPNoWrapFlagsProp:$noWrapFlags);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
`(` $base `:` qualified(type($base)) `,` $stride `:` qualified(type($stride)) `)`
`(` $base `:` qualified(type($base)) `,` $stride `:` qualified(type($stride))(`,` $noWrapFlags^)?`)`
`,` qualified(type($result)) attr-dict
}];

Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,10 +2835,16 @@ mlir::Value CIRGenFunction::emitCheckedInBoundsGEP(
assert(IdxList.size() == 1 && "multi-index ptr arithmetic NYI");
mlir::Value GEPVal =
builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr, IdxList[0]);

// If the pointer overflow sanitizer isn't enabled, do nothing.
if (!SanOpts.has(SanitizerKind::PointerOverflow))
return GEPVal;
if (!SanOpts.has(SanitizerKind::PointerOverflow)) {
cir::GEPNoWrapFlags nwFlags = cir::GEPNoWrapFlags::inbounds;
if (!SignedIndices && !IsSubtraction)
nwFlags = nwFlags | cir::GEPNoWrapFlags::nuw;
return builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr,
IdxList[0], nwFlags);
}

return GEPVal;

// TODO(cir): the unreachable code below hides a substantial amount of code
// from the original codegen related with pointer overflow sanitizer.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index,
static_cast<mlir::LLVM::GEPNoWrapFlags>(adaptor.getNoWrapFlags()));
return mlir::success();
}

Expand Down
21 changes: 11 additions & 10 deletions clang/test/CIR/CodeGen/pointer-arith-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ void *f4(void *a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!void>, %[[SUB]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!void>, %[[SUB]] : !s32i, inbounds)

// LLVM-LABEL: f4
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f4, just make sure it does not crash.
void *f4_1(void *a, int b) { return (a -= b); }
Expand All @@ -52,13 +52,13 @@ FP f5(FP a, int b) { return a + b; }
// CIR-LABEL: f5
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[STRIDE]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[STRIDE]] : !s32i, inbounds)

// LLVM-LABEL: f5
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]

// These test the same paths above, just make sure it does not crash.
FP f5_1(FP a, int b) { return (a += b); }
Expand All @@ -70,14 +70,14 @@ FP f7(FP a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[SUB]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[SUB]] : !s32i, inbounds)

// LLVM-LABEL: f7
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f7, just make sure it does not crash.
FP f7_1(FP a, int b) { return (a -= b); }
Expand All @@ -87,14 +87,14 @@ void f8(void *a, int b) { return *(id(a + b)); }
// CIR-LABEL: f8
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!void>, %[[STRIDE]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!void>, %[[STRIDE]] : !s32i, inbounds)
// CIR: cir.return

// LLVM-LABEL: f8
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: ret void

void f8_1(void *a, int b) { return a[b]; }
Expand All @@ -119,7 +119,8 @@ unsigned char *p(unsigned int x) {

// CIR-LABEL: @p
// CIR: %[[SUB:.*]] = cir.binop(sub
// CIR: cir.ptr_stride({{.*}} : !cir.ptr<!u8i>, %[[SUB]] : !u32i), !cir.ptr<!u8i>
// CIR: cir.ptr_stride({{.*}} : !cir.ptr<!u8i>, %[[SUB]] : !u32i, inbounds|nuw), !cir.ptr<!u8i>

// LLVM-LABEL: @p
// LLVM: getelementptr i8, ptr {{.*}}
// LLVM: getelementptr inbounds nuw i8, ptr {{.*}}

71 changes: 43 additions & 28 deletions clang/test/CIR/CodeGen/pointers.cpp
Original file line number Diff line number Diff line change
@@ -1,49 +1,64 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - 2>&1 | FileCheck %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o - 2>&1 | FileCheck %s --check-prefix=OGCG

// Should generate basic pointer arithmetics.
void foo(int *iptr, char *cptr, unsigned ustride) {
*(iptr + 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !s32i), !cir.ptr<!s32i>
// CIR: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !s32i, inbounds), !cir.ptr<!s32i>
// LLVM: getelementptr inbounds
// OGCG: getelementptr inbounds
*(cptr + 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#STRIDE]] : !s32i), !cir.ptr<!s8i>
// CIR: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#STRIDE]] : !s32i, inbounds), !cir.ptr<!s8i>
// LLVM: getelementptr inbounds
// OGCG: getelementptr inbounds
*(iptr - 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s32i>
// CIR: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s32i>
// LLVM: getelementptr inbounds
// OGCG: getelementptr inbounds
*(cptr - 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s8i>
// CIR: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s8i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s8i>
// LLVM: getelementptr inbounds
// OGCG: getelementptr inbounds
*(iptr + ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !u32i), !cir.ptr<!s32i>
// CIR: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#STRIDE]] : !u32i, inbounds|nuw), !cir.ptr<!s32i>

// LLVM: getelementptr inbounds nuw
// OGCG: getelementptr inbounds nuw

// Must convert unsigned stride to a signed one.
*(iptr - ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i), !cir.ptr<!s32i>
// CIR: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CIR: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i
// CIR: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#NEGSTRIDE]] : !s32i, inbounds), !cir.ptr<!s32i>
// LLVM: getelementptr inbounds
// OGCG: getelementptr inbounds
}

void testPointerSubscriptAccess(int *ptr) {
// CHECK: testPointerSubscriptAccess
// CIR: testPointerSubscriptAccess
ptr[1] = 2;
// CHECK: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CHECK: %[[#V2:]] = cir.const #cir.int<1> : !s32i
// CHECK: cir.ptr_stride(%[[#V1]] : !cir.ptr<!s32i>, %[[#V2]] : !s32i), !cir.ptr<!s32i>
// CIR: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CIR: %[[#V2:]] = cir.const #cir.int<1> : !s32i
// CIR: cir.ptr_stride(%[[#V1]] : !cir.ptr<!s32i>, %[[#V2]] : !s32i), !cir.ptr<!s32i>
}

void testPointerMultiDimSubscriptAccess(int **ptr) {
// CHECK: testPointerMultiDimSubscriptAccess
// CIR: testPointerMultiDimSubscriptAccess
ptr[1][2] = 3;
// CHECK: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!cir.ptr<!s32i>>>, !cir.ptr<!cir.ptr<!s32i>>
// CHECK: %[[#V2:]] = cir.const #cir.int<1> : !s32i
// CHECK: %[[#V3:]] = cir.ptr_stride(%[[#V1]] : !cir.ptr<!cir.ptr<!s32i>>, %[[#V2]] : !s32i), !cir.ptr<!cir.ptr<!s32i>>
// CHECK: %[[#V4:]] = cir.load{{.*}} %[[#V3]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CHECK: %[[#V5:]] = cir.const #cir.int<2> : !s32i
// CHECK: cir.ptr_stride(%[[#V4]] : !cir.ptr<!s32i>, %[[#V5]] : !s32i), !cir.ptr<!s32i>
// CIR: %[[#V1:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!cir.ptr<!s32i>>>, !cir.ptr<!cir.ptr<!s32i>>
// CIR: %[[#V2:]] = cir.const #cir.int<1> : !s32i
// CIR: %[[#V3:]] = cir.ptr_stride(%[[#V1]] : !cir.ptr<!cir.ptr<!s32i>>, %[[#V2]] : !s32i), !cir.ptr<!cir.ptr<!s32i>>
// CIR: %[[#V4:]] = cir.load{{.*}} %[[#V3]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CIR: %[[#V5:]] = cir.const #cir.int<2> : !s32i
// CIR: cir.ptr_stride(%[[#V4]] : !cir.ptr<!s32i>, %[[#V5]] : !s32i), !cir.ptr<!s32i>
}
17 changes: 17 additions & 0 deletions clang/test/CIR/IR/ptr_stride.cir
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ module {
%4 = cir.ptr_stride(%2 : !cir.ptr<!s32i>, %3 : !s32i), !cir.ptr<!s32i>
cir.return
}

cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
%0 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i), !cir.ptr<!s32i>
%1 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i, nuw), !cir.ptr<!s32i>
%2 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i, inbounds|nuw), !cir.ptr<!s32i>
%3 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i, none), !cir.ptr<!s32i>
cir.return
}
}

// CHECK: cir.func @arraysubscript(%arg0: !s32i) {
Expand All @@ -20,3 +28,12 @@ module {
// CHECK-NEXT: %4 = cir.ptr_stride(%2 : !cir.ptr<!s32i>, %3 : !s32i), !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }


// CHECK: cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
// CHECK-NEXT: %0 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i), !cir.ptr<!s32i>
// CHECK-NEXT: %1 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i, nuw), !cir.ptr<!s32i>
// CHECK-NEXT: %2 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i, inbounds|nuw), !cir.ptr<!s32i>
// CHECK-NEXT: %3 = cir.ptr_stride(%arg0 : !cir.ptr<!s32i>, %arg1 : !s32i), !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }
Loading