@@ -305,7 +305,10 @@ bool IsHLWaveSensitive(Function *F) {
305305 return attrSet.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive);
306306}
307307
308- std::string GetHLFullName (HLOpcodeGroup op, unsigned opcode) {
308+ static std::string GetHLFunctionAttributeMangling (const AttributeSet &attribs);
309+
310+ std::string GetHLFullName (HLOpcodeGroup op, unsigned opcode,
311+ const AttributeSet &attribs = AttributeSet()) {
309312 assert (op != HLOpcodeGroup::HLExtIntrinsic && " else table name should be used" );
310313 std::string opName = GetHLOpcodeGroupFullName (op).str () + " ." ;
311314
@@ -321,22 +324,26 @@ std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
321324 case HLOpcodeGroup::HLIntrinsic: {
322325 // intrinsic with same signature will share the funciton now
323326 // The opcode is in arg0.
324- return opName;
327+ return opName + GetHLFunctionAttributeMangling (attribs) ;
325328 }
326329 case HLOpcodeGroup::HLMatLoadStore: {
327330 HLMatLoadStoreOpcode matOp = static_cast <HLMatLoadStoreOpcode>(opcode);
328331 return opName + GetHLOpcodeName (matOp).str ();
329332 }
330333 case HLOpcodeGroup::HLSubscript: {
331334 HLSubscriptOpcode subOp = static_cast <HLSubscriptOpcode>(opcode);
332- return opName + GetHLOpcodeName (subOp).str ();
335+ return opName + GetHLOpcodeName (subOp).str () + " ." +
336+ GetHLFunctionAttributeMangling (attribs);
333337 }
334338 case HLOpcodeGroup::HLCast: {
335339 HLCastOpcode castOp = static_cast <HLCastOpcode>(opcode);
336340 return opName + GetHLOpcodeName (castOp).str ();
337341 }
338- default :
342+ case HLOpcodeGroup::HLCreateHandle:
343+ case HLOpcodeGroup::HLAnnotateHandle:
339344 return opName;
345+ default :
346+ return opName + GetHLFunctionAttributeMangling (attribs);
340347 }
341348}
342349
@@ -417,38 +424,59 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode) {
417424 }
418425}
419426
420- static void SetHLFunctionAttribute (Function *F, HLOpcodeGroup group,
421- unsigned opcode) {
422- F->addFnAttr (Attribute::NoUnwind);
427+ static AttributeSet
428+ GetHLFunctionAttributes (LLVMContext &C, FunctionType *funcTy,
429+ const AttributeSet &origAttribs,
430+ HLOpcodeGroup group, unsigned opcode) {
431+ // Always add nounwind
432+ AttributeSet attribs =
433+ AttributeSet::get (C, AttributeSet::FunctionIndex,
434+ ArrayRef<Attribute::AttrKind>({Attribute::NoUnwind}));
435+
436+ auto addAttr = [&](Attribute::AttrKind Attr) {
437+ if (!attribs.hasAttribute (AttributeSet::FunctionIndex, Attr))
438+ attribs = attribs.addAttribute (C, AttributeSet::FunctionIndex, Attr);
439+ };
440+ auto copyAttr = [&](Attribute::AttrKind Attr) {
441+ if (origAttribs.hasAttribute (AttributeSet::FunctionIndex, Attr))
442+ addAttr (Attr);
443+ };
444+ auto copyStrAttr = [&](StringRef Kind) {
445+ if (origAttribs.hasAttribute (AttributeSet::FunctionIndex, Kind))
446+ attribs = attribs.addAttribute (
447+ C, AttributeSet::FunctionIndex, Kind,
448+ origAttribs.getAttribute (AttributeSet::FunctionIndex, Kind)
449+ .getValueAsString ());
450+ };
451+
452+ // Copy attributes we preserve from the original function.
453+ copyAttr (Attribute::ReadOnly);
454+ copyAttr (Attribute::ReadNone);
455+ copyStrAttr (HLWaveSensitive);
423456
424457 switch (group) {
425458 case HLOpcodeGroup::HLUnOp:
426459 case HLOpcodeGroup::HLBinOp:
427460 case HLOpcodeGroup::HLCast:
428461 case HLOpcodeGroup::HLSubscript:
429- if (!F->hasFnAttribute (Attribute::ReadNone)) {
430- F->addFnAttr (Attribute::ReadNone);
431- }
462+ addAttr (Attribute::ReadNone);
432463 break ;
433464 case HLOpcodeGroup::HLInit:
434- if (!F->hasFnAttribute (Attribute::ReadNone))
435- if (!F->getReturnType ()->isVoidTy ()) {
436- F->addFnAttr (Attribute::ReadNone);
437- }
465+ if (!funcTy->getReturnType ()->isVoidTy ()) {
466+ addAttr (Attribute::ReadNone);
467+ }
438468 break ;
439469 case HLOpcodeGroup::HLMatLoadStore: {
440470 HLMatLoadStoreOpcode matOp = static_cast <HLMatLoadStoreOpcode>(opcode);
441471 if (matOp == HLMatLoadStoreOpcode::ColMatLoad ||
442472 matOp == HLMatLoadStoreOpcode::RowMatLoad)
443- if (!F->hasFnAttribute (Attribute::ReadOnly)) {
444- F->addFnAttr (Attribute::ReadOnly);
445- }
473+ addAttr (Attribute::ReadOnly);
446474 } break ;
447475 case HLOpcodeGroup::HLCreateHandle: {
448- F-> addFnAttr (Attribute::ReadNone);
476+ addAttr (Attribute::ReadNone);
449477 } break ;
450478 case HLOpcodeGroup::HLAnnotateHandle: {
451- F-> addFnAttr (Attribute::ReadNone);
479+ addAttr (Attribute::ReadNone);
452480 } break ;
453481 case HLOpcodeGroup::HLIntrinsic: {
454482 IntrinsicOp intrinsicOp = static_cast <IntrinsicOp>(opcode);
@@ -461,7 +489,7 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
461489 case IntrinsicOp::IOP_GroupMemoryBarrier:
462490 case IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync:
463491 case IntrinsicOp::IOP_AllMemoryBarrier:
464- F-> addFnAttr (Attribute::NoDuplicate);
492+ addAttr (Attribute::NoDuplicate);
465493 break ;
466494 }
467495 } break ;
@@ -472,6 +500,75 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
472500 // No default attributes for these opcodes.
473501 break ;
474502 }
503+ assert (!(attribs.hasAttribute (AttributeSet::FunctionIndex,
504+ Attribute::ReadNone) &&
505+ attribs.hasAttribute (AttributeSet::FunctionIndex,
506+ Attribute::ReadOnly)) &&
507+ " conflicting ReadNone and ReadOnly attributes" );
508+ return attribs;
509+ }
510+
511+ static std::string GetHLFunctionAttributeMangling (const AttributeSet &attribs) {
512+ std::string mangledName;
513+ raw_string_ostream mangledNameStr (mangledName);
514+
515+ // Capture for adding in canonical order later.
516+ bool ReadNone = false ;
517+ bool ReadOnly = false ;
518+ bool NoDuplicate = false ;
519+ bool WaveSensitive = false ;
520+
521+ // Ensure every function attribute is recognized.
522+ for (unsigned Slot = 0 ; Slot < attribs.getNumSlots (); Slot++) {
523+ if (attribs.getSlotIndex (Slot) == AttributeSet::FunctionIndex) {
524+ for (auto it = attribs.begin (Slot), e = attribs.end (Slot); it != e;
525+ it++) {
526+ if (it->isEnumAttribute ()) {
527+ switch (it->getKindAsEnum ()) {
528+ case Attribute::ReadNone:
529+ ReadNone = true ;
530+ break ;
531+ case Attribute::ReadOnly:
532+ ReadOnly = true ;
533+ break ;
534+ case Attribute::NoDuplicate:
535+ NoDuplicate = true ;
536+ break ;
537+ case Attribute::NoUnwind:
538+ // All intrinsics have this attribute, so mangling is unaffected.
539+ break ;
540+ default :
541+ assert (false && " unexpected attribute for HLOperation" );
542+ }
543+ } else if (it->isStringAttribute ()) {
544+ StringRef Kind = it->getKindAsString ();
545+ if (Kind == HLWaveSensitive) {
546+ assert (it->getValueAsString () == " y" &&
547+ " otherwise, unexpected value for WaveSensitive attribute" );
548+ WaveSensitive = true ;
549+ } else {
550+ assert (false &&
551+ " unexpected string function attribute for HLOperation" );
552+ }
553+ }
554+ }
555+ }
556+ }
557+
558+ // Validate attribute combinations.
559+ assert (!(ReadNone && ReadOnly) &&
560+ " ReadNone and ReadOnly are mutually exclusive" );
561+
562+ // Add mangling in canonical order
563+ if (NoDuplicate)
564+ mangledNameStr << " nd" ;
565+ if (ReadNone)
566+ mangledNameStr << " rn" ;
567+ if (ReadOnly)
568+ mangledNameStr << " ro" ;
569+ if (WaveSensitive)
570+ mangledNameStr << " wave" ;
571+ return mangledName;
475572}
476573
477574
@@ -497,7 +594,11 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
497594Function *GetOrCreateHLFunction (Module &M, FunctionType *funcTy,
498595 HLOpcodeGroup group, StringRef *groupName,
499596 StringRef *fnName, unsigned opcode,
500- const AttributeSet &attribs) {
597+ const AttributeSet &origAttribs) {
598+ // Set/transfer all common attributes
599+ AttributeSet attribs = GetHLFunctionAttributes (
600+ M.getContext (), funcTy, origAttribs, group, opcode);
601+
501602 std::string mangledName;
502603 raw_string_ostream mangledNameStr (mangledName);
503604 if (group == HLOpcodeGroup::HLExtIntrinsic) {
@@ -506,33 +607,31 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
506607 mangledNameStr << *groupName;
507608 mangledNameStr << ' .' ;
508609 mangledNameStr << *fnName;
610+ attribs = attribs.addAttribute (M.getContext (), AttributeSet::FunctionIndex,
611+ hlsl::HLPrefix, *groupName);
509612 }
510613 else {
511- mangledNameStr << GetHLFullName (group, opcode);
512- // Need to add wave sensitivity to name to prevent clashes with non-wave intrinsic
513- if (attribs.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive))
514- mangledNameStr << " wave" ;
614+ mangledNameStr << GetHLFullName (group, opcode, attribs);
515615 mangledNameStr << ' .' ;
516616 funcTy->print (mangledNameStr);
517617 }
518618
519619 mangledNameStr.flush ();
520620
521- Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy));
522- if (group == HLOpcodeGroup::HLExtIntrinsic) {
523- F->addFnAttr (hlsl::HLPrefix, *groupName);
621+ // Avoid getOrInsertFunction to verify attributes and type without casting.
622+ Function *F = cast_or_null<Function>(M.getNamedValue (mangledName));
623+ if (F) {
624+ assert (F->getFunctionType () == funcTy &&
625+ " otherwise, function type mismatch not captured by mangling" );
626+ // Compare attribute mangling to ensure function attributes are as expected.
627+ assert (
628+ GetHLFunctionAttributeMangling (F->getAttributes ().getFnAttributes ()) ==
629+ GetHLFunctionAttributeMangling (attribs) &&
630+ " otherwise, function attribute mismatch not captured by mangling" );
631+ } else {
632+ F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy, attribs));
524633 }
525634
526- SetHLFunctionAttribute (F, group, opcode);
527-
528- // Copy attributes
529- if (attribs.hasAttribute (AttributeSet::FunctionIndex, Attribute::ReadNone))
530- F->addFnAttr (Attribute::ReadNone);
531- if (attribs.hasAttribute (AttributeSet::FunctionIndex, Attribute::ReadOnly))
532- F->addFnAttr (Attribute::ReadOnly);
533- if (attribs.hasAttribute (AttributeSet::FunctionIndex, HLWaveSensitive))
534- F->addFnAttr (HLWaveSensitive, " y" );
535-
536635 return F;
537636}
538637
@@ -541,15 +640,17 @@ Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
541640Function *GetOrCreateHLFunctionWithBody (Module &M, FunctionType *funcTy,
542641 HLOpcodeGroup group, unsigned opcode,
543642 StringRef name) {
544- std::string operatorName = GetHLFullName (group, opcode);
643+ // Set/transfer all common attributes
644+ AttributeSet attribs = GetHLFunctionAttributes (
645+ M.getContext (), funcTy, AttributeSet (), group, opcode);
646+
647+ std::string operatorName = GetHLFullName (group, opcode, attribs);
545648 std::string mangledName = operatorName + " ." + name.str ();
546649 raw_string_ostream mangledNameStr (mangledName);
547650 funcTy->print (mangledNameStr);
548651 mangledNameStr.flush ();
549652
550- Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy));
551-
552- SetHLFunctionAttribute (F, group, opcode);
653+ Function *F = cast<Function>(M.getOrInsertFunction (mangledName, funcTy, attribs));
553654
554655 F->setLinkage (llvm::GlobalValue::LinkageTypes::InternalLinkage);
555656
0 commit comments