@@ -48,8 +48,7 @@ template <typename SYCLType> static bool isMemRefOf(const Type &type) {
4848}
4949
5050// Returns the element type of 'memref<?xSYCLType>'.
51- template <typename SYCLType>
52- static SYCLType getElementType (const Type &type) {
51+ template <typename SYCLType> static SYCLType getElementType (const Type &type) {
5352 assert (isMemRefOf<SYCLType>(type) && " Expecting memref<?xsycl::<type>>" );
5453 Type elemType = type.cast <MemRefType>().getElementType ();
5554 return elemType.cast <SYCLType>();
@@ -121,36 +120,74 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
121120 converter);
122121}
123122
123+ // / Create a LLVM struct type with name \p name, and the converted \p body as
124+ // / the body.
125+ static Optional<Type> convertBodyType (StringRef name,
126+ llvm::ArrayRef<mlir::Type> body,
127+ LLVMTypeConverter &converter) {
128+ auto convertedTy =
129+ LLVM::LLVMStructType::getIdentified (&converter.getContext (), name);
130+ if (!convertedTy.isInitialized ()) {
131+ SmallVector<Type> convertedElemTypes;
132+ convertedElemTypes.reserve (body.size ());
133+ if (failed (converter.convertTypes (body, convertedElemTypes)))
134+ return llvm::None;
135+ if (failed (convertedTy.setBody (convertedElemTypes, /* isPacked=*/ false )))
136+ return llvm::None;
137+ }
138+
139+ return convertedTy;
140+ }
141+
124142// / Converts SYCL accessor implement device type to LLVM type.
125143static Optional<Type>
126144convertAccessorImplDeviceType (sycl::AccessorImplDeviceType type,
127145 LLVMTypeConverter &converter) {
128- SmallVector<Type> convertedElemTypes;
129- convertedElemTypes.reserve (type.getBody ().size ());
130- if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
131- return llvm::None;
132-
133- return LLVM::LLVMStructType::getNewIdentified (
134- &converter.getContext (), " class.cl::sycl::detail::AccessorImplDevice" ,
135- convertedElemTypes, /* isPacked=*/ false );
146+ return convertBodyType (" class.cl::sycl::detail::AccessorImplDevice" +
147+ std::to_string (type.getDimension ()),
148+ type.getBody (), converter);
136149}
137150
138151// / Converts SYCL accessor type to LLVM type.
139152static Optional<Type> convertAccessorType (sycl::AccessorType type,
140153 LLVMTypeConverter &converter) {
141- SmallVector<Type> convertedElemTypes;
142- convertedElemTypes.reserve (type.getBody ().size ());
143- if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
144- return llvm::None;
154+ auto convertedTy = LLVM::LLVMStructType::getIdentified (
155+ &converter.getContext (),
156+ " class.cl::sycl::accessor" + std::to_string (type.getDimension ()));
157+ if (!convertedTy.isInitialized ()) {
158+ SmallVector<Type> convertedElemTypes;
159+ convertedElemTypes.reserve (type.getBody ().size ());
160+ if (failed (converter.convertTypes (type.getBody (), convertedElemTypes)))
161+ return llvm::None;
162+
163+ auto ptrTy = LLVM::LLVMPointerType::get (type.getType (), /* addressSpace=*/ 1 );
164+ auto structTy =
165+ LLVM::LLVMStructType::getLiteral (&converter.getContext (), ptrTy);
166+ convertedElemTypes.push_back (structTy);
167+
168+ if (failed (convertedTy.setBody (convertedElemTypes, /* isPacked=*/ false )))
169+ return llvm::None;
170+ }
171+
172+ return convertedTy;
173+ }
145174
146- auto ptrTy = LLVM::LLVMPointerType::get (type.getType (), /* addressSpace=*/ 1 );
147- auto structTy =
148- LLVM::LLVMStructType::getLiteral (&converter.getContext (), ptrTy);
149- convertedElemTypes.push_back (structTy);
175+ // / Converts SYCL item base type to LLVM type.
176+ static Optional<Type> convertItemBaseType (sycl::ItemBaseType type,
177+ LLVMTypeConverter &converter) {
178+ return convertBodyType (" class.cl::sycl::detail::ItemBase." +
179+ std::to_string (type.getDimension ()) +
180+ (type.getWithOffset () ? " .true" : " .false" ),
181+ type.getBody (), converter);
182+ }
150183
151- return LLVM::LLVMStructType::getNewIdentified (
152- &converter.getContext (), " class.cl::sycl::accessor" , convertedElemTypes,
153- /* isPacked=*/ false );
184+ // / Converts SYCL item type to LLVM type.
185+ static Optional<Type> convertItemType (sycl::ItemType type,
186+ LLVMTypeConverter &converter) {
187+ return convertBodyType (" class.cl::sycl::item." +
188+ std::to_string (type.getDimension ()) +
189+ (type.getWithOffset () ? " .true" : " .false" ),
190+ type.getBody (), converter);
154191}
155192
156193// ===----------------------------------------------------------------------===//
@@ -188,7 +225,7 @@ class ConstructorPattern final
188225 MLIRContext *context = module .getContext ();
189226
190227 // Lookup the ctor function to use.
191- const auto ®istry = SYCLFuncRegistry::create (module , rewriter);
228+ const auto ®istry = SYCLFuncRegistry::create (module , rewriter);
192229 auto voidTy = LLVM::LLVMVoidType::get (context);
193230 SYCLFuncDescriptor::FuncId funcId =
194231 registry.getFuncId (SYCLFuncDescriptor::FuncIdKind::IdCtor, voidTy,
@@ -235,12 +272,10 @@ void mlir::sycl::populateSYCLToLLVMTypeConversion(
235272 typeConverter.addConversion (
236273 [&](sycl::IDType type) { return convertIDType (type, typeConverter); });
237274 typeConverter.addConversion ([&](sycl::ItemBaseType type) {
238- llvm_unreachable (" SYCLToLLVM - sycl::ItemBaseType not handle (yet)" );
239- return llvm::None;
275+ return convertItemBaseType (type, typeConverter);
240276 });
241277 typeConverter.addConversion ([&](sycl::ItemType type) {
242- llvm_unreachable (" SYCLToLLVM - sycl::ItemType not handle (yet)" );
243- return llvm::None;
278+ return convertItemType (type, typeConverter);
244279 });
245280 typeConverter.addConversion ([&](sycl::NdItemType type) {
246281 llvm_unreachable (" SYCLToLLVM - sycl::NdItemType not handle (yet)" );
0 commit comments