@@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
587587 LOG (FATAL) << " do not support " << dtype;
588588 }
589589 }
590- if (dtype.lanes () != 1 ) {
590+ if (! dtype.is_scalar () ) {
591591#if TVM_LLVM_VERSION >= 110
592- return llvm::FixedVectorType::get (etype, dtype.lanes ());
592+ if (dtype.is_scalable_vector ()) {
593+ return llvm::VectorType::get (etype, dtype.vscale_factor (), true );
594+ } else {
595+ return llvm::FixedVectorType::get (etype, dtype.lanes ());
596+ }
593597#else
598+ ICHECK (!dtype.is_scalable_vector ())
599+ << " Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
600+ " version." ;
594601 return llvm::VectorType::get (etype, dtype.lanes ());
595602#endif
596603 } else {
@@ -749,26 +756,6 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Modul
749756 return debug_info;
750757}
751758
752- llvm::Value* CodeGenLLVM::CreateBroadcast (llvm::Value* value, int lanes) {
753- #if TVM_LLVM_VERSION >= 110
754- llvm::Type* type = llvm::FixedVectorType::get (value->getType (), lanes);
755- #else
756- llvm::Type* type = llvm::VectorType::get (value->getType (), lanes);
757- #endif
758- llvm::Constant* undef = llvm::UndefValue::get (type);
759- llvm::Constant* zero = ConstInt32 (0 );
760- value = builder_->CreateInsertElement (undef, value, zero);
761- #if TVM_LLVM_VERSION >= 120
762- llvm::Constant* mask = llvm::ConstantVector::getSplat (llvm::ElementCount::getFixed (lanes), zero);
763- #elif TVM_LLVM_VERSION >= 110
764- llvm::Constant* mask =
765- llvm::ConstantVector::getSplat (llvm::ElementCount (lanes, /* Scalable=*/ false ), zero);
766- #else
767- llvm::Constant* mask = llvm::ConstantVector::getSplat (lanes, zero);
768- #endif
769- return builder_->CreateShuffleVector (value, undef, mask);
770- }
771-
772759llvm::Value* CodeGenLLVM::CreateVecSlice (llvm::Value* vec, int begin, int extent) {
773760 int num_elems = GetVectorNumElements (vec);
774761 if (extent == num_elems && begin == 0 ) return vec;
@@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper(
16931680 }
16941681
16951682 PrimExpr last_index = indices[indices.size () - 1 ];
1696- ICHECK_EQ (value_dtype.lanes (), last_index.dtype ().lanes () * buffer_element_dtype.lanes ());
1683+ ICHECK_EQ (value_dtype.get_lanes_or_vscale_factor (),
1684+ last_index.dtype ().get_lanes_or_vscale_factor () * buffer_element_dtype.lanes ());
16971685
16981686 // Record index and elemtype in original form used for alias info
16991687 PrimExpr last_index_origin = last_index;
@@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper(
17361724 llvm::Value* last_index_value;
17371725 int subelement_i = i;
17381726 if (const RampNode* ramp = last_index.as <RampNode>()) {
1739- // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
1740- ICHECK (!last_index.dtype ().is_scalable_vector ());
17411727 PrimExpr offset = ramp->base + (ramp->stride * i);
17421728 last_index_value = MakeValue (offset);
17431729 } else if (last_index.dtype ().lanes () > 1 ) {
@@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper(
17541740 all_index_values.push_back (last_index_value);
17551741
17561742 TypedPointer buffer_ptr =
1757- CreateBufferPtr (MakeValue (buffer->data ), buffer_element_dtype, all_index_values,
1758- value_dtype.with_lanes (value_dtype.lanes () / last_index.dtype ().lanes ()));
1743+ value_dtype.is_scalable_vector ()
1744+ ? CreateBufferPtr (MakeValue (buffer->data ), buffer_element_dtype, all_index_values,
1745+ value_dtype.with_scalable_vscale_factor (value_dtype.vscale_factor () /
1746+ last_index.dtype ().lanes ()))
1747+ : CreateBufferPtr (
1748+ MakeValue (buffer->data ), buffer_element_dtype, all_index_values,
1749+ value_dtype.with_lanes (value_dtype.lanes () / last_index.dtype ().lanes ()));
17591750 auto instruction = make_instruction (buffer_ptr, subelement_i, alignment, is_volatile);
17601751 AddAliasInfo (instruction, buffer->data .get (), last_index_origin, buffer_element_dtype_origin);
17611752 }
@@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
18701861}
18711862
18721863llvm::Value* CodeGenLLVM::VisitExpr_ (const BroadcastNode* op) {
1873- // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
1874- ICHECK (!op->dtype .is_scalable_vector ());
1875- int lanes = op->dtype .lanes ();
1876- return CreateBroadcast (MakeValue (op->value ), lanes);
1864+ DataType dtype = op->dtype ;
1865+ llvm::Value* value = MakeValue (op->value );
1866+ llvm::Type* type = DTypeToLLVMType (dtype);
1867+ llvm::Constant* undef = llvm::UndefValue::get (type);
1868+ llvm::Constant* zero = ConstInt32 (0 );
1869+ value = builder_->CreateInsertElement (undef, value, zero);
1870+ #if TVM_LLVM_VERSION >= 110
1871+ llvm::ElementCount ec =
1872+ llvm::ElementCount::get (dtype.get_lanes_or_vscale_factor (), dtype.is_scalable_vector ());
1873+ llvm::Constant* mask = llvm::ConstantVector::getSplat (ec, zero);
1874+ #else
1875+ ICHECK (!dtype.is_scalable_vector ())
1876+ << " Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later "
1877+ " version." ;
1878+ llvm::Constant* mask = llvm::ConstantVector::getSplat (dtype.lanes (), zero);
1879+ #endif
1880+ return builder_->CreateShuffleVector (value, undef, mask);
18771881}
18781882
18791883void CodeGenLLVM::VisitStmt_ (const BufferStoreNode* op) {
0 commit comments