diff --git a/mat_transpose/README.md b/mat_transpose/README.md index 4417c305..080b2e8f 100755 --- a/mat_transpose/README.md +++ b/mat_transpose/README.md @@ -9,13 +9,13 @@ - [X] mat_transpose_f32x4_col2row_kernel(float4向量化版本) - [X] mat_transpose_f32x4_row2col_kernel(float4向量化版本) - [X] mat_transpose_f32_diagnonal(对角轴应用于S=K) -- [ ] mat_transpose_f32x4_shared_col2row_kernel(float4向量化版本,共享内存) (施工中) -- [ ] mat_transpose_f32x4_shared_row2col_kernel(float4向量化版本,共享内存) (施工中) -- [ ] mat_transpose_f32x4_shared_bcf_col2row_kernel(float4向量化版本,共享内存,去bank conflict) (施工中) -- [ ] mat_transpose_f32x4_shared_bcf_row2col_kernel(float4向量化版本,共享内存,去bank conflict) (施工中) +- [X] mat_transpose_f32x4_shared_col2row_kernel(float4向量化版本,共享内存) +- [X] mat_transpose_f32x4_shared_row2col_kernel(float4向量化版本,共享内存) +- [X] mat_transpose_f32x4_shared_bcf_col2row_kernel(float4向量化版本,共享内存,去bank conflict) +- [X] mat_transpose_f32x4_shared_bcf_row2col_kernel(float4向量化版本,共享内存,去bank conflict) - [X] PyTorch bindings - +虽然是基础操作但是很适合练手,比矩阵乘法难度低一点但是可以其中可以用到的优化技巧都可以想办法用到这里来。 ## 测试 @@ -30,122 +30,158 @@ python3 mat_transpose.py ```bash ------------------------------------------------------------------------------------------------------------------------ S=1024, K=1024 - out_original: [1.22921503, 1.82269871, -0.72512561], validate False, time:0.00008798ms - out_f32_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03252983ms - out_f32_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02068520ms - out_f32_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.02265215ms - out_f32_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01682043ms - out_f32_diagnonal: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01259637ms - out_f32x4_col2row: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.03311539ms - out_f32x4_row2col: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01966453ms - out_f32x4_col2row(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01993465ms - out_f32x4_row2col(2d): [1.22921503, -0.72512561, 1.82269871], validate True , time:0.01886630ms - out_f32_th: [1.22921503, -0.72512561, 1.82269871], validate True , time:0.04084969ms + out_original: [0.2706067, 1.89055979, 0.62714416], validate False, time:0.00007796ms + out_f32_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03732634ms + out_f32_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03055906ms + out_f32_col2row(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02096868ms + out_f32_row2col(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03112197ms + out_f32_diagnonal: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02037907ms + out_f32x4_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.06107259ms + out_f32x4_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02692676ms + out_f32x4_col2row(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.03207874ms + out_f32x4_row2col(2d): [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01719213ms + out_f32x4_shared_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01326251ms + out_f32x4_shared_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.02352262ms + out_f32x4_shared_bcf_col2row: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01917195ms + out_f32x4_shared_bcf_row2col: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.01389265ms + out_f32_th: [0.2706067, 0.62714416, 1.89055979], validate True , time:0.05057526ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=1024, K=2048 - out_original: [1.68499732, 0.07425918, -0.02102743], validate False, time:0.00008655ms - out_f32_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05558133ms - out_f32_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03320456ms - out_f32_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02773643ms - out_f32_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02775192ms - out_f32x4_col2row: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05540919ms - out_f32x4_row2col: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03241920ms - out_f32x4_col2row(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.03086519ms - out_f32x4_row2col(2d): [1.68499732, -0.02102743, 0.07425918], validate True , time:0.02918243ms - out_f32_th: [1.68499732, -0.02102743, 0.07425918], validate True , time:0.05527997ms + out_original: [0.1013972, 0.10635406, 0.45091254], validate False, time:0.00007367ms + out_f32_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.11233115ms + out_f32_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05733228ms + out_f32_col2row(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.04851723ms + out_f32_row2col(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05224919ms + out_f32x4_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.10379744ms + out_f32x4_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05431175ms + out_f32x4_col2row(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.05774999ms + out_f32x4_row2col(2d): [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03115702ms + out_f32x4_shared_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03814983ms + out_f32x4_shared_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03473568ms + out_f32x4_shared_bcf_col2row: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03495407ms + out_f32x4_shared_bcf_row2col: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.03433728ms + out_f32_th: [0.1013972, 0.45091254, 0.10635406], validate True , time:0.08867288ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=1024, K=4096 - out_original: [-1.25576293, -1.05169642, 0.3411217], validate False, time:0.00008583ms - out_f32_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10143566ms - out_f32_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05657411ms - out_f32_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04857659ms - out_f32_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04864573ms - out_f32x4_col2row: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.10081601ms - out_f32x4_row2col: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05694509ms - out_f32x4_col2row(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.05282903ms - out_f32x4_row2col(2d): [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.04989004ms - out_f32_th: [-1.25576293, 0.3411217, -1.05169642], validate True , time:0.08283234ms + out_original: [1.78550363, -1.60489535, -0.16560346], validate False, time:0.00007296ms + out_f32_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.19823909ms + out_f32_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.11195445ms + out_f32_col2row(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.09996772ms + out_f32_row2col(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.09864736ms + out_f32x4_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.19718719ms + out_f32x4_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.11092091ms + out_f32x4_col2row(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.10105634ms + out_f32x4_row2col(2d): [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06530714ms + out_f32x4_shared_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06287837ms + out_f32x4_shared_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.07055283ms + out_f32x4_shared_bcf_col2row: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06612253ms + out_f32x4_shared_bcf_row2col: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.06411195ms + out_f32_th: [1.78550363, -0.16560346, -1.60489535], validate True , time:0.17973542ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=2048, K=1024 - out_original: [-0.47698042, -0.33631387, -0.16439888], validate False, time:0.00008464ms - out_f32_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05773354ms - out_f32_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03202701ms - out_f32_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02529335ms - out_f32_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02532363ms - out_f32x4_col2row: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05734038ms - out_f32x4_row2col: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03257370ms - out_f32x4_col2row(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.03162861ms - out_f32x4_row2col(2d): [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.02920556ms - out_f32_th: [-0.47698042, -0.16439888, -0.33631387], validate True , time:0.05421734ms + out_original: [-0.96589017, -0.53940338, 1.51841831], validate False, time:0.00007153ms + out_f32_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.10408664ms + out_f32_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05784106ms + out_f32_col2row(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04911971ms + out_f32_row2col(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04792857ms + out_f32x4_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.15571523ms + out_f32x4_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.07688594ms + out_f32x4_col2row(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05413485ms + out_f32x4_row2col(2d): [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.03497577ms + out_f32x4_shared_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04818010ms + out_f32x4_shared_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.05148292ms + out_f32x4_shared_bcf_col2row: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.04849076ms + out_f32x4_shared_bcf_row2col: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.03030324ms + out_f32_th: [-0.96589017, 1.51841831, -0.53940338], validate True , time:0.09853792ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=2048, K=2048 - out_original: [-1.11287403, -0.41300669, 0.3849003], validate False, time:0.00008488ms - out_f32_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10564256ms - out_f32_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05567479ms - out_f32_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04766870ms - out_f32_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.04748774ms - out_f32_diagnonal: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.02389789ms - out_f32x4_col2row: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.10338593ms - out_f32x4_row2col: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05683303ms - out_f32x4_col2row(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05457044ms - out_f32x4_row2col(2d): [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.05046129ms - out_f32_th: [-1.11287403, 0.3849003, -0.41300669], validate True , time:0.08376551ms + out_original: [0.66138971, 0.43854904, -1.19618118], validate False, time:0.00007439ms + out_f32_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.24223709ms + out_f32_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.15707016ms + out_f32_col2row(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.09814286ms + out_f32_row2col(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.13747311ms + out_f32_diagnonal: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.08852434ms + out_f32x4_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.26274681ms + out_f32x4_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.12002778ms + out_f32x4_col2row(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.15025878ms + out_f32x4_row2col(2d): [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07008457ms + out_f32x4_shared_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07605863ms + out_f32x4_shared_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.09375811ms + out_f32x4_shared_bcf_col2row: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07940960ms + out_f32x4_shared_bcf_row2col: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.07159257ms + out_f32_th: [0.66138971, -1.19618118, 0.43854904], validate True , time:0.25392270ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=2048, K=4096 - out_original: [1.41623259, -0.94387418, 0.48682433], validate False, time:0.00008965ms - out_f32_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19712996ms - out_f32_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10346484ms - out_f32_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08918452ms - out_f32_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.08975387ms - out_f32x4_col2row: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.19636393ms - out_f32x4_row2col: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.10541511ms - out_f32x4_col2row(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09951663ms - out_f32x4_row2col(2d): [1.41623259, 0.48682433, -0.94387418], validate True , time:0.09154367ms - out_f32_th: [1.41623259, 0.48682433, -0.94387418], validate True , time:0.14955282ms + out_original: [0.21140628, 0.86610204, -0.61084032], validate False, time:0.00007534ms + out_f32_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.51111245ms + out_f32_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.29512668ms + out_f32_col2row(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.25763965ms + out_f32_row2col(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.25509524ms + out_f32x4_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.47753954ms + out_f32x4_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.27053690ms + out_f32x4_col2row(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.26033616ms + out_f32x4_row2col(2d): [0.21140628, -0.61084032, 0.86610204], validate True , time:0.16601658ms + out_f32x4_shared_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.14935517ms + out_f32x4_shared_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.17617536ms + out_f32x4_shared_bcf_col2row: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.14183927ms + out_f32x4_shared_bcf_row2col: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.17589092ms + out_f32_th: [0.21140628, -0.61084032, 0.86610204], validate True , time:0.43119144ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=4096, K=1024 - out_original: [-0.58965021, 0.14326878, -0.19429833], validate False, time:0.00008726ms - out_f32_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10833144ms - out_f32_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05539703ms - out_f32_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996872ms - out_f32_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.04996324ms - out_f32x4_col2row: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.10815549ms - out_f32x4_row2col: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05626845ms - out_f32x4_col2row(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05652213ms - out_f32x4_row2col(2d): [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.05046964ms - out_f32_th: [-0.58965021, -0.19429833, 0.14326878], validate True , time:0.08028626ms + out_original: [-0.33594334, -0.13206008, 0.8452214], validate False, time:0.00007868ms + out_f32_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.26727128ms + out_f32_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.17777562ms + out_f32_col2row(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09764647ms + out_f32_row2col(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.13735604ms + out_f32x4_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.25628328ms + out_f32x4_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.15057874ms + out_f32x4_col2row(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.12607431ms + out_f32x4_row2col(2d): [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09281611ms + out_f32x4_shared_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.07143378ms + out_f32x4_shared_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.08804989ms + out_f32x4_shared_bcf_col2row: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.09320903ms + out_f32x4_shared_bcf_row2col: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.07376838ms + out_f32_th: [-0.33594334, 0.8452214, -0.13206008], validate True , time:0.25272131ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=4096, K=2048 - out_original: [-0.86244643, 0.61793995, -0.78971046], validate False, time:0.00008225ms - out_f32_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20896244ms - out_f32_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10261559ms - out_f32_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09091687ms - out_f32_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09096813ms - out_f32x4_col2row: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.20603800ms - out_f32x4_row2col: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10330606ms - out_f32x4_col2row(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.10366035ms - out_f32x4_row2col(2d): [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.09077668ms - out_f32_th: [-0.86244643, -0.78971046, 0.61793995], validate True , time:0.14721990ms + out_original: [1.44601941, 1.46612203, -2.00953078], validate False, time:0.00007796ms + out_f32_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.51826644ms + out_f32_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.31751609ms + out_f32_col2row(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.26685858ms + out_f32_row2col(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.18520737ms + out_f32x4_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.29121876ms + out_f32x4_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.16650081ms + out_f32x4_col2row(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.14630580ms + out_f32x4_row2col(2d): [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09408069ms + out_f32x4_shared_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09475493ms + out_f32x4_shared_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09508491ms + out_f32x4_shared_bcf_col2row: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09532118ms + out_f32x4_shared_bcf_row2col: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.09467864ms + out_f32_th: [1.44601941, -2.00953078, 1.46612203], validate True , time:0.26716113ms ------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------------------ S=4096, K=4096 - out_original: [-1.41012037, 0.45044342, 0.36045134], validate False, time:0.00008726ms - out_f32_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38568211ms - out_f32_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.41187572ms - out_f32_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21557069ms - out_f32_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.21556497ms - out_f32_diagnonal: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30571437ms - out_f32x4_col2row: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.38697243ms - out_f32x4_row2col: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.30080318ms - out_f32x4_col2row(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.23044729ms - out_f32x4_row2col(2d): [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.34491825ms - out_f32_th: [-1.41012037, 0.36045134, 0.45044342], validate True , time:0.56499386ms + out_original: [-1.07092094, -1.13755226, 0.99070781], validate False, time:0.00007606ms + out_f32_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.75331712ms + out_f32_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.52119255ms + out_f32_col2row(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.36621094ms + out_f32_row2col(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.36603284ms + out_f32_diagnonal: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.37416911ms + out_f32x4_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.96249247ms + out_f32x4_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.56916833ms + out_f32x4_col2row(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.48158646ms + out_f32x4_row2col(2d): [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.30216074ms + out_f32x4_shared_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.32637930ms + out_f32x4_shared_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.32455182ms + out_f32x4_shared_bcf_col2row: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.30707669ms + out_f32x4_shared_bcf_row2col: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.31853962ms + out_f32_th: [-1.07092094, 0.99070781, -1.13755226], validate True , time:0.91187215ms ------------------------------------------------------------------------------------------------------------------------ ``` diff --git a/mat_transpose/mat_transpose.cu b/mat_transpose/mat_transpose.cu index b700325e..9c76958a 100644 --- a/mat_transpose/mat_transpose.cu +++ b/mat_transpose/mat_transpose.cu @@ -12,6 +12,7 @@ #define WARP_SIZE 256 #define WARP_SIZE_S 16 +#define PAD 1 #define INT4(value) (reinterpret_cast(&(value))[0]) #define FLOAT4(value) (reinterpret_cast(&(value))[0]) #define HALF2(value) (reinterpret_cast(&(value))[0]) @@ -132,20 +133,146 @@ __global__ void mat_transpose_f32x4_row2col2d_kernel( } } -// TODO: may support shared memory optimize ? + __global__ void mat_transpose_f32x4_shared_col2row2d_kernel( - float *x, float *y, const int row, const int col) { - return; + float *x, float *y, const int row, const int col){ + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + const int local_x = threadIdx.x; + const int local_y = threadIdx.y; + __shared__ float tile[WARP_SIZE_S][WARP_SIZE_S * 4]; + if(global_x * 4 + 3 < col + 3 && global_y < row) { + // load value from x to shared memory + float4 x_val = reinterpret_cast(x)[global_y * col / 4 + global_x]; + FLOAT4(tile[local_y][local_x * 4]) = FLOAT4(x_val); + __syncthreads(); + float4 smem_val; + // load value from shared memory to y. + // add STRIDE to satisfied different block size. + constexpr int STRIDE = WARP_SIZE_S / 4; + smem_val.x = tile[(local_y % STRIDE) * 4 ][local_x * 4 + local_y / STRIDE]; + smem_val.y = tile[(local_y % STRIDE) * 4 + 1][local_x * 4 + local_y / STRIDE]; + smem_val.z = tile[(local_y % STRIDE) * 4 + 2][local_x * 4 + local_y / STRIDE]; + smem_val.w = tile[(local_y % STRIDE) * 4 + 3][local_x * 4 + local_y / STRIDE]; + //map index n*n to (n/4)*(n*4) + const int bid_y = blockIdx.y * blockDim.y; + const int out_y = global_x * 4 + local_y / STRIDE; + const int out_x = (local_y % STRIDE) * 4 + bid_y; + reinterpret_cast(y)[(out_y * row + out_x) / 4] = FLOAT4(smem_val); + } } + __global__ void mat_transpose_f32x4_shared_row2col2d_kernel( - float *x, float *y, const int row, const int col) { - return; + float *x, float *y, const int row, const int col){ + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + const int local_x = threadIdx.x; + const int local_y = threadIdx.y; + __shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S]; + if(global_y * 4 < row && global_x < col) { + // load value from x to shared memory + float4 x_val; + x_val.x = x[(global_y * 4) * col + global_x]; + x_val.y = x[(global_y * 4 + 1) * col + global_x]; + x_val.z = x[(global_y * 4 + 2) * col + global_x]; + x_val.w = x[(global_y * 4 + 3) * col + global_x]; + tile[local_y * 4 ][local_x] = x_val.x; + tile[local_y * 4 + 1][local_x] = x_val.y; + tile[local_y * 4 + 2][local_x] = x_val.z; + tile[local_y * 4 + 3][local_x] = x_val.w; + __syncthreads(); + float4 smem_val; + // load value from shared memory to y. + // add STRIDE to satisfied different block size. + //map index n*n to (n/4)*(n*4) + constexpr int STRIDE = WARP_SIZE_S / 4; + smem_val.x = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4]; + smem_val.y = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 1]; + smem_val.z = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 2]; + smem_val.w = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 3]; + const int bid_x = blockIdx.x * blockDim.x; + const int bid_y = blockIdx.y * blockDim.y; + + const int out_y = bid_x + (local_y % STRIDE) * 4; + const int out_x = bid_y * 4 + local_x * 4 + (local_y / STRIDE); + y[out_y * row + out_x] = smem_val.x; + y[(out_y + 1) * row + out_x] = smem_val.y; + y[(out_y + 2) * row + out_x] = smem_val.z; + y[(out_y + 3) * row + out_x] = smem_val.w; + } } + __global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel( - float *x, float *y, const int row, const int col) { - return; + float *x, float *y, const int row, const int col){ + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + const int local_x = threadIdx.x; + const int local_y = threadIdx.y; + __shared__ float tile[WARP_SIZE_S][WARP_SIZE_S * 4 + PAD]; + if(global_x * 4 + 3 < col + 3 && global_y < row) { + // load value from x to shared memory + float4 x_val = reinterpret_cast(x)[global_y * col / 4 + global_x]; + tile[local_y][local_x * 4 ] = x_val.x; + tile[local_y][local_x * 4 + 1] = x_val.y; + tile[local_y][local_x * 4 + 2] = x_val.z; + tile[local_y][local_x * 4 + 3] = x_val.w; + __syncthreads(); + float4 smem_val; + // load value from shared memory to y. + // add STRIDE to satisfied different block size. + constexpr int STRIDE = WARP_SIZE_S / 4; + smem_val.x = tile[(local_y % STRIDE) * 4 ][local_x * 4 + local_y / STRIDE]; + smem_val.y = tile[(local_y % STRIDE) * 4 + 1][local_x * 4 + local_y / STRIDE]; + smem_val.z = tile[(local_y % STRIDE) * 4 + 2][local_x * 4 + local_y / STRIDE]; + smem_val.w = tile[(local_y % STRIDE) * 4 + 3][local_x * 4 + local_y / STRIDE]; + //map index n*n to (n/4)*(n*4) + const int bid_y = blockIdx.y * blockDim.y; + const int out_y = global_x * 4 + local_y / STRIDE; + const int out_x = (local_y % STRIDE) * 4 + bid_y; + reinterpret_cast(y)[(out_y * row + out_x) / 4] = FLOAT4(smem_val); + } } +__global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel( + float *x, float *y, const int row, const int col){ + const int global_x = blockIdx.x * blockDim.x + threadIdx.x; + const int global_y = blockIdx.y * blockDim.y + threadIdx.y; + const int local_x = threadIdx.x; + const int local_y = threadIdx.y; + __shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S + PAD]; + if(global_y * 4 < row && global_x < col) { + // load value from x to shared memory + float4 x_val; + x_val.x = x[(global_y * 4) * col + global_x]; + x_val.y = x[(global_y * 4 + 1) * col + global_x]; + x_val.z = x[(global_y * 4 + 2) * col + global_x]; + x_val.w = x[(global_y * 4 + 3) * col + global_x]; + tile[local_y * 4 ][local_x] = x_val.x; + tile[local_y * 4 + 1][local_x] = x_val.y; + tile[local_y * 4 + 2][local_x] = x_val.z; + tile[local_y * 4 + 3][local_x] = x_val.w; + __syncthreads(); + float4 smem_val; + // load value from shared memory to y. + // add STRIDE to satisfied different block size. + //map index n*n to (n/4)*(n*4) + constexpr int STRIDE = WARP_SIZE_S / 4; + smem_val.x = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4]; + smem_val.y = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 1]; + smem_val.z = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 2]; + smem_val.w = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 3]; + const int bid_x = blockIdx.x * blockDim.x; + const int bid_y = blockIdx.y * blockDim.y; + + const int out_y = bid_x + (local_y % STRIDE) * 4; + const int out_x = bid_y * 4 + local_x * 4 + (local_y / STRIDE); + y[out_y * row + out_x] = smem_val.x; + y[(out_y + 1) * row + out_x] = smem_val.y; + y[(out_y + 2) * row + out_x] = smem_val.z; + y[(out_y + 3) * row + out_x] = smem_val.w; + } +} +// TODO: may support double buffer pipeline mat transpose ? // TODO: may support fp16 mat transpose ? // --------------------- PyTorch bindings for custom kernel ----------------------- @@ -201,9 +328,16 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_col2row, torch::kFloat32, float, 1, 4) TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_row2col, torch::kFloat32, float, 4, 1) // diagonal index method. TORCH_BINDING_MAT_TRANSPOSE2D(f32_diagonal, torch::kFloat32, float, 1, 1) -// TODO: may support shared memory optimize ? +// shared memory +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_col2row, torch::kFloat32, float, 1, 4) +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_row2col, torch::kFloat32, float, 4, 1) +// shared memory with bcf +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float, 1, 4) +TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32, float, 4, 1) +// TODO: may support double buffer pipeline mat transpose ? // TODO: may support fp16 mat transpose ? + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 1d index TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_col2row) @@ -217,4 +351,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_row2col2d) // diagonal index method. TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_diagonal2d) + // shared memory optimize + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_col2row2d) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_row2col2d) + //shared memory optimize with bcf + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_col2row2d) + TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_row2col2d) } diff --git a/mat_transpose/mat_transpose.py b/mat_transpose/mat_transpose.py index 4a6026cc..4132368e 100644 --- a/mat_transpose/mat_transpose.py +++ b/mat_transpose/mat_transpose.py @@ -60,32 +60,36 @@ def run_benchmark( real_t = f"{out.T.equal(x)}" out_val = out[:2, :2].flatten().detach().cpu().numpy().tolist()[:3] out_val = [round(v, 8) for v in out_val] - print(f"{out_info:>30}: {out_val}, validate {real_t:<5}, time:{mean_time:.8f}ms") + print(f"{out_info:>35}: {out_val}, validate {real_t:<5}, time:{mean_time:.8f}ms") if show_all: print(out) return out, mean_time -Ss = [1024, 2048, 4096] -Ks = [1024, 2048, 4096] -SKs = [(S, K) for S in Ss for K in Ks] +Ms = [1024, 2048, 4096] +Ns = [1024, 2048, 4096] +MNs = [(M, N) for M in Ms for N in Ns] copy_x = lambda x: x # show the three elements x[0][0], x[0][1], x[1][0] -for S, K in SKs: - print("-" * 120) - print(" " * 50 + f"S={S}, K={K}") - x = torch.randn((S, K)).cuda().float().contiguous() - y = torch.randn((K, S)).cuda().float().contiguous() +for M, N in MNs: + print("-" * 130) + print(" " * 55 + f"M={M}, N={N}") + x = torch.randn((M, N)).cuda().float().contiguous() + y = torch.randn((N, M)).cuda().float().contiguous() run_benchmark(partial(copy_x), x, "original") run_benchmark(lib.mat_transpose_f32_col2row, x, "f32_col2row", y) run_benchmark(lib.mat_transpose_f32_row2col, x, "f32_row2col", y) run_benchmark(lib.mat_transpose_f32_col2row2d, x, "f32_col2row(2d)", y) run_benchmark(lib.mat_transpose_f32_row2col2d, x, "f32_row2col(2d)", y) - if S == K: + if M == N: run_benchmark(lib.mat_transpose_f32_diagonal2d, x, "f32_diagnonal", y) run_benchmark(lib.mat_transpose_f32x4_col2row, x, "f32x4_col2row", y) run_benchmark(lib.mat_transpose_f32x4_row2col, x, "f32x4_row2col", y) run_benchmark(lib.mat_transpose_f32x4_col2row2d, x, "f32x4_col2row(2d)", y) run_benchmark(lib.mat_transpose_f32x4_row2col2d, x, "f32x4_row2col(2d)", y) + run_benchmark(lib.mat_transpose_f32x4_shared_col2row2d, x, "f32x4_shared_col2row(2d)", y) + run_benchmark(lib.mat_transpose_f32x4_shared_row2col2d, x, "f32x4_shared_row2col(2d)", y) + run_benchmark(lib.mat_transpose_f32x4_shared_bcf_col2row2d, x, "f32x4_shared_bcf_col2row(2d)", y) + run_benchmark(lib.mat_transpose_f32x4_shared_bcf_row2col2d, x, "f32x4_shared_bcf_row2col(2d)", y) run_benchmark(partial(torch.transpose_copy, dim0=0, dim1=1, out=y), x, "f32_th") - print("-" * 120) + print("-" * 130)