diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 9b2a7f432a544..5f1d855621c93 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -864,10 +864,8 @@ class TargetTransformInfoImplBase { } virtual InstructionCost - getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, - bool VariableMask, Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I = nullptr) const { + getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { return InstructionCost::getInvalid(); } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 314830652f0b6..fceff5f93b765 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1599,19 +1599,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { /*IsGatherScatter*/ true, CostKind); } - InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, - const Value *Ptr, bool VariableMask, - Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I) const override { + InstructionCost + getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override { // For a target without strided memory operations (or for an illegal // operation type on one which does), assume we lower to a gather/scatter // operation. (Which may in turn be scalarized.) - unsigned IID = Opcode == Instruction::Load ? Intrinsic::masked_gather - : Intrinsic::masked_scatter; + unsigned IID = MICA.getID() == Intrinsic::experimental_vp_strided_load + ? Intrinsic::masked_gather + : Intrinsic::masked_scatter; return thisT()->getGatherScatterOpCost( - MemIntrinsicCostAttributes(IID, DataTy, Ptr, VariableMask, Alignment, - I), + MemIntrinsicCostAttributes(IID, MICA.getDataType(), MICA.getPointer(), + MICA.getVariableMask(), MICA.getAlignment(), + MICA.getInst()), CostKind); } @@ -3062,21 +3062,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const override { unsigned Id = MICA.getID(); - Type *DataTy = MICA.getDataType(); - const Value *Ptr = MICA.getPointer(); - const Instruction *I = MICA.getInst(); - bool VariableMask = MICA.getVariableMask(); - Align Alignment = MICA.getAlignment(); switch (Id) { case Intrinsic::experimental_vp_strided_load: - case Intrinsic::experimental_vp_strided_store: { - unsigned Opcode = Id == Intrinsic::experimental_vp_strided_load - ? Instruction::Load - : Instruction::Store; - return thisT()->getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); - } + case Intrinsic::experimental_vp_strided_store: + return thisT()->getStridedMemoryOpCost(MICA, CostKind); case Intrinsic::masked_scatter: case Intrinsic::masked_gather: case Intrinsic::vp_scatter: diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 1d431959eaea3..74c2c896a8a88 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1212,14 +1212,20 @@ InstructionCost RISCVTTIImpl::getExpandCompressMemoryOpCost( LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind); } -InstructionCost RISCVTTIImpl::getStridedMemoryOpCost( - unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, - Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { - if (((Opcode == Instruction::Load || Opcode == Instruction::Store) && - !isLegalStridedLoadStore(DataTy, Alignment)) || - (Opcode != Instruction::Load && Opcode != Instruction::Store)) - return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask, - Alignment, CostKind, I); +InstructionCost +RISCVTTIImpl::getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const { + + unsigned Opcode = MICA.getID() == Intrinsic::experimental_vp_strided_load + ? Instruction::Load + : Instruction::Store; + + Type *DataTy = MICA.getDataType(); + Align Alignment = MICA.getAlignment(); + const Instruction *I = MICA.getInst(); + + if (!isLegalStridedLoadStore(DataTy, Alignment)) + return BaseT::getStridedMemoryOpCost(MICA, CostKind); if (CostKind == TTI::TCK_CodeSize) return TTI::TCC_Basic; diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index e32b1c553c57a..c1746e6d13166 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -202,11 +202,9 @@ class RISCVTTIImpl final : public BasicTTIImplBase { getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const override; - InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, - const Value *Ptr, bool VariableMask, - Align Alignment, - TTI::TargetCostKind CostKind, - const Instruction *I) const override; + InstructionCost + getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA, + TTI::TargetCostKind CostKind) const override; InstructionCost getCostOfKeepingLiveOverCall(ArrayRef Tys) const override;