Skip to content

Commit 2343f25

Browse files
committed
upd
1 parent 29563c4 commit 2343f25

File tree

15 files changed

+185
-189
lines changed

15 files changed

+185
-189
lines changed

.github/workflows/nightly-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ jobs:
9898
run: |
9999
python -m pip install --upgrade pip
100100
pip install build twine wheel
101-
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b19
101+
pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2"
102102
103103
- name: Build flashinfer-cubin wheel
104104
env:

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
run: |
137137
python -m pip install --upgrade pip
138138
pip install build twine wheel
139-
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b19
139+
pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2"
140140
141141
- name: Build flashinfer-cubin wheel
142142
run: |

csrc/batch_mla_config.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using namespace flashinfer;
1313
#ifdef FLASHINFER_ENABLE_PROFILER
1414
#define ADDITIONAL_FUNC_PARAMS , Tensor profiler_buffer
1515
#define ADDITIONAL_PARAMS_SETTER \
16-
params.profiler_buffer = static_cast<uint64_t*>(profiler_buffer->data);
16+
params.profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
1717
#else
1818
#define ADDITIONAL_FUNC_PARAMS
1919
#define ADDITIONAL_PARAMS_SETTER

csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal
149149
}
150150
CHECK_CONTIGUOUS(blockScale);
151151
CHECK_INPUT_TYPE(blockScale, dl_uint8);
152-
auto blockScaleShape = blockScale.shape();
152+
auto blockScaleShape = blockScale.sizes();
153153
TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3)
154154
<< "Block Scale should be 2D or 3D tensor.";
155155
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
@@ -204,7 +204,7 @@ void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView revers
204204
}
205205
CHECK_CONTIGUOUS(blockScale);
206206
CHECK_INPUT_TYPE(blockScale, dl_uint8);
207-
auto blockScaleShape = blockScale.shape();
207+
auto blockScaleShape = blockScale.sizes();
208208
TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3)
209209
<< "Block Scale should be 2D or 3D tensor.";
210210
auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1;
@@ -251,8 +251,8 @@ void E2M1AndUFP8SFScaleToFloatV2(TensorView valueE2M1, TensorView scaleFP8SF,
251251
bool isSfSwizzledLayout = true) {
252252
CHECK_CPU_INPUT(valueE2M1, dl_uint8);
253253
CHECK_CPU_INPUT(scaleFP8SF, dl_uint8);
254-
auto packedShape = valueE2M1.shape();
255-
auto scaleShape = scaleFP8SF.shape();
254+
auto packedShape = valueE2M1.sizes();
255+
auto scaleShape = scaleFP8SF.sizes();
256256
TVM_FFI_ICHECK_EQ(packedShape.size(), 2) << "valueE2M1 should be 2D tensor.";
257257
TVM_FFI_ICHECK_EQ(scaleShape.size(), 1) << "scaleFP8SF should be 1D tensor.";
258258

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void fp4_quantize(TensorView self, Optional<TensorView> const& globalScale, Tens
5050
globalScalePtr = static_cast<float*>(globalScale.value().data_ptr());
5151
}
5252

53-
auto const& inputShape = self.shape();
53+
auto const& inputShape = self.sizes();
5454
auto const& rank = inputShape.size();
5555

5656
TVM_FFI_ICHECK_GE(rank, 2) << "Input should be >=2D tensor.";
@@ -140,7 +140,7 @@ void fp4_batched_quantize(TensorView self, Optional<TensorView> const& mask, Ten
140140
CHECK_INPUT_TYPE(globalScale, fp32_dtype);
141141
TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16";
142142

143-
auto const& inputShape = self.shape();
143+
auto const& inputShape = self.sizes();
144144
auto const& rank = inputShape.size();
145145

146146
TVM_FFI_ICHECK_EQ(rank, 3) << "Input should be 3D tensor.";
@@ -205,7 +205,7 @@ void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView cons
205205
CHECK_INPUT_TYPE(globalScale, fp32_dtype);
206206
TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16";
207207

208-
auto const& inputShape = self.shape();
208+
auto const& inputShape = self.sizes();
209209
auto const& rank = inputShape.size();
210210
auto const& mask_rank = mask.ndim();
211211

csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF
3636
TVM_FFI_ICHECK_EQ(alignment % SF_VEC_SIZE, 0)
3737
<< "alignment must be divisible by SF_VEC_SIZE = 32";
3838

39-
auto const& inputShape = input.shape();
39+
auto const& inputShape = input.sizes();
4040
auto const& rank = inputShape.size();
4141

4242
TVM_FFI_ICHECK_GE(rank, 2) << "Input should be >=2D tensor.";
@@ -98,7 +98,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc
9898
int32_t const sf_vec_size = 32;
9999
auto fp32_dtype = DLDataType{kDLFloat, 32, 1};
100100
CHECK_INPUT_TYPE(x_fp32, fp32_dtype);
101-
auto data_shape = x_fp32.shape();
101+
auto data_shape = x_fp32.sizes();
102102
TVM_FFI_ICHECK_EQ(data_shape.size(), 2) << "x_fp32 should be 2D tensor.";
103103
int num_tokens = data_shape[0];
104104
int hidden_dim = data_shape[1];
@@ -145,8 +145,8 @@ void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf,
145145
int32_t const sf_vec_size = 32;
146146
CHECK_INPUT_TYPE(value_e4m3, dl_uint8);
147147
CHECK_INPUT_TYPE(scale_ue8m08sf, dl_uint8);
148-
auto data_shape = value_e4m3.shape();
149-
auto scale_shape = scale_ue8m08sf.shape();
148+
auto data_shape = value_e4m3.sizes();
149+
auto scale_shape = scale_ue8m08sf.sizes();
150150
TVM_FFI_ICHECK_EQ(data_shape.size(), 2) << "value_e4m3 should be 2D tensor.";
151151
TVM_FFI_ICHECK_EQ(scale_shape.size(), 1) << "scale_ue8m08sf should be 1D tensor.";
152152

csrc/rope.cu

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -285,29 +285,29 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
285285
CHECK_INPUT(pos_ids);
286286

287287
// Extract dimensions from tensor shapes (flexible)
288-
uint32_t rope_dim = q_rope_in->shape[q_rope_in->ndim - 1];
289-
uint32_t no_rope_dim = q_nope_in->shape[q_nope_in->ndim - 1];
288+
uint32_t rope_dim = q_rope_in.size(-1);
289+
uint32_t no_rope_dim = q_nope_in.size(-1);
290290

291291
// Validate rope and no_rope dimensions are consistent
292-
TVM_FFI_ICHECK_EQ(k_rope_in->shape[k_rope_in->ndim - 1], rope_dim);
293-
TVM_FFI_ICHECK_EQ(k_nope_in->shape[k_nope_in->ndim - 1], no_rope_dim);
294-
TVM_FFI_ICHECK_EQ(q_rope_out->shape[q_rope_out->ndim - 1], rope_dim);
295-
TVM_FFI_ICHECK_EQ(k_rope_out->shape[k_rope_out->ndim - 1], rope_dim);
296-
TVM_FFI_ICHECK_EQ(q_nope_out->shape[q_nope_out->ndim - 1], no_rope_dim);
297-
TVM_FFI_ICHECK_EQ(k_nope_out->shape[k_nope_out->ndim - 1], no_rope_dim);
298-
TVM_FFI_ICHECK_EQ(q_rope_in->dtype, k_rope_in->dtype);
299-
TVM_FFI_ICHECK_EQ(q_rope_in->dtype, q_nope_in->dtype);
300-
TVM_FFI_ICHECK_EQ(q_rope_in->dtype, k_nope_in->dtype);
301-
TVM_FFI_ICHECK_EQ(q_rope_out->dtype, k_rope_out->dtype);
302-
TVM_FFI_ICHECK_EQ(q_rope_out->dtype, q_nope_out->dtype);
303-
TVM_FFI_ICHECK_EQ(q_rope_out->dtype, k_nope_out->dtype);
292+
TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim);
293+
TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim);
294+
TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim);
295+
TVM_FFI_ICHECK_EQ(k_rope_out.size(-1), rope_dim);
296+
TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim);
297+
TVM_FFI_ICHECK_EQ(k_nope_out.size(-1), no_rope_dim);
298+
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype());
299+
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype());
300+
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype());
301+
TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), k_rope_out.dtype());
302+
TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), q_nope_out.dtype());
303+
TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), k_nope_out.dtype());
304304

305305
// Validate supported input data types (float16 or bfloat16)
306-
TVM_FFI_ICHECK(q_rope_in->dtype == dl_float16 || q_rope_in->dtype == dl_bfloat16)
306+
TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16)
307307
<< "Input dtype must be float16 or bfloat16";
308308

309309
// Validate supported output quantization data types (float8_e4m3fn or float8_e5m2)
310-
TVM_FFI_ICHECK(q_rope_out->dtype == dl_float8_e4m3fn || q_rope_out->dtype == dl_float8_e5m2)
310+
TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2)
311311
<< "Output dtype must be float8_e4m3fn or float8_e5m2";
312312

313313
// Q tensors are always 3D: (nnz, num_qo_heads, rope_dim/no_rope_dim)
@@ -318,7 +318,7 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
318318

319319
// K tensors can be 2D (MLA) or 3D (GQA/MHA)
320320
uint32_t num_kv_heads;
321-
if (k_rope_in->ndim == 2) {
321+
if (k_rope_in.ndim() == 2) {
322322
// MLA case: k_rope_in: (nnz, rope_dim), k_nope_in: (nnz, no_rope_dim)
323323
CHECK_DIM(2, k_rope_in);
324324
CHECK_DIM(2, k_nope_in);
@@ -331,81 +331,82 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
331331
CHECK_DIM(3, k_nope_in);
332332
CHECK_DIM(3, k_rope_out);
333333
CHECK_DIM(3, k_nope_out);
334-
num_kv_heads = k_rope_in->shape[1];
334+
num_kv_heads = k_rope_in.size(1);
335335
}
336-
uint32_t nnz = q_rope_in->shape[0];
337-
uint32_t num_qo_heads = q_rope_in->shape[1];
336+
uint32_t nnz = q_rope_in.size(0);
337+
uint32_t num_qo_heads = q_rope_in.size(1);
338338

339339
// Validate consistent dimensions across all tensors
340-
TVM_FFI_ICHECK_EQ(q_nope_in->shape[0], nnz);
341-
TVM_FFI_ICHECK_EQ(k_rope_in->shape[0], nnz);
342-
TVM_FFI_ICHECK_EQ(k_nope_in->shape[0], nnz);
343-
TVM_FFI_ICHECK_EQ(q_rope_out->shape[0], nnz);
344-
TVM_FFI_ICHECK_EQ(k_rope_out->shape[0], nnz);
345-
TVM_FFI_ICHECK_EQ(q_nope_out->shape[0], nnz);
346-
TVM_FFI_ICHECK_EQ(k_nope_out->shape[0], nnz);
340+
TVM_FFI_ICHECK_EQ(q_nope_in.size(0), nnz);
341+
TVM_FFI_ICHECK_EQ(k_rope_in.size(0), nnz);
342+
TVM_FFI_ICHECK_EQ(k_nope_in.size(0), nnz);
343+
TVM_FFI_ICHECK_EQ(q_rope_out.size(0), nnz);
344+
TVM_FFI_ICHECK_EQ(k_rope_out.size(0), nnz);
345+
TVM_FFI_ICHECK_EQ(q_nope_out.size(0), nnz);
346+
TVM_FFI_ICHECK_EQ(k_nope_out.size(0), nnz);
347347

348348
// Validate Q tensor head dimensions are consistent
349-
TVM_FFI_ICHECK_EQ(q_nope_in->shape[1], num_qo_heads);
350-
TVM_FFI_ICHECK_EQ(q_rope_out->shape[1], num_qo_heads);
351-
TVM_FFI_ICHECK_EQ(q_nope_out->shape[1], num_qo_heads);
349+
TVM_FFI_ICHECK_EQ(q_nope_in.size(1), num_qo_heads);
350+
TVM_FFI_ICHECK_EQ(q_rope_out.size(1), num_qo_heads);
351+
TVM_FFI_ICHECK_EQ(q_nope_out.size(1), num_qo_heads);
352352

353353
// Validate K tensor head dimensions (if 3D)
354-
if (k_rope_in->ndim == 3) {
355-
TVM_FFI_ICHECK_EQ(k_nope_in->shape[1], num_kv_heads);
356-
TVM_FFI_ICHECK_EQ(k_rope_out->shape[1], num_kv_heads);
357-
TVM_FFI_ICHECK_EQ(k_nope_out->shape[1], num_kv_heads);
354+
if (k_rope_in.ndim() == 3) {
355+
TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads);
356+
TVM_FFI_ICHECK_EQ(k_rope_out.size(1), num_kv_heads);
357+
TVM_FFI_ICHECK_EQ(k_nope_out.size(1), num_kv_heads);
358358
}
359359

360-
const uint32_t q_rope_in_stride_n = q_rope_in->strides[0];
361-
const uint32_t q_rope_in_stride_h = q_rope_in->strides[1];
362-
const uint32_t q_nope_in_stride_n = q_nope_in->strides[0];
363-
const uint32_t q_nope_in_stride_h = q_nope_in->strides[1];
364-
const uint32_t q_rope_out_stride_n = q_rope_out->strides[0];
365-
const uint32_t q_rope_out_stride_h = q_rope_out->strides[1];
366-
const uint32_t q_nope_out_stride_n = q_nope_out->strides[0];
367-
const uint32_t q_nope_out_stride_h = q_nope_out->strides[1];
360+
const uint32_t q_rope_in_stride_n = q_rope_in.stride(0);
361+
const uint32_t q_rope_in_stride_h = q_rope_in.stride(1);
362+
const uint32_t q_nope_in_stride_n = q_nope_in.stride(0);
363+
const uint32_t q_nope_in_stride_h = q_nope_in.stride(1);
364+
const uint32_t q_rope_out_stride_n = q_rope_out.stride(0);
365+
const uint32_t q_rope_out_stride_h = q_rope_out.stride(1);
366+
const uint32_t q_nope_out_stride_n = q_nope_out.stride(0);
367+
const uint32_t q_nope_out_stride_h = q_nope_out.stride(1);
368368

369369
// K tensor strides depend on dimensionality
370370
uint32_t k_rope_in_stride, k_nope_in_stride, k_rope_out_stride, k_nope_out_stride;
371371
uint32_t k_rope_in_stride_h, k_nope_in_stride_h, k_rope_out_stride_h, k_nope_out_stride_h;
372372

373-
if (k_rope_in->ndim == 2) {
373+
if (k_rope_in.ndim() == 2) {
374374
// 2D K tensors (MLA): only have batch stride
375-
k_rope_in_stride = k_rope_in->strides[0];
376-
k_nope_in_stride = k_nope_in->strides[0];
377-
k_rope_out_stride = k_rope_out->strides[0];
378-
k_nope_out_stride = k_nope_out->strides[0];
375+
k_rope_in_stride = k_rope_in.stride(0);
376+
k_nope_in_stride = k_nope_in.stride(0);
377+
k_rope_out_stride = k_rope_out.stride(0);
378+
k_nope_out_stride = k_nope_out.stride(0);
379379
// For 2D tensors, head stride is the same as batch stride (shared K/V)
380380
k_rope_in_stride_h = k_rope_in_stride;
381381
k_nope_in_stride_h = k_nope_in_stride;
382382
k_rope_out_stride_h = k_rope_out_stride;
383383
k_nope_out_stride_h = k_nope_out_stride;
384384
} else {
385385
// 3D K tensors (GQA/MHA): have both batch and head strides
386-
k_rope_in_stride = k_rope_in->strides[0];
387-
k_rope_in_stride_h = k_rope_in->strides[1];
388-
k_nope_in_stride = k_nope_in->strides[0];
389-
k_nope_in_stride_h = k_nope_in->strides[1];
390-
k_rope_out_stride = k_rope_out->strides[0];
391-
k_rope_out_stride_h = k_rope_out->strides[1];
392-
k_nope_out_stride = k_nope_out->strides[0];
393-
k_nope_out_stride_h = k_nope_out->strides[1];
386+
k_rope_in_stride = k_rope_in.stride(0);
387+
k_rope_in_stride_h = k_rope_in.stride(1);
388+
k_nope_in_stride = k_nope_in.stride(0);
389+
k_nope_in_stride_h = k_nope_in.stride(1);
390+
k_rope_out_stride = k_rope_out.stride(0);
391+
k_rope_out_stride_h = k_rope_out.stride(1);
392+
k_nope_out_stride = k_nope_out.stride(0);
393+
k_nope_out_stride_h = k_nope_out.stride(1);
394394
}
395395

396-
cudaSetDevice(q_rope_in->device.device_id);
397-
const cudaStream_t stream = get_stream(q_rope_in->device);
398-
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in->dtype, c_type, [&] {
399-
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out->dtype, c_quant_type, [&] {
400-
return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids->dtype, c_idtype, [&] {
396+
cudaSetDevice(q_rope_in.device().device_id);
397+
const cudaStream_t stream = get_stream(q_rope_in.device());
398+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] {
399+
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] {
400+
return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] {
401401
cudaError_t status = RopeQuantize(
402-
static_cast<c_type*>(q_rope_in->data), static_cast<c_type*>(k_rope_in->data),
403-
static_cast<c_type*>(q_nope_in->data), static_cast<c_type*>(k_nope_in->data),
404-
static_cast<c_quant_type*>(q_rope_out->data),
405-
static_cast<c_quant_type*>(k_rope_out->data),
406-
static_cast<c_quant_type*>(q_nope_out->data),
407-
static_cast<c_quant_type*>(k_nope_out->data), static_cast<float*>(cos_sin_cache->data),
408-
static_cast<c_idtype*>(pos_ids->data), nnz, num_qo_heads, num_kv_heads, rope_dim,
402+
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
403+
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
404+
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
405+
static_cast<c_quant_type*>(k_rope_out.data_ptr()),
406+
static_cast<c_quant_type*>(q_nope_out.data_ptr()),
407+
static_cast<c_quant_type*>(k_nope_out.data_ptr()),
408+
static_cast<float*>(cos_sin_cache.data_ptr()),
409+
static_cast<c_idtype*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
409410
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
410411
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
411412
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,

0 commit comments

Comments
 (0)