-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Mlas] optimize MlasConv using thread partition opt #25255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -729,6 +729,58 @@ | |||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| void | ||||||||||
| MlasConvExpandThenGemmSegmentedThreaded( | ||||||||||
| void* Context, | ||||||||||
| ptrdiff_t Index | ||||||||||
| ){ | ||||||||||
| MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; | ||||||||||
|
|
||||||||||
| const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; | ||||||||||
|
|
||||||||||
| const size_t GroupCount = Parameters->GroupCount; | ||||||||||
| const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; | ||||||||||
|
|
||||||||||
| const size_t TargetThreadCount = WorkBlock->TargetThreadCount; | ||||||||||
|
|
||||||||||
| const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; | ||||||||||
| const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; | ||||||||||
|
|
||||||||||
| size_t BatchGroupStart; | ||||||||||
| size_t BatchGroupEnd; | ||||||||||
|
|
||||||||||
| if (uint32_t(Index) < BatchGroupCountExtra) { | ||||||||||
| BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; | ||||||||||
| BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; | ||||||||||
| } else { | ||||||||||
| BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; | ||||||||||
| BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| const size_t FilterCount = Parameters->FilterCount; | ||||||||||
| const size_t OutputSize = Parameters->OutputSize; | ||||||||||
| const size_t K = Parameters->K; | ||||||||||
|
|
||||||||||
| const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; | ||||||||||
| const size_t OutputGroupSize = FilterCount * OutputSize; | ||||||||||
| const size_t FilterGroupSize = FilterCount * K; | ||||||||||
|
|
||||||||||
| // std::cout << "Address of WorkBlock->WorkingBuffer" << WorkBlock->WorkingBuffer << std::endl; | ||||||||||
|
|
||||||||||
| for(size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++){ | ||||||||||
| size_t group = bg % GroupCount; | ||||||||||
|
|
||||||||||
| const float* input = WorkBlock->Input + bg * InputGroupSize; | ||||||||||
| const float* filter = WorkBlock->Filter + group * FilterGroupSize; | ||||||||||
| float* output = WorkBlock->Output + bg * OutputGroupSize; | ||||||||||
| const float* bias = WorkBlock->Bias; | ||||||||||
| if(bias != nullptr){ bias += group * FilterCount; } | ||||||||||
| float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K; | ||||||||||
|
|
||||||||||
| MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize); | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| inline | ||||||||||
| bool | ||||||||||
| MlasConvTryMultithread( | ||||||||||
|
|
@@ -913,6 +965,32 @@ | |||||||||
|
|
||||||||||
| #endif | ||||||||||
|
|
||||||||||
| if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) { | ||||||||||
|
|
||||||||||
| const size_t BatchGroupCount = BatchCount * GroupCount; | ||||||||||
|
|
||||||||||
| int32_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); | ||||||||||
| // TargetThreadCount = 16; | ||||||||||
|
||||||||||
| // TargetThreadCount = 16; |
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_ep_generic_interface
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_ep_generic_interface
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_debug
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_debug
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_vitisai
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_vitisai
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_xnnpack
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / build_x64_release_xnnpack
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU TensorRT CI Pipeline
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU TensorRT CI Pipeline
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU CUDA CI Pipeline
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU CUDA CI Pipeline
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU DML CI Pipeline
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / Windows GPU DML CI Pipeline
the following warning is treated as an error
Check failure on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo
the following warning is treated as an error
Check warning on line 978 in onnxruntime/core/mlas/lib/convolve.cpp
GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)
'initializing': conversion from 'ptrdiff_t' to 'int32_t', possible loss of data
Copilot
AI
Jul 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent indentation: line uses tab character while surrounding code uses spaces. This should be indented with spaces to match the existing code style.
| if(Parameters->BatchCount >1 || Parameters->GroupCount > 1){ | |
| if(Parameters->BatchCount >1 || Parameters->GroupCount > 1){ |
Copilot
AI
Jul 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after 'if' and around operators. Should be formatted as 'if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {' to match C++ style conventions.
| if(Parameters->BatchCount >1 || Parameters->GroupCount > 1){ | |
| if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug output statement should be removed from production code. This commented-out debug line should be deleted.