@@ -285,29 +285,29 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
285285 CHECK_INPUT (pos_ids);
286286
287287 // Extract dimensions from tensor shapes (flexible)
288- uint32_t rope_dim = q_rope_in-> shape [q_rope_in-> ndim - 1 ] ;
289- uint32_t no_rope_dim = q_nope_in-> shape [q_nope_in-> ndim - 1 ] ;
288+ uint32_t rope_dim = q_rope_in. size (- 1 ) ;
289+ uint32_t no_rope_dim = q_nope_in. size (- 1 ) ;
290290
291291 // Validate rope and no_rope dimensions are consistent
292- TVM_FFI_ICHECK_EQ (k_rope_in-> shape [k_rope_in-> ndim - 1 ] , rope_dim);
293- TVM_FFI_ICHECK_EQ (k_nope_in-> shape [k_nope_in-> ndim - 1 ] , no_rope_dim);
294- TVM_FFI_ICHECK_EQ (q_rope_out-> shape [q_rope_out-> ndim - 1 ] , rope_dim);
295- TVM_FFI_ICHECK_EQ (k_rope_out-> shape [k_rope_out-> ndim - 1 ] , rope_dim);
296- TVM_FFI_ICHECK_EQ (q_nope_out-> shape [q_nope_out-> ndim - 1 ] , no_rope_dim);
297- TVM_FFI_ICHECK_EQ (k_nope_out-> shape [k_nope_out-> ndim - 1 ] , no_rope_dim);
298- TVM_FFI_ICHECK_EQ (q_rope_in-> dtype , k_rope_in-> dtype );
299- TVM_FFI_ICHECK_EQ (q_rope_in-> dtype , q_nope_in-> dtype );
300- TVM_FFI_ICHECK_EQ (q_rope_in-> dtype , k_nope_in-> dtype );
301- TVM_FFI_ICHECK_EQ (q_rope_out-> dtype , k_rope_out-> dtype );
302- TVM_FFI_ICHECK_EQ (q_rope_out-> dtype , q_nope_out-> dtype );
303- TVM_FFI_ICHECK_EQ (q_rope_out-> dtype , k_nope_out-> dtype );
292+ TVM_FFI_ICHECK_EQ (k_rope_in. size (- 1 ) , rope_dim);
293+ TVM_FFI_ICHECK_EQ (k_nope_in. size (- 1 ) , no_rope_dim);
294+ TVM_FFI_ICHECK_EQ (q_rope_out. size (- 1 ) , rope_dim);
295+ TVM_FFI_ICHECK_EQ (k_rope_out. size (- 1 ) , rope_dim);
296+ TVM_FFI_ICHECK_EQ (q_nope_out. size (- 1 ) , no_rope_dim);
297+ TVM_FFI_ICHECK_EQ (k_nope_out. size (- 1 ) , no_rope_dim);
298+ TVM_FFI_ICHECK_EQ (q_rope_in. dtype () , k_rope_in. dtype () );
299+ TVM_FFI_ICHECK_EQ (q_rope_in. dtype () , q_nope_in. dtype () );
300+ TVM_FFI_ICHECK_EQ (q_rope_in. dtype () , k_nope_in. dtype () );
301+ TVM_FFI_ICHECK_EQ (q_rope_out. dtype () , k_rope_out. dtype () );
302+ TVM_FFI_ICHECK_EQ (q_rope_out. dtype () , q_nope_out. dtype () );
303+ TVM_FFI_ICHECK_EQ (q_rope_out. dtype () , k_nope_out. dtype () );
304304
305305 // Validate supported input data types (float16 or bfloat16)
306- TVM_FFI_ICHECK (q_rope_in-> dtype == dl_float16 || q_rope_in-> dtype == dl_bfloat16)
306+ TVM_FFI_ICHECK (q_rope_in. dtype () == dl_float16 || q_rope_in. dtype () == dl_bfloat16)
307307 << " Input dtype must be float16 or bfloat16" ;
308308
309309 // Validate supported output quantization data types (float8_e4m3fn or float8_e5m2)
310- TVM_FFI_ICHECK (q_rope_out-> dtype == dl_float8_e4m3fn || q_rope_out-> dtype == dl_float8_e5m2)
310+ TVM_FFI_ICHECK (q_rope_out. dtype () == dl_float8_e4m3fn || q_rope_out. dtype () == dl_float8_e5m2)
311311 << " Output dtype must be float8_e4m3fn or float8_e5m2" ;
312312
313313 // Q tensors are always 3D: (nnz, num_qo_heads, rope_dim/no_rope_dim)
@@ -318,7 +318,7 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
318318
319319 // K tensors can be 2D (MLA) or 3D (GQA/MHA)
320320 uint32_t num_kv_heads;
321- if (k_rope_in-> ndim == 2 ) {
321+ if (k_rope_in. ndim () == 2 ) {
322322 // MLA case: k_rope_in: (nnz, rope_dim), k_nope_in: (nnz, no_rope_dim)
323323 CHECK_DIM (2 , k_rope_in);
324324 CHECK_DIM (2 , k_nope_in);
@@ -331,81 +331,82 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
331331 CHECK_DIM (3 , k_nope_in);
332332 CHECK_DIM (3 , k_rope_out);
333333 CHECK_DIM (3 , k_nope_out);
334- num_kv_heads = k_rope_in-> shape [ 1 ] ;
334+ num_kv_heads = k_rope_in. size ( 1 ) ;
335335 }
336- uint32_t nnz = q_rope_in-> shape [ 0 ] ;
337- uint32_t num_qo_heads = q_rope_in-> shape [ 1 ] ;
336+ uint32_t nnz = q_rope_in. size ( 0 ) ;
337+ uint32_t num_qo_heads = q_rope_in. size ( 1 ) ;
338338
339339 // Validate consistent dimensions across all tensors
340- TVM_FFI_ICHECK_EQ (q_nope_in-> shape [ 0 ] , nnz);
341- TVM_FFI_ICHECK_EQ (k_rope_in-> shape [ 0 ] , nnz);
342- TVM_FFI_ICHECK_EQ (k_nope_in-> shape [ 0 ] , nnz);
343- TVM_FFI_ICHECK_EQ (q_rope_out-> shape [ 0 ] , nnz);
344- TVM_FFI_ICHECK_EQ (k_rope_out-> shape [ 0 ] , nnz);
345- TVM_FFI_ICHECK_EQ (q_nope_out-> shape [ 0 ] , nnz);
346- TVM_FFI_ICHECK_EQ (k_nope_out-> shape [ 0 ] , nnz);
340+ TVM_FFI_ICHECK_EQ (q_nope_in. size ( 0 ) , nnz);
341+ TVM_FFI_ICHECK_EQ (k_rope_in. size ( 0 ) , nnz);
342+ TVM_FFI_ICHECK_EQ (k_nope_in. size ( 0 ) , nnz);
343+ TVM_FFI_ICHECK_EQ (q_rope_out. size ( 0 ) , nnz);
344+ TVM_FFI_ICHECK_EQ (k_rope_out. size ( 0 ) , nnz);
345+ TVM_FFI_ICHECK_EQ (q_nope_out. size ( 0 ) , nnz);
346+ TVM_FFI_ICHECK_EQ (k_nope_out. size ( 0 ) , nnz);
347347
348348 // Validate Q tensor head dimensions are consistent
349- TVM_FFI_ICHECK_EQ (q_nope_in-> shape [ 1 ] , num_qo_heads);
350- TVM_FFI_ICHECK_EQ (q_rope_out-> shape [ 1 ] , num_qo_heads);
351- TVM_FFI_ICHECK_EQ (q_nope_out-> shape [ 1 ] , num_qo_heads);
349+ TVM_FFI_ICHECK_EQ (q_nope_in. size ( 1 ) , num_qo_heads);
350+ TVM_FFI_ICHECK_EQ (q_rope_out. size ( 1 ) , num_qo_heads);
351+ TVM_FFI_ICHECK_EQ (q_nope_out. size ( 1 ) , num_qo_heads);
352352
353353 // Validate K tensor head dimensions (if 3D)
354- if (k_rope_in-> ndim == 3 ) {
355- TVM_FFI_ICHECK_EQ (k_nope_in-> shape [ 1 ] , num_kv_heads);
356- TVM_FFI_ICHECK_EQ (k_rope_out-> shape [ 1 ] , num_kv_heads);
357- TVM_FFI_ICHECK_EQ (k_nope_out-> shape [ 1 ] , num_kv_heads);
354+ if (k_rope_in. ndim () == 3 ) {
355+ TVM_FFI_ICHECK_EQ (k_nope_in. size ( 1 ) , num_kv_heads);
356+ TVM_FFI_ICHECK_EQ (k_rope_out. size ( 1 ) , num_kv_heads);
357+ TVM_FFI_ICHECK_EQ (k_nope_out. size ( 1 ) , num_kv_heads);
358358 }
359359
360- const uint32_t q_rope_in_stride_n = q_rope_in-> strides [ 0 ] ;
361- const uint32_t q_rope_in_stride_h = q_rope_in-> strides [ 1 ] ;
362- const uint32_t q_nope_in_stride_n = q_nope_in-> strides [ 0 ] ;
363- const uint32_t q_nope_in_stride_h = q_nope_in-> strides [ 1 ] ;
364- const uint32_t q_rope_out_stride_n = q_rope_out-> strides [ 0 ] ;
365- const uint32_t q_rope_out_stride_h = q_rope_out-> strides [ 1 ] ;
366- const uint32_t q_nope_out_stride_n = q_nope_out-> strides [ 0 ] ;
367- const uint32_t q_nope_out_stride_h = q_nope_out-> strides [ 1 ] ;
360+ const uint32_t q_rope_in_stride_n = q_rope_in. stride ( 0 ) ;
361+ const uint32_t q_rope_in_stride_h = q_rope_in. stride ( 1 ) ;
362+ const uint32_t q_nope_in_stride_n = q_nope_in. stride ( 0 ) ;
363+ const uint32_t q_nope_in_stride_h = q_nope_in. stride ( 1 ) ;
364+ const uint32_t q_rope_out_stride_n = q_rope_out. stride ( 0 ) ;
365+ const uint32_t q_rope_out_stride_h = q_rope_out. stride ( 1 ) ;
366+ const uint32_t q_nope_out_stride_n = q_nope_out. stride ( 0 ) ;
367+ const uint32_t q_nope_out_stride_h = q_nope_out. stride ( 1 ) ;
368368
369369 // K tensor strides depend on dimensionality
370370 uint32_t k_rope_in_stride, k_nope_in_stride, k_rope_out_stride, k_nope_out_stride;
371371 uint32_t k_rope_in_stride_h, k_nope_in_stride_h, k_rope_out_stride_h, k_nope_out_stride_h;
372372
373- if (k_rope_in-> ndim == 2 ) {
373+ if (k_rope_in. ndim () == 2 ) {
374374 // 2D K tensors (MLA): only have batch stride
375- k_rope_in_stride = k_rope_in-> strides [ 0 ] ;
376- k_nope_in_stride = k_nope_in-> strides [ 0 ] ;
377- k_rope_out_stride = k_rope_out-> strides [ 0 ] ;
378- k_nope_out_stride = k_nope_out-> strides [ 0 ] ;
375+ k_rope_in_stride = k_rope_in. stride ( 0 ) ;
376+ k_nope_in_stride = k_nope_in. stride ( 0 ) ;
377+ k_rope_out_stride = k_rope_out. stride ( 0 ) ;
378+ k_nope_out_stride = k_nope_out. stride ( 0 ) ;
379379 // For 2D tensors, head stride is the same as batch stride (shared K/V)
380380 k_rope_in_stride_h = k_rope_in_stride;
381381 k_nope_in_stride_h = k_nope_in_stride;
382382 k_rope_out_stride_h = k_rope_out_stride;
383383 k_nope_out_stride_h = k_nope_out_stride;
384384 } else {
385385 // 3D K tensors (GQA/MHA): have both batch and head strides
386- k_rope_in_stride = k_rope_in-> strides [ 0 ] ;
387- k_rope_in_stride_h = k_rope_in-> strides [ 1 ] ;
388- k_nope_in_stride = k_nope_in-> strides [ 0 ] ;
389- k_nope_in_stride_h = k_nope_in-> strides [ 1 ] ;
390- k_rope_out_stride = k_rope_out-> strides [ 0 ] ;
391- k_rope_out_stride_h = k_rope_out-> strides [ 1 ] ;
392- k_nope_out_stride = k_nope_out-> strides [ 0 ] ;
393- k_nope_out_stride_h = k_nope_out-> strides [ 1 ] ;
386+ k_rope_in_stride = k_rope_in. stride ( 0 ) ;
387+ k_rope_in_stride_h = k_rope_in. stride ( 1 ) ;
388+ k_nope_in_stride = k_nope_in. stride ( 0 ) ;
389+ k_nope_in_stride_h = k_nope_in. stride ( 1 ) ;
390+ k_rope_out_stride = k_rope_out. stride ( 0 ) ;
391+ k_rope_out_stride_h = k_rope_out. stride ( 1 ) ;
392+ k_nope_out_stride = k_nope_out. stride ( 0 ) ;
393+ k_nope_out_stride_h = k_nope_out. stride ( 1 ) ;
394394 }
395395
396- cudaSetDevice (q_rope_in-> device .device_id );
397- const cudaStream_t stream = get_stream (q_rope_in-> device );
398- DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 (q_rope_in-> dtype , c_type, [&] {
399- return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 (q_rope_out-> dtype , c_quant_type, [&] {
400- return DISPATCH_DLPACK_IDTYPE_TO_CTYPE (pos_ids-> dtype , c_idtype, [&] {
396+ cudaSetDevice (q_rope_in. device () .device_id );
397+ const cudaStream_t stream = get_stream (q_rope_in. device () );
398+ DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 (q_rope_in. dtype () , c_type, [&] {
399+ return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8 (q_rope_out. dtype () , c_quant_type, [&] {
400+ return DISPATCH_DLPACK_IDTYPE_TO_CTYPE (pos_ids. dtype () , c_idtype, [&] {
401401 cudaError_t status = RopeQuantize (
402- static_cast <c_type*>(q_rope_in->data ), static_cast <c_type*>(k_rope_in->data ),
403- static_cast <c_type*>(q_nope_in->data ), static_cast <c_type*>(k_nope_in->data ),
404- static_cast <c_quant_type*>(q_rope_out->data ),
405- static_cast <c_quant_type*>(k_rope_out->data ),
406- static_cast <c_quant_type*>(q_nope_out->data ),
407- static_cast <c_quant_type*>(k_nope_out->data ), static_cast <float *>(cos_sin_cache->data ),
408- static_cast <c_idtype*>(pos_ids->data ), nnz, num_qo_heads, num_kv_heads, rope_dim,
402+ static_cast <c_type*>(q_rope_in.data_ptr ()), static_cast <c_type*>(k_rope_in.data_ptr ()),
403+ static_cast <c_type*>(q_nope_in.data_ptr ()), static_cast <c_type*>(k_nope_in.data_ptr ()),
404+ static_cast <c_quant_type*>(q_rope_out.data_ptr ()),
405+ static_cast <c_quant_type*>(k_rope_out.data_ptr ()),
406+ static_cast <c_quant_type*>(q_nope_out.data_ptr ()),
407+ static_cast <c_quant_type*>(k_nope_out.data_ptr ()),
408+ static_cast <float *>(cos_sin_cache.data_ptr ()),
409+ static_cast <c_idtype*>(pos_ids.data_ptr ()), nnz, num_qo_heads, num_kv_heads, rope_dim,
409410 no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
410411 q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
411412 q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
0 commit comments