@@ -1745,8 +1745,9 @@ static void ggml_cuda_op_mul_mat(
17451745 }
17461746}
17471747
1748+ template <typename T>
17481749static __global__ void k_compute_batched_ptrs (
1749- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1750+ const T * src0_as_f16, const T * src1_as_f16, char * dst,
17501751 const void ** ptrs_src, void ** ptrs_dst,
17511752 int64_t ne12, int64_t ne13,
17521753 int64_t ne23,
@@ -1774,7 +1775,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17741775 GGML_ASSERT (!ggml_is_transposed (src1));
17751776
17761777 GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1777- GGML_ASSERT (src0->type == GGML_TYPE_F16);
1778+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0-> type == GGML_TYPE_BF16 || src0-> type == GGML_TYPE_F32 );
17781779
17791780 // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
17801781 // As long as dst is contiguous this does not matter though.
@@ -1788,64 +1789,153 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17881789
17891790 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
17901791
1791- const half * src0_f16 = (const half *) src0->data ;
1792- float * dst_ddf = (float *) dst->data ;
1792+ const ggml_type src0_type = src0->type ;
1793+ const bool use_f32_path = src0_type == GGML_TYPE_F32;
1794+ const bool use_bf16_path = src0_type == GGML_TYPE_BF16;
17931795
1794- const half * src1_f16 = (const half *) src1 ->data ;
1796+ float * dst_ddf = (float *) dst ->data ;
17951797 const size_t ts_src1 = ggml_type_size (src1->type );
17961798 GGML_ASSERT (nb10 == ts_src1);
17971799 int64_t s11 = nb11 / ts_src1;
17981800 int64_t s12 = nb12 / ts_src1;
17991801 int64_t s13 = nb13 / ts_src1;
1802+
1803+ const half * src0_f16 = nullptr ;
1804+ const half * src1_f16 = nullptr ;
1805+ const nv_bfloat16 * src0_bf16 = nullptr ;
1806+ const nv_bfloat16 * src1_bf16 = nullptr ;
1807+ const float * src0_f32 = nullptr ;
1808+ const float * src1_f32 = nullptr ;
1809+
1810+ ggml_cuda_pool_alloc<half> src0_f16_alloc (ctx.pool ());
18001811 ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
1812+ ggml_cuda_pool_alloc<nv_bfloat16> src0_bf16_alloc (ctx.pool ());
1813+ ggml_cuda_pool_alloc<nv_bfloat16> src1_bf16_alloc (ctx.pool ());
1814+ ggml_cuda_pool_alloc<float > src0_f32_alloc (ctx.pool ());
1815+ ggml_cuda_pool_alloc<float > src1_f32_alloc (ctx.pool ());
1816+
1817+ if (use_f32_path) {
1818+ // F32 path
1819+ src0_f32 = (const float *) src0->data ;
1820+ if (src1->type == GGML_TYPE_F32) {
1821+ src1_f32 = (const float *) src1->data ;
1822+ } else {
1823+ // Convert src1 to F32
1824+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src1->type );
1825+ const int64_t ne_src1 = ggml_nelements (src1);
1826+ src1_f32_alloc.alloc (ne_src1);
1827+ GGML_ASSERT (to_fp32_cuda != nullptr );
18011828
1802- // convert src1 to fp16
1803- if (src1->type != GGML_TYPE_F16) {
1804- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1805- const int64_t ne_src1 = ggml_nelements (src1);
1806- src1_f16_alloc.alloc (ne_src1);
1807- GGML_ASSERT (to_fp16_cuda != nullptr );
1829+ to_fp32_cuda ((const void *)((const char *)src1->data ), src1_f32_alloc.get (), ne_src1, main_stream);
1830+ src1_f32 = src1_f32_alloc.get ();
1831+ s11 = ne10;
1832+ s12 = ne11*s11;
1833+ s13 = ne12*s12;
1834+ }
1835+ } else if (use_bf16_path) {
1836+ // BF16 path
1837+ src0_bf16 = (const nv_bfloat16 *) src0->data ;
1838+ if (src1->type == GGML_TYPE_BF16) {
1839+ src1_bf16 = (const nv_bfloat16 *) src1->data ;
1840+ } else {
1841+ // Convert src1 to BF16
1842+ const to_bf16_nc_cuda_t to_bf16_cuda = ggml_get_to_bf16_nc_cuda (src1->type );
1843+ const int64_t ne_src1 = ggml_nelements (src1);
1844+ src1_bf16_alloc.alloc (ne_src1);
1845+ GGML_ASSERT (to_bf16_cuda != nullptr );
18081846
1809- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1847+ to_bf16_cuda ((const void *)((const char *)src1->data ), src1_bf16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848+ src1_bf16 = src1_bf16_alloc.get ();
1849+ s11 = ne10;
1850+ s12 = ne11*s11;
1851+ s13 = ne12*s12;
1852+ }
1853+ } else {
1854+ // F16 path (default)
1855+ src0_f16 = (const half *) src0->data ;
1856+ if (src1->type == GGML_TYPE_F16) {
1857+ src1_f16 = (const half *) src1->data ;
1858+ } else {
1859+ // Convert src1 to F16
1860+ const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1861+ const int64_t ne_src1 = ggml_nelements (src1);
1862+ src1_f16_alloc.alloc (ne_src1);
1863+ GGML_ASSERT (to_fp16_cuda != nullptr );
18101864
1811- src1_f16 = src1_f16_alloc.get ();
1812- s11 = ne10;
1813- s12 = ne11*s11;
1814- s13 = ne12*s12;
1865+ to_fp16_cuda ((const void *)((const char *)src1->data ), src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1866+ src1_f16 = src1_f16_alloc.get ();
1867+ s11 = ne10;
1868+ s12 = ne11*s11;
1869+ s13 = ne12*s12;
1870+ }
18151871 }
18161872
18171873 ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
1874+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16 (ctx.pool ());
18181875 char * dst_t ;
18191876
1820- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1821- cudaDataType_t cu_data_type = CUDA_R_16F;
1877+ cublasComputeType_t cu_compute_type;
1878+ cudaDataType_t cu_data_type;
1879+ cudaDataType_t cu_data_type_a;
1880+ cudaDataType_t cu_data_type_b;
1881+
1882+ if (use_f32_path) {
1883+ cu_compute_type = CUBLAS_COMPUTE_32F;
1884+ cu_data_type = CUDA_R_32F;
1885+ cu_data_type_a = CUDA_R_32F;
1886+ cu_data_type_b = CUDA_R_32F;
1887+ } else if (use_bf16_path) {
1888+ cu_compute_type = CUBLAS_COMPUTE_32F;
1889+ cu_data_type = CUDA_R_16BF;
1890+ cu_data_type_a = CUDA_R_16BF;
1891+ cu_data_type_b = CUDA_R_16BF;
1892+ } else {
1893+ cu_compute_type = CUBLAS_COMPUTE_16F;
1894+ cu_data_type = CUDA_R_16F;
1895+ cu_data_type_a = CUDA_R_16F;
1896+ cu_data_type_b = CUDA_R_16F;
1897+ }
18221898
1823- // dst strides
18241899 size_t nbd2 = dst->nb [2 ];
18251900 size_t nbd3 = dst->nb [3 ];
18261901
18271902 const half alpha_f16 = 1 .0f ;
18281903 const half beta_f16 = 0 .0f ;
1829-
18301904 const float alpha_f32 = 1 .0f ;
18311905 const float beta_f32 = 0 .0f ;
18321906
1833- const void * alpha = &alpha_f16 ;
1834- const void * beta = &beta_f16 ;
1907+ const void * alpha;
1908+ const void * beta;
18351909
1836- if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1837- dst_t = (char *) dst_f16.alloc (ne_dst);
1910+ if (use_f32_path || cu_compute_type == CUBLAS_COMPUTE_32F) {
1911+ alpha = &alpha_f32;
1912+ beta = &beta_f32;
1913+ } else if (use_bf16_path) {
1914+ alpha = &alpha_f32;
1915+ beta = &beta_f32;
1916+ } else {
1917+ alpha = &alpha_f16;
1918+ beta = &beta_f16;
1919+ }
18381920
1839- nbd2 /= sizeof (float ) / sizeof (half);
1840- nbd3 /= sizeof (float ) / sizeof (half);
1921+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1922+ if (use_f32_path) {
1923+ dst_t = (char *) dst_ddf; // Direct F32 output
1924+ } else if (use_bf16_path) {
1925+ dst_t = (char *) dst_bf16.alloc (ne_dst);
1926+ nbd2 /= sizeof (float ) / sizeof (nv_bfloat16);
1927+ nbd3 /= sizeof (float ) / sizeof (nv_bfloat16);
1928+ } else {
1929+ dst_t = (char *) dst_f16.alloc (ne_dst);
1930+ nbd2 /= sizeof (float ) / sizeof (half);
1931+ nbd3 /= sizeof (float ) / sizeof (half);
1932+ }
18411933 } else {
18421934 dst_t = (char *) dst_ddf;
1843-
18441935 cu_compute_type = CUBLAS_COMPUTE_32F;
1845- cu_data_type = CUDA_R_32F;
1846-
1936+ cu_data_type = CUDA_R_32F;
18471937 alpha = &alpha_f32;
1848- beta = &beta_f32;
1938+ beta = &beta_f32;
18491939 }
18501940
18511941 int id = ggml_cuda_get_device ();
@@ -1886,11 +1976,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18861976 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
18871977 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
18881978 // use cublasGemmStridedBatchedEx
1979+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
1980+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
1981+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
1982+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
1983+
18891984 CUBLAS_CHECK (
18901985 cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18911986 ne01, ne11, ne10,
1892- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1893- src1_f16, CUDA_R_16F, s11, s12, // strideB
1987+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1988+ src1_ptr, cu_data_type_b, s11, s12, // strideB
18941989 beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
18951990 ne12*ne13,
18961991 cu_compute_type,
@@ -1902,34 +1997,74 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
19021997 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
19031998 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
19041999
2000+ const void * src0_ptr = use_f32_path ? (const void *)src0_f32 :
2001+ use_bf16_path ? (const void *)src0_bf16 : (const void *)src0_f16;
2002+ const void * src1_ptr = use_f32_path ? (const void *)src1_f32 :
2003+ use_bf16_path ? (const void *)src1_bf16 : (const void *)src1_f16;
2004+
2005+ size_t src1_stride_size = use_f32_path ? sizeof (float ) :
2006+ use_bf16_path ? sizeof (nv_bfloat16) : sizeof (half);
2007+
19052008 dim3 block_dims (ne13, ne12);
1906- k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1907- src0_f16, src1_f16, dst_t ,
1908- ptrs_src.get (), ptrs_dst.get (),
1909- ne12, ne13,
1910- ne23,
1911- nb02, nb03,
1912- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half),
1913- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half),
1914- nbd2, nbd3,
1915- r2, r3);
2009+ if ( use_f32_path ) {
2010+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2011+ (const float *)src0_ptr, (const float *)src1_ptr, dst_t ,
2012+ ptrs_src.get (), ptrs_dst.get (),
2013+ ne12, ne13,
2014+ ne23,
2015+ nb02, nb03,
2016+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2017+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2018+ nbd2, nbd3,
2019+ r2, r3);
2020+ } else if (use_bf16_path) {
2021+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2022+ (const nv_bfloat16*)src0_ptr, (const nv_bfloat16*)src1_ptr, dst_t ,
2023+ ptrs_src.get (), ptrs_dst.get (),
2024+ ne12, ne13,
2025+ ne23,
2026+ nb02, nb03,
2027+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2028+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2029+ nbd2, nbd3,
2030+ r2, r3);
2031+ } else {
2032+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
2033+ (const half*)src0_ptr, (const half*)src1_ptr, dst_t ,
2034+ ptrs_src.get (), ptrs_dst.get (),
2035+ ne12, ne13,
2036+ ne23,
2037+ nb02, nb03,
2038+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
2039+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
2040+ nbd2, nbd3,
2041+ r2, r3);
2042+ }
2043+
19162044 CUDA_CHECK (cudaGetLastError ());
19172045
19182046 CUBLAS_CHECK (
19192047 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19202048 ne01, ne11, ne10,
1921- alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1922- (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
2049+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
2050+ (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
19232051 beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19242052 ne23,
19252053 cu_compute_type,
19262054 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19272055 }
19282056#endif
19292057
1930- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1931- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1932- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2058+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
2059+ if (use_f32_path) {
2060+ // already in f32
2061+ } else if (use_bf16_path && cu_data_type == CUDA_R_16BF) {
2062+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
2063+ to_fp32_cuda (dst_bf16.get (), dst_ddf, ne_dst, main_stream);
2064+ } else if (cu_data_type == CUDA_R_16F) {
2065+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
2066+ to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
2067+ }
19332068 }
19342069}
19352070
@@ -1989,8 +2124,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19892124 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
19902125 } else if (!split && use_mul_mat_q) {
19912126 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1992- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1993- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2127+ } else if (!split && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32)
2128+ && (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)
2129+ && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
19942130 // general KQ + KQV multi-batch without FlashAttention
19952131 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
19962132 } else if (use_mul_mat_vec) {
0 commit comments