@@ -699,7 +699,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
699699 ZeKernelDesc.pKernelName = KernelName;
700700
701701 ze_kernel_handle_t ZeKernel;
702- ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
702+ auto ZeResult =
703+ ZE_CALL_NOCHECK (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
704+ // Gracefully handle the case that kernel create fails.
705+ if (ZeResult != ZE_RESULT_SUCCESS) {
706+ delete *RetKernel;
707+ *RetKernel = nullptr ;
708+ return ze2urResult (ZeResult);
709+ }
703710
704711 auto ZeDevice = It.first ;
705712
@@ -753,20 +760,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
753760 PArgValue = nullptr ;
754761 }
755762
763+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
764+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
765+ }
766+
756767 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
768+ ze_result_t ZeResult = ZE_RESULT_SUCCESS;
757769 if (Kernel->ZeKernelMap .empty ()) {
758770 auto ZeKernel = Kernel->ZeKernel ;
759- ZE2UR_CALL (zeKernelSetArgumentValue,
760- (ZeKernel, ArgIndex, ArgSize, PArgValue));
771+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
772+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
761773 } else {
762774 for (auto It : Kernel->ZeKernelMap ) {
763775 auto ZeKernel = It.second ;
764- ZE2UR_CALL (zeKernelSetArgumentValue,
765- (ZeKernel, ArgIndex, ArgSize, PArgValue));
776+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
777+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
766778 }
767779 }
768780
769- return UR_RESULT_SUCCESS;
781+ if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
782+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
783+ }
784+
785+ return ze2urResult (ZeResult);
770786}
771787
772788UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal (
@@ -815,6 +831,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(
815831 } catch (...) {
816832 return UR_RESULT_ERROR_UNKNOWN;
817833 }
834+ case UR_KERNEL_INFO_NUM_REGS:
818835 case UR_KERNEL_INFO_NUM_ARGS:
819836 return ReturnValue (uint32_t {Kernel->ZeKernelProperties ->numKernelArgs });
820837 case UR_KERNEL_INFO_REFERENCE_COUNT:
@@ -1065,6 +1082,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
10651082) {
10661083 std::ignore = Properties;
10671084 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
1085+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1086+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1087+ }
10681088 ZE2UR_CALL (zeKernelSetArgumentValue, (Kernel->ZeKernel , ArgIndex,
10691089 sizeof (void *), &ArgValue->ZeSampler ));
10701090
@@ -1080,6 +1100,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
10801100) {
10811101 std::ignore = Properties;
10821102
1103+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1104+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1105+ }
1106+
10831107 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
10841108 // The ArgValue may be a NULL pointer in which case a NULL value is used for
10851109 // the kernel argument declared as a pointer to global or constant memory.
0 commit comments