diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index 8d21a93c75..2fdb6b08a3 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -74,12 +74,11 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() { ur_exp_command_buffer_command_handle_t_:: ur_exp_command_buffer_command_handle_t_( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, - std::shared_ptr &&Node, CUDA_KERNEL_NODE_PARAMS Params, - uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr, - const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr) - : CommandBuffer(CommandBuffer), Kernel(Kernel), Node{std::move(Node)}, - Params(Params), WorkDim(WorkDim), RefCountInternal(1), - RefCountExternal(1) { + CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim, + const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr, + const size_t *LocalWorkSizePtr) + : CommandBuffer(CommandBuffer), Kernel(Kernel), Node(Node), Params(Params), + WorkDim(WorkDim), RefCountInternal(1), RefCountExternal(1) { CommandBuffer->incrementInternalReferenceCount(); const size_t CopySize = sizeof(size_t) * WorkDim; @@ -124,7 +123,7 @@ static ur_result_t getNodesFromSyncPoints( for (size_t i = 0; i < NumSyncPointsInWaitList; i++) { if (auto NodeHandle = SyncPoints.find(SyncPointWaitList[i]); NodeHandle != SyncPoints.end()) { - CuNodesList.push_back(*NodeHandle->second.get()); + CuNodesList.push_back(NodeHandle->second); } else { return UR_RESULT_ERROR_INVALID_VALUE; } @@ -161,22 +160,22 @@ static ur_result_t enqueueCommandBufferFillHelper( const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize, size_t Size, uint32_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, - ur_exp_command_buffer_sync_point_t *SyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; + ur_exp_command_buffer_sync_point_t *RetSyncPoint) { std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, - SyncPointWaitList, DepsList), - Result); + UR_CHECK_ERROR(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, + SyncPointWaitList, DepsList)); try { + // Graph node added to graph, if multiple nodes are created this will + // be set to the leaf node + CUgraphNode GraphNode; + const size_t N = Size / PatternSize; auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE ? *static_cast(DstDevice) : (CUdeviceptr)DstDevice; if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) { - // Create a new node - CUgraphNode GraphNode; CUDA_MEMSET_NODE_PARAMS NodeParams = {}; NodeParams.dst = DstPtr; NodeParams.elementSize = PatternSize; @@ -207,11 +206,6 @@ static ur_result_t enqueueCommandBufferFillHelper( cuGraphAddMemsetNode(&GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(), &NodeParams, CommandBuffer->Device->getNativeContext())); - - // Get sync point and register the cuNode with it. - *SyncPoint = - CommandBuffer->addSyncPoint(std::make_shared(GraphNode)); - } else { // CUDA has no memset functions that allow setting values more than 4 // bytes. UR API lets you pass an arbitrary "pattern" to the buffer @@ -222,10 +216,6 @@ static ur_result_t enqueueCommandBufferFillHelper( size_t NumberOfSteps = PatternSize / sizeof(uint8_t); - // Shared pointer that will point to the last node created - std::shared_ptr GraphNodePtr; - // Create a new node - CUgraphNode GraphNodeFirst; // Update NodeParam CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {}; NodeParamsStepFirst.dst = DstPtr; @@ -236,16 +226,12 @@ static ur_result_t enqueueCommandBufferFillHelper( NodeParamsStepFirst.width = 1; UR_CHECK_ERROR(cuGraphAddMemsetNode( - &GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(), + &GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(), &NodeParamsStepFirst, CommandBuffer->Device->getNativeContext())); - // Get sync point and register the cuNode with it. - *SyncPoint = CommandBuffer->addSyncPoint( - std::make_shared(GraphNodeFirst)); - DepsList.clear(); - DepsList.push_back(GraphNodeFirst); + DepsList.push_back(GraphNode); // we walk up the pattern in 1-byte steps, and call cuMemset for each // 1-byte chunk of the pattern. @@ -256,8 +242,6 @@ static ur_result_t enqueueCommandBufferFillHelper( // offset the pointer to the part of the buffer we want to write to auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t)); - // Create a new node - CUgraphNode GraphNode; // Update NodeParam CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {}; NodeParamsStep.dst = (CUdeviceptr)OffsetPtr; @@ -272,18 +256,20 @@ static ur_result_t enqueueCommandBufferFillHelper( DepsList.size(), &NodeParamsStep, CommandBuffer->Device->getNativeContext())); - GraphNodePtr = std::make_shared(GraphNode); - // Get sync point and register the cuNode with it. - *SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr); - DepsList.clear(); - DepsList.push_back(*GraphNodePtr.get()); + DepsList.push_back(GraphNode); } } + + // Get sync point and register the cuNode with it. + auto SyncPoint = CommandBuffer->addSyncPoint(GraphNode); + if (RetSyncPoint) { + *RetSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( @@ -368,18 +354,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); if (*pGlobalWorkSize == 0) { try { @@ -388,12 +367,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( DepsList.data(), DepsList.size())); // Get sync point and register the cuNode with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } // Set the number of threads per block to the number of threads per warp @@ -403,13 +384,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( uint32_t LocalSize = hKernel->getLocalSize(); CUfunction CuFunc = hKernel->get(); - Result = + UR_CHECK_ERROR( setKernelParams(hCommandBuffer->Context, hCommandBuffer->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, - hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid); - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid)); try { // Set node param structure with the kernel related data @@ -434,14 +412,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( hKernel->clearLocalSize(); // Get sync point and register the cuNode with it. - auto NodeSP = std::make_shared(GraphNode); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); if (pSyncPoint) { - *pSyncPoint = hCommandBuffer->addSyncPoint(NodeSP); + *pSyncPoint = SyncPoint; } auto NewCommand = new ur_exp_command_buffer_command_handle_t_{ - hCommandBuffer, hKernel, std::move(NodeSP), NodeParams, - workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize}; + hCommandBuffer, hKernel, GraphNode, NodeParams, + workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize}; NewCommand->incrementInternalReferenceCount(); hCommandBuffer->CommandHandles.push_back(NewCommand); @@ -451,9 +429,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( @@ -461,16 +439,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( size_t size, uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { CUDA_MEMCPY3D NodeParams = {}; @@ -482,12 +454,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( @@ -496,7 +470,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; @@ -505,13 +478,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( UR_ASSERT(size + srcOffset <= std::get(hSrcMem->Mem).getSize(), UR_RESULT_ERROR_INVALID_SIZE); - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Src = std::get(hSrcMem->Mem) @@ -528,12 +496,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( @@ -544,16 +514,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto SrcPtr = @@ -571,12 +535,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT @@ -586,16 +552,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Dst = std::get(hBuffer->Mem) @@ -610,12 +570,14 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT @@ -624,16 +586,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( size_t offset, size_t size, void *pDst, uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Src = std::get(hBuffer->Mem) @@ -648,12 +604,14 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT @@ -665,16 +623,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto DstPtr = @@ -691,12 +643,14 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT @@ -708,16 +662,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( uint32_t numSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, ur_exp_command_buffer_sync_point_t *pSyncPoint) { - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto SrcPtr = @@ -734,12 +682,14 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( &NodeParams, hCommandBuffer->Device->getNativeContext())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( @@ -750,13 +700,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( ur_exp_command_buffer_sync_point_t *pSyncPoint) { // Prefetch cmd is not supported by Cuda Graph. // We implement it as an empty node to enforce dependencies. - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { // Add an empty node to preserve dependencies. @@ -764,17 +712,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( DepsList.data(), DepsList.size())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } - setErrorMessage("Prefetch hint ignored and replaced with empty node as " - "prefetch is not supported by CUDA Graph backend", - UR_RESULT_SUCCESS); - Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC; } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( @@ -785,13 +731,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( ur_exp_command_buffer_sync_point_t *pSyncPoint) { // Mem-Advise cmd is not supported by Cuda Graph. // We implement it as an empty node to enforce dependencies. - ur_result_t Result = UR_RESULT_SUCCESS; CUgraphNode GraphNode; std::vector DepsList; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { // Add an empty node to preserve dependencies. @@ -799,18 +743,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( DepsList.data(), DepsList.size())); // Get sync point and register the cuNode with it. - *pSyncPoint = - hCommandBuffer->addSyncPoint(std::make_shared(GraphNode)); - - setErrorMessage("Memory advice ignored and replaced with empty node as " - "memory advice is not supported by CUDA Graph backend", - UR_RESULT_SUCCESS); - Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC; + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( @@ -860,7 +801,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; try { std::unique_ptr RetImplEvent{nullptr}; @@ -870,10 +810,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( CUstream CuStream = hQueue->getNextComputeStream( numEventsInWaitList, phEventWaitList, Guard, &StreamToken); - if ((Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList, - phEventWaitList)) != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList, + phEventWaitList)); if (phEvent) { RetImplEvent = std::unique_ptr( @@ -890,10 +828,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( *phEvent = RetImplEvent.release(); } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainCommandExp( @@ -1067,7 +1005,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( Params.sharedMemBytes = Kernel->getLocalSize(); Params.kernelParams = const_cast(Kernel->getArgIndices().data()); - CUgraphNode Node = *(hCommand->Node); + CUgraphNode Node = hCommand->Node; CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec; UR_CHECK_ERROR(cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params)); return UR_RESULT_SUCCESS; diff --git a/source/adapters/cuda/command_buffer.hpp b/source/adapters/cuda/command_buffer.hpp index d83269f2ae..504095612b 100644 --- a/source/adapters/cuda/command_buffer.hpp +++ b/source/adapters/cuda/command_buffer.hpp @@ -42,9 +42,9 @@ struct ur_exp_command_buffer_command_handle_t_ { ur_exp_command_buffer_command_handle_t_( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, - std::shared_ptr &&Node, CUDA_KERNEL_NODE_PARAMS Params, - uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr, - const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr); + CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim, + const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr, + const size_t *LocalWorkSizePtr); void setGlobalOffset(const size_t *GlobalWorkOffsetPtr) { const size_t CopySize = sizeof(size_t) * WorkDim; @@ -97,7 +97,7 @@ struct ur_exp_command_buffer_command_handle_t_ { ur_exp_command_buffer_handle_t CommandBuffer; ur_kernel_handle_t Kernel; - std::shared_ptr Node; + CUgraphNode Node; CUDA_KERNEL_NODE_PARAMS Params; uint32_t WorkDim; @@ -118,8 +118,8 @@ struct ur_exp_command_buffer_handle_t_ { ~ur_exp_command_buffer_handle_t_(); void registerSyncPoint(ur_exp_command_buffer_sync_point_t SyncPoint, - std::shared_ptr CuNode) { - SyncPoints[SyncPoint] = std::move(CuNode); + CUgraphNode CuNode) { + SyncPoints[SyncPoint] = CuNode; NextSyncPoint++; } @@ -130,8 +130,7 @@ struct ur_exp_command_buffer_handle_t_ { // Helper to register next sync point // @param CuNode Node to register as next sync point // @return Pointer to the sync that registers the Node - ur_exp_command_buffer_sync_point_t - addSyncPoint(std::shared_ptr CuNode) { + ur_exp_command_buffer_sync_point_t addSyncPoint(CUgraphNode CuNode) { ur_exp_command_buffer_sync_point_t SyncPoint = NextSyncPoint; registerSyncPoint(SyncPoint, std::move(CuNode)); return SyncPoint; @@ -173,8 +172,7 @@ struct ur_exp_command_buffer_handle_t_ { std::atomic_uint32_t RefCountExternal; // Map of sync_points to ur_events - std::unordered_map> + std::unordered_map SyncPoints; // Next sync_point value (may need to consider ways to reuse values if 32-bits // is not enough) diff --git a/source/adapters/hip/command_buffer.cpp b/source/adapters/hip/command_buffer.cpp index d9438eeb9c..4ff38626af 100644 --- a/source/adapters/hip/command_buffer.cpp +++ b/source/adapters/hip/command_buffer.cpp @@ -76,12 +76,11 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() { ur_exp_command_buffer_command_handle_t_:: ur_exp_command_buffer_command_handle_t_( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, - std::shared_ptr &&Node, hipKernelNodeParams Params, - uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr, - const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr) - : CommandBuffer(CommandBuffer), Kernel(Kernel), Node(std::move(Node)), - Params(Params), WorkDim(WorkDim), RefCountInternal(1), - RefCountExternal(1) { + hipGraphNode_t Node, hipKernelNodeParams Params, uint32_t WorkDim, + const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr, + const size_t *LocalWorkSizePtr) + : CommandBuffer(CommandBuffer), Kernel(Kernel), Node(Node), Params(Params), + WorkDim(WorkDim), RefCountInternal(1), RefCountExternal(1) { CommandBuffer->incrementInternalReferenceCount(); const size_t CopySize = sizeof(size_t) * WorkDim; @@ -125,7 +124,7 @@ static ur_result_t getNodesFromSyncPoints( for (size_t i = 0; i < NumSyncPointsInWaitList; i++) { if (auto NodeHandle = SyncPoints.find(SyncPointWaitList[i]); NodeHandle != SyncPoints.end()) { - HIPNodesList.push_back(*NodeHandle->second.get()); + HIPNodesList.push_back(NodeHandle->second); } else { return UR_RESULT_ERROR_INVALID_VALUE; } @@ -139,29 +138,23 @@ static ur_result_t enqueueCommandBufferFillHelper( const hipMemoryType DstType, const void *Pattern, size_t PatternSize, size_t Size, uint32_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, - ur_exp_command_buffer_sync_point_t *SyncPoint) { + ur_exp_command_buffer_sync_point_t *RetSyncPoint) { std::vector DepsList; - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, - SyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, + SyncPointWaitList, DepsList)); try { + // Graph node added to graph, if multiple nodes are created this will + // be set to the leaf node + hipGraphNode_t GraphNode; + const size_t N = Size / PatternSize; auto DstPtr = DstType == hipMemoryTypeDevice ? *static_cast(DstDevice) : DstDevice; if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) { - // Create a new node - hipGraphNode_t GraphNode; hipMemsetParams NodeParams = {}; NodeParams.dst = DstPtr; NodeParams.elementSize = PatternSize; @@ -192,10 +185,6 @@ static ur_result_t enqueueCommandBufferFillHelper( DepsList.data(), DepsList.size(), &NodeParams)); - // Get sync point and register the node with it. - *SyncPoint = CommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); - } else { // HIP has no memset functions that allow setting values more than 4 // bytes. UR API lets you pass an arbitrary "pattern" to the buffer @@ -206,11 +195,6 @@ static ur_result_t enqueueCommandBufferFillHelper( size_t NumberOfSteps = PatternSize / sizeof(uint8_t); - // Shared pointer that will point to the last node created - std::shared_ptr GraphNodePtr; - - // Create a new node - hipGraphNode_t GraphNodeFirst; // Update NodeParam hipMemsetParams NodeParamsStepFirst = {}; NodeParamsStepFirst.dst = DstPtr; @@ -220,16 +204,12 @@ static ur_result_t enqueueCommandBufferFillHelper( NodeParamsStepFirst.value = *(static_cast(Pattern)); NodeParamsStepFirst.width = 1; - UR_CHECK_ERROR(hipGraphAddMemsetNode( - &GraphNodeFirst, CommandBuffer->HIPGraph, DepsList.data(), - DepsList.size(), &NodeParamsStepFirst)); - - // Get sync point and register the node with it. - *SyncPoint = CommandBuffer->addSyncPoint( - std::make_shared(GraphNodeFirst)); + UR_CHECK_ERROR(hipGraphAddMemsetNode(&GraphNode, CommandBuffer->HIPGraph, + DepsList.data(), DepsList.size(), + &NodeParamsStepFirst)); DepsList.clear(); - DepsList.push_back(GraphNodeFirst); + DepsList.push_back(GraphNode); // we walk up the pattern in 1-byte steps, and add Memset node for each // 1-byte chunk of the pattern. @@ -241,8 +221,6 @@ static ur_result_t enqueueCommandBufferFillHelper( auto OffsetPtr = reinterpret_cast( reinterpret_cast(DstPtr) + (Step * sizeof(uint8_t))); - // Create a new node - hipGraphNode_t GraphNode; // Update NodeParam hipMemsetParams NodeParamsStep = {}; NodeParamsStep.dst = reinterpret_cast(OffsetPtr); @@ -256,14 +234,17 @@ static ur_result_t enqueueCommandBufferFillHelper( &GraphNode, CommandBuffer->HIPGraph, DepsList.data(), DepsList.size(), &NodeParamsStep)); - GraphNodePtr = std::make_shared(GraphNode); - // Get sync point and register the node with it. - *SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr); - DepsList.clear(); - DepsList.push_back(*GraphNodePtr.get()); + DepsList.push_back(GraphNode); } } + + // Get sync point and register the node with it. + auto SyncPoint = CommandBuffer->addSyncPoint(GraphNode); + if (RetSyncPoint) { + *RetSyncPoint = SyncPoint; + } + } catch (ur_result_t Err) { return Err; } @@ -346,14 +327,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( hipGraphNode_t GraphNode; std::vector DepsList; - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); if (*pGlobalWorkSize == 0) { try { @@ -362,8 +337,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( DepsList.data(), DepsList.size())); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -377,13 +354,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( uint32_t LocalSize = hKernel->getLocalSize(); hipFunction_t HIPFunc = hKernel->get(); - UR_CALL(setKernelParams(hCommandBuffer->Device, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, hKernel, HIPFunc, - ThreadsPerBlock, BlocksPerGrid), - Result); - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(setKernelParams( + hCommandBuffer->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, hKernel, HIPFunc, ThreadsPerBlock, BlocksPerGrid)); try { // Set node param structure with the kernel related data @@ -409,14 +382,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( hKernel->clearLocalSize(); // Get sync point and register the node with it. - auto NodeSP = std::make_shared(GraphNode); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); if (pSyncPoint) { - *pSyncPoint = hCommandBuffer->addSyncPoint(NodeSP); + *pSyncPoint = SyncPoint; } auto NewCommand = new ur_exp_command_buffer_command_handle_t_{ - hCommandBuffer, hKernel, std::move(NodeSP), NodeParams, - workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize}; + hCommandBuffer, hKernel, GraphNode, NodeParams, + workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize}; NewCommand->incrementInternalReferenceCount(); hCommandBuffer->CommandHandles.push_back(NewCommand); @@ -442,25 +415,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { - UR_CHECK_ERROR(hipGraphAddMemcpyNode1D( - &GraphNode, hCommandBuffer->HIPGraph, DepsList.data(), DepsList.size(), - pDst, pSrc, size, hipMemcpyHostToHost)); + UR_CHECK_ERROR(hipGraphAddMemcpyNode1D(&GraphNode, hCommandBuffer->HIPGraph, + DepsList.data(), DepsList.size(), + pDst, pSrc, size, hipMemcpyDefault)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -483,16 +450,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( UR_ASSERT(size + srcOffset <= std::get(hSrcMem->Mem).getSize(), UR_RESULT_ERROR_INVALID_SIZE); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Src = std::get(hSrcMem->Mem) @@ -505,8 +464,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( Dst, Src, size, hipMemcpyDeviceToDevice)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -527,16 +488,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto SrcPtr = @@ -554,8 +507,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( &NodeParams)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -575,16 +530,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Dst = std::get(hBuffer->Mem) @@ -595,8 +542,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( Dst, pSrc, size, hipMemcpyHostToDevice)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -615,16 +564,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto Src = std::get(hBuffer->Mem) @@ -635,8 +576,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( pDst, Src, size, hipMemcpyDeviceToHost)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -658,16 +601,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto DstPtr = @@ -683,8 +618,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( &NodeParams)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -706,16 +643,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { auto SrcPtr = @@ -731,8 +660,10 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( &NodeParams)); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -753,16 +684,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { // Create an empty node if the kernel workload size is zero @@ -770,13 +693,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( DepsList.data(), DepsList.size())); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); - - setErrorMessage("Prefetch hint ignored and replaced with empty node as " - "prefetch is not supported by HIP Graph backend", - UR_RESULT_SUCCESS); - return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -797,16 +717,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( UR_ASSERT(!(pSyncPointWaitList == NULL && numSyncPointsInWaitList > 0), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); - { - ur_result_t Result = UR_RESULT_SUCCESS; - UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, - pSyncPointWaitList, DepsList), - Result); - - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - } + UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList, + pSyncPointWaitList, DepsList)); try { // Create an empty node if the kernel workload size is zero @@ -814,13 +726,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( DepsList.data(), DepsList.size())); // Get sync point and register the node with it. - *pSyncPoint = hCommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); - - setErrorMessage("Memory advice ignored and replaced with empty node as " - "memory advice is not supported by HIP Graph backend", - UR_RESULT_SUCCESS); - return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); + if (pSyncPoint) { + *pSyncPoint = SyncPoint; + } } catch (ur_result_t Err) { return Err; } @@ -878,8 +787,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - ur_result_t Result = UR_RESULT_SUCCESS; - try { std::unique_ptr RetImplEvent{nullptr}; ScopedContext Active(hQueue->getDevice()); @@ -888,10 +795,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( hipStream_t HIPStream = hQueue->getNextComputeStream( numEventsInWaitList, phEventWaitList, Guard, &StreamToken); - if ((Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, - phEventWaitList)) != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, + phEventWaitList)); if (phEvent) { RetImplEvent = std::unique_ptr( @@ -908,10 +813,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( *phEvent = RetImplEvent.release(); } } catch (ur_result_t Err) { - Result = Err; + return Err; } - return Result; + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainCommandExp( @@ -978,12 +883,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( uint32_t ArgIndex = PointerArgDesc.argIndex; const void *ArgValue = PointerArgDesc.pNewPointerArg; - ur_result_t Result = UR_RESULT_SUCCESS; try { Kernel->setKernelArg(ArgIndex, sizeof(ArgValue), ArgValue); } catch (ur_result_t Err) { - Result = Err; - return Result; + return Err; } } @@ -996,7 +899,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( uint32_t ArgIndex = MemobjArgDesc.argIndex; ur_mem_handle_t ArgValue = MemobjArgDesc.hNewMemObjArg; - ur_result_t Result = UR_RESULT_SUCCESS; try { if (ArgValue == nullptr) { Kernel->setKernelArg(ArgIndex, 0, nullptr); @@ -1005,8 +907,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( Kernel->setKernelArg(ArgIndex, sizeof(void *), (void *)&HIPPtr); } } catch (ur_result_t Err) { - Result = Err; - return Result; + return Err; } } @@ -1020,13 +921,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( size_t ArgSize = ValueArgDesc.argSize; const void *ArgValue = ValueArgDesc.pNewValueArg; - ur_result_t Result = UR_RESULT_SUCCESS; - try { Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); } catch (ur_result_t Err) { - Result = Err; - return Result; + return Err; } } @@ -1064,12 +962,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; size_t BlocksPerGrid[3] = {1u, 1u, 1u}; hipFunction_t HIPFunc = Kernel->get(); - auto Result = setKernelParams(Device, WorkDim, GlobalWorkOffset, - GlobalWorkSize, LocalWorkSize, Kernel, HIPFunc, - ThreadsPerBlock, BlocksPerGrid); - if (Result != UR_RESULT_SUCCESS) { - return Result; - } + UR_CHECK_ERROR(setKernelParams(Device, WorkDim, GlobalWorkOffset, + GlobalWorkSize, LocalWorkSize, Kernel, HIPFunc, + ThreadsPerBlock, BlocksPerGrid)); hipKernelNodeParams &Params = hCommand->Params; @@ -1083,7 +978,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( Params.sharedMemBytes = Kernel->getLocalSize(); Params.kernelParams = const_cast(Kernel->getArgIndices().data()); - hipGraphNode_t Node = *(hCommand->Node); + hipGraphNode_t Node = hCommand->Node; hipGraphExec_t HipGraphExec = CommandBuffer->HIPGraphExec; UR_CHECK_ERROR(hipGraphExecKernelNodeSetParams(HipGraphExec, Node, &Params)); return UR_RESULT_SUCCESS; diff --git a/source/adapters/hip/command_buffer.hpp b/source/adapters/hip/command_buffer.hpp index 751fde3720..d744a3544d 100644 --- a/source/adapters/hip/command_buffer.hpp +++ b/source/adapters/hip/command_buffer.hpp @@ -41,9 +41,9 @@ struct ur_exp_command_buffer_command_handle_t_ { ur_exp_command_buffer_command_handle_t_( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, - std::shared_ptr &&Node, hipKernelNodeParams Params, - uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr, - const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr); + hipGraphNode_t Node, hipKernelNodeParams Params, uint32_t WorkDim, + const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr, + const size_t *LocalWorkSizePtr); void setGlobalOffset(const size_t *GlobalWorkOffsetPtr) { const size_t CopySize = sizeof(size_t) * WorkDim; @@ -96,7 +96,7 @@ struct ur_exp_command_buffer_command_handle_t_ { ur_exp_command_buffer_handle_t CommandBuffer; ur_kernel_handle_t Kernel; - std::shared_ptr Node; + hipGraphNode_t Node; hipKernelNodeParams Params; uint32_t WorkDim; @@ -117,7 +117,7 @@ struct ur_exp_command_buffer_handle_t_ { ~ur_exp_command_buffer_handle_t_(); void registerSyncPoint(ur_exp_command_buffer_sync_point_t SyncPoint, - std::shared_ptr &&HIPNode) { + hipGraphNode_t HIPNode) { SyncPoints[SyncPoint] = std::move(HIPNode); NextSyncPoint++; } @@ -129,8 +129,7 @@ struct ur_exp_command_buffer_handle_t_ { // Helper to register next sync point // @param HIPNode Node to register as next sync point // @return Pointer to the sync that registers the Node - ur_exp_command_buffer_sync_point_t - addSyncPoint(std::shared_ptr HIPNode) { + ur_exp_command_buffer_sync_point_t addSyncPoint(hipGraphNode_t HIPNode) { ur_exp_command_buffer_sync_point_t SyncPoint = NextSyncPoint; registerSyncPoint(SyncPoint, std::move(HIPNode)); return SyncPoint; @@ -171,8 +170,7 @@ struct ur_exp_command_buffer_handle_t_ { std::atomic_uint32_t RefCountExternal; // Map of sync_points to ur_events - std::unordered_map> + std::unordered_map SyncPoints; // Next sync_point value (may need to consider ways to reuse values if 32-bits // is not enough) diff --git a/test/conformance/exp_command_buffer/CMakeLists.txt b/test/conformance/exp_command_buffer/CMakeLists.txt index a8ecf793ab..a28d692d9b 100644 --- a/test/conformance/exp_command_buffer/CMakeLists.txt +++ b/test/conformance/exp_command_buffer/CMakeLists.txt @@ -12,4 +12,6 @@ add_conformance_test_with_kernels_environment(exp_command_buffer release.cpp retain.cpp invalid_update.cpp + commands.cpp + fill.cpp ) diff --git a/test/conformance/exp_command_buffer/commands.cpp b/test/conformance/exp_command_buffer/commands.cpp new file mode 100644 index 0000000000..412e4ab6de --- /dev/null +++ b/test/conformance/exp_command_buffer/commands.cpp @@ -0,0 +1,204 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "fixtures.h" + +struct urCommandBufferCommandsTest + : uur::command_buffer::urCommandBufferExpTest { + + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE( + uur::command_buffer::urCommandBufferExpTest::SetUp()); + + // Allocate USM pointers + for (auto &device_ptr : device_ptrs) { + ASSERT_SUCCESS(urUSMDeviceAlloc(context, device, nullptr, nullptr, + allocation_size, &device_ptr)); + ASSERT_NE(device_ptr, nullptr); + } + + for (auto &buffer : buffers) { + ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, + allocation_size, nullptr, + &buffer)); + + ASSERT_NE(buffer, nullptr); + } + } + + void TearDown() override { + for (auto &device_ptr : device_ptrs) { + if (device_ptr) { + EXPECT_SUCCESS(urUSMFree(context, device_ptr)); + } + } + + for (auto &buffer : buffers) { + if (buffer) { + EXPECT_SUCCESS(urMemRelease(buffer)); + } + } + + UUR_RETURN_ON_FATAL_FAILURE( + uur::command_buffer::urCommandBufferExpTest::TearDown()); + } + + static constexpr unsigned elements = 16; + static constexpr size_t allocation_size = elements * sizeof(uint32_t); + + std::array device_ptrs = {nullptr, nullptr}; + std::array buffers = {nullptr, nullptr}; +}; + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urCommandBufferCommandsTest); + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMMemcpyExp) { + ASSERT_SUCCESS(urCommandBufferAppendUSMMemcpyExp( + cmd_buf_handle, device_ptrs[0], device_ptrs[1], allocation_size, 0, + nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMFillExp) { + uint32_t pattern = 42; + ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp( + cmd_buf_handle, device_ptrs[0], &pattern, sizeof(pattern), + allocation_size, 0, nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferCopyExp) { + ASSERT_SUCCESS(urCommandBufferAppendMemBufferCopyExp( + cmd_buf_handle, buffers[0], buffers[1], 0, 0, allocation_size, 0, + nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferCopyRectExp) { + ur_rect_offset_t origin{0, 0, 0}; + ur_rect_region_t region{4, 4, 1}; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferCopyRectExp( + cmd_buf_handle, buffers[0], buffers[1], origin, origin, region, 4, 16, + 4, 16, 0, nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferReadExp) { + std::array host_data{}; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferReadExp( + cmd_buf_handle, buffers[0], 0, allocation_size, host_data.data(), 0, + nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferReadRectExp) { + std::array host_data{}; + ur_rect_offset_t origin{0, 0, 0}; + ur_rect_region_t region{4, 4, 1}; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferReadRectExp( + cmd_buf_handle, buffers[0], origin, origin, region, 4, 16, 4, 16, + host_data.data(), 0, nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferWriteExp) { + std::array host_data{}; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferWriteExp( + cmd_buf_handle, buffers[0], 0, allocation_size, host_data.data(), 0, + nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, + urCommandBufferAppendMemBufferWriteRectExp) { + std::array host_data{}; + ur_rect_offset_t origin{0, 0, 0}; + ur_rect_region_t region{4, 4, 1}; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferWriteRectExp( + cmd_buf_handle, buffers[0], origin, origin, region, 4, 16, 4, 16, + host_data.data(), 0, nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendMemBufferFillExp) { + uint32_t pattern = 42; + ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp( + cmd_buf_handle, buffers[0], &pattern, sizeof(pattern), 0, + allocation_size, 0, nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMPrefetchExp) { + ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( + cmd_buf_handle, device_ptrs[0], allocation_size, 0, 0, nullptr, + nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMAdviseExp) { + ASSERT_SUCCESS(urCommandBufferAppendUSMAdviseExp( + cmd_buf_handle, device_ptrs[0], allocation_size, 0, 0, nullptr, + nullptr)); +} + +struct urCommandBufferAppendKernelLaunchExpTest + : uur::command_buffer::urCommandBufferExpExecutionTest { + virtual void SetUp() override { + program_name = "saxpy_usm"; + UUR_RETURN_ON_FATAL_FAILURE(urCommandBufferExpExecutionTest::SetUp()); + for (auto &shared_ptr : shared_ptrs) { + ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr, + allocation_size, &shared_ptr)); + ASSERT_NE(shared_ptr, nullptr); + } + + int32_t *ptrX = static_cast(shared_ptrs[1]); + int32_t *ptrY = static_cast(shared_ptrs[2]); + for (size_t i = 0; i < global_size; i++) { + ptrX[i] = i; + ptrY[i] = i * 2; + } + + // Index 0 is output + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 0, nullptr, shared_ptrs[0])); + // Index 1 is A + ASSERT_SUCCESS(urKernelSetArgValue(kernel, 1, sizeof(A), nullptr, &A)); + // Index 2 is X + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 2, nullptr, shared_ptrs[1])); + // Index 3 is Y + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 3, nullptr, shared_ptrs[2])); + } + + virtual void TearDown() override { + for (auto &shared_ptr : shared_ptrs) { + if (shared_ptr) { + EXPECT_SUCCESS(urUSMFree(context, shared_ptr)); + } + } + + UUR_RETURN_ON_FATAL_FAILURE( + urCommandBufferExpExecutionTest::TearDown()); + } + + static constexpr size_t local_size = 4; + static constexpr size_t global_size = 32; + static constexpr size_t global_offset = 0; + static constexpr size_t n_dimensions = 1; + static constexpr size_t allocation_size = sizeof(uint32_t) * global_size; + static constexpr uint32_t A = 42; + std::array shared_ptrs = {nullptr, nullptr, nullptr}; +}; + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urCommandBufferAppendKernelLaunchExpTest); +TEST_P(urCommandBufferAppendKernelLaunchExpTest, Basic) { + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size, + &local_size, 0, nullptr, nullptr, nullptr)); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); + + ASSERT_SUCCESS( + urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + int32_t *ptrZ = static_cast(shared_ptrs[0]); + for (size_t i = 0; i < global_size; i++) { + uint32_t result = (A * i) + (i * 2); + ASSERT_EQ(result, ptrZ[i]); + } +} diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_cuda.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_cuda.match index e69de29bb2..8b13789179 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_cuda.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_cuda.match @@ -0,0 +1 @@ + diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_hip.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_hip.match index e69de29bb2..8b13789179 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_hip.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_hip.match @@ -0,0 +1 @@ + diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_level_zero_v2.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_level_zero_v2.match index e69de29bb2..95176fc51c 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_level_zero_v2.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_level_zero_v2.match @@ -0,0 +1,15 @@ +urCommandBufferAppendKernelLaunchExpTest.Basic{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__1__patternSize__1{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__256{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__1024__patternSize__256{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__4{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__8{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__16{{.*}} +urCommandBufferFillCommandsTest.Buffer/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__32{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__1__patternSize__1{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__256{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__1024__patternSize__256{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__4{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__8{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__16{{.*}} +urCommandBufferFillCommandsTest.USM/Intel_R__oneAPI_Unified_Runtime_over_Level_Zero___{{.*}}_size__256__patternSize__32{{.*}} diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match index 0a5a2b1317..2508f92fed 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match @@ -25,3 +25,4 @@ {{OPT}}InvalidUpdateTest.GlobalLocalSizeMistach/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} {{OPT}}InvalidUpdateTest.ImplToUserDefinedLocalSize/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} {{OPT}}InvalidUpdateTest.UserToImplDefinedLocalSize/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} +{{OPT}}urCommandBufferAppendKernelLaunchExpTest.Basic/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} diff --git a/test/conformance/exp_command_buffer/fill.cpp b/test/conformance/exp_command_buffer/fill.cpp new file mode 100644 index 0000000000..2b9a27cf2a --- /dev/null +++ b/test/conformance/exp_command_buffer/fill.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "fixtures.h" + +struct testParametersFill { + size_t size; + size_t pattern_size; +}; + +struct urCommandBufferFillCommandsTest + : uur::command_buffer::urCommandBufferExpTestWithParam { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE( + uur::command_buffer::urCommandBufferExpTestWithParam< + testParametersFill>::SetUp()); + + size = std::get<1>(GetParam()).size; + pattern_size = std::get<1>(GetParam()).pattern_size; + pattern = std::vector(pattern_size); + uur::generateMemFillPattern(pattern); + + // Allocate USM pointers + ASSERT_SUCCESS(urUSMDeviceAlloc(context, device, nullptr, nullptr, size, + &device_ptr)); + ASSERT_NE(device_ptr, nullptr); + + ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, size, + nullptr, &buffer)); + + ASSERT_NE(buffer, nullptr); + } + + void TearDown() override { + if (device_ptr) { + EXPECT_SUCCESS(urUSMFree(context, device_ptr)); + } + + if (buffer) { + EXPECT_SUCCESS(urMemRelease(buffer)); + } + + UUR_RETURN_ON_FATAL_FAILURE( + uur::command_buffer::urCommandBufferExpTestWithParam< + testParametersFill>::TearDown()); + } + + void verifyData(std::vector &output, size_t verify_size) { + size_t pattern_index = 0; + for (size_t i = 0; i < verify_size; ++i) { + ASSERT_EQ(output[i], pattern[pattern_index]) + << "Result mismatch at index: " << i; + + ++pattern_index; + if (pattern_index % pattern_size == 0) { + pattern_index = 0; + } + } + } + + static constexpr unsigned elements = 16; + static constexpr size_t allocation_size = elements * sizeof(uint32_t); + + std::vector pattern; + size_t size; + size_t pattern_size; + + ur_exp_command_buffer_sync_point_t sync_point; + void *device_ptr = nullptr; + ur_mem_handle_t buffer = nullptr; +}; + +static std::vector test_cases{ + /* Everything set to 1 */ + {1, 1}, + /* pattern_size == size */ + {256, 256}, + /* pattern_size < size */ + {1024, 256}, + /* pattern sizes corresponding to some common scalar and vector types */ + {256, 4}, + {256, 8}, + {256, 16}, + {256, 32}}; + +template +static std::string +printFillTestString(const testing::TestParamInfo &info) { + const auto device_handle = std::get<0>(info.param); + const auto platform_device_name = + uur::GetPlatformAndDeviceName(device_handle); + std::stringstream test_name; + test_name << platform_device_name << "__size__" + << std::get<1>(info.param).size << "__patternSize__" + << std::get<1>(info.param).pattern_size; + return test_name.str(); +} + +UUR_TEST_SUITE_P(urCommandBufferFillCommandsTest, testing::ValuesIn(test_cases), + printFillTestString); + +TEST_P(urCommandBufferFillCommandsTest, Buffer) { + ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp( + cmd_buf_handle, buffer, pattern.data(), pattern_size, 0, size, 0, + nullptr, &sync_point)); + + std::vector output(size, 1); + ASSERT_SUCCESS(urCommandBufferAppendMemBufferReadExp( + cmd_buf_handle, buffer, 0, size, output.data(), 1, &sync_point, + nullptr)); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); + + ASSERT_SUCCESS( + urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + verifyData(output, size); +} + +TEST_P(urCommandBufferFillCommandsTest, USM) { + ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp( + cmd_buf_handle, device_ptr, pattern.data(), pattern_size, size, 0, + nullptr, &sync_point)); + + std::vector output(size, 1); + ASSERT_SUCCESS(urCommandBufferAppendUSMMemcpyExp( + cmd_buf_handle, output.data(), device_ptr, size, 1, &sync_point, + nullptr)); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); + + ASSERT_SUCCESS( + urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + verifyData(output, size); +} diff --git a/test/conformance/exp_command_buffer/fixtures.h b/test/conformance/exp_command_buffer/fixtures.h index eeb0a5d5d8..85457bea97 100644 --- a/test/conformance/exp_command_buffer/fixtures.h +++ b/test/conformance/exp_command_buffer/fixtures.h @@ -55,6 +55,46 @@ struct urCommandBufferExpTest : uur::urContextTest { ur_bool_t updatable_command_buffer_support = false; }; +template +struct urCommandBufferExpTestWithParam : urQueueTestWithParam { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(uur::urQueueTestWithParam::SetUp()); + + size_t returned_size; + ASSERT_SUCCESS(urDeviceGetInfo(this->device, UR_DEVICE_INFO_EXTENSIONS, + 0, nullptr, &returned_size)); + + std::unique_ptr returned_extensions(new char[returned_size]); + + ASSERT_SUCCESS(urDeviceGetInfo(this->device, UR_DEVICE_INFO_EXTENSIONS, + returned_size, returned_extensions.get(), + nullptr)); + + std::string_view extensions_string(returned_extensions.get()); + bool command_buffer_support = + extensions_string.find(UR_COMMAND_BUFFER_EXTENSION_STRING_EXP) != + std::string::npos; + + if (!command_buffer_support) { + GTEST_SKIP() << "EXP command-buffer feature is not supported."; + } + + // Create a command-buffer + ASSERT_SUCCESS(urCommandBufferCreateExp(this->context, this->device, + nullptr, &cmd_buf_handle)); + ASSERT_NE(cmd_buf_handle, nullptr); + } + + void TearDown() override { + if (cmd_buf_handle) { + EXPECT_SUCCESS(urCommandBufferReleaseExp(cmd_buf_handle)); + } + UUR_RETURN_ON_FATAL_FAILURE(uur::urQueueTestWithParam::TearDown()); + } + + ur_exp_command_buffer_handle_t cmd_buf_handle = nullptr; +}; + struct urCommandBufferExpExecutionTest : uur::urKernelExecutionTest { void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(uur::urKernelExecutionTest::SetUp());