Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
d0a53b5
adopt reference implementation from sglang
ishovkun Jan 29, 2026
320a72c
Extract create_test_inputs to shared test_utils module
ishovkun Jan 29, 2026
4022f10
Rename test to reflect that it's an single-token test file
ishovkun Jan 29, 2026
a8bc286
Add multi-token support to the interface of selective_state_update
ishovkun Jan 29, 2026
2e70ea4
Refactor selective_state_update: add validation helpers and update param
ishovkun Jan 29, 2026
295ae56
Non-contiguous state
ishovkun Jan 29, 2026
5541624
Simplify code for template dispatching
ishovkun Jan 29, 2026
ab33cc1
Refactor dispatch logic in selective_state_update.cuh
ishovkun Jan 29, 2026
26271a9
Refactor pointer alignement checking away from the logic.
ishovkun Jan 29, 2026
f3f02f5
Support int32 and int64 state_batch_indices in selective_state_update
ishovkun Jan 29, 2026
1cb4ac7
Refactor Mamba selective state update kernel dispatch and add dtype
ishovkun Jan 30, 2026
3265bd5
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 30, 2026
9d6d35c
Fix simple stp kernel to only write state if a flag is provided
ishovkun Jan 30, 2026
5b5756d
Fix Triton kernel intermediate state caching to match CUDA behavior
ishovkun Jan 30, 2026
e3f751e
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Jan 31, 2026
fb693d0
Add Mamba2 SSD chunk scan test and reorganize Triton refs
ishovkun Feb 3, 2026
0ce5d47
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 17, 2026
304fd59
Enable .jinja templates for mamba
ishovkun Feb 17, 2026
329bfd0
Remove SM100 module, unify SM90+ selective state update handling
ishovkun Feb 17, 2026
f464097
Add algorithm selection to selective_state_update kernels
ishovkun Feb 18, 2026
c65670c
Fix include order: config.inc before header in selective_state_update…
ishovkun Feb 18, 2026
44b6c25
Parallelize consumer warp loads in vertical SSU kernel
ishovkun Feb 18, 2026
eff403c
Reduce test combinations in SSU tests to base + independent deviations
ishovkun Feb 18, 2026
afc7c6a
Add algorithm parameter to selective_state_update tests
ishovkun Feb 19, 2026
74accb0
Merge branch 'flashinfer-ai:main' into main
ishovkun Feb 19, 2026
1d42007
Update selective_state_update instantiations to include SSUAlgorithm
ishovkun Feb 19, 2026
61d88bd
Clarify algorithm selection docstring in selective_state_update
ishovkun Feb 19, 2026
ead4943
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 19, 2026
6f6a3d7
Remove chunk scan combined kernels as they are irrelevant to this PR
ishovkun Feb 19, 2026
de96dd5
Remove ssd_chunk_state.py Triton reference implementation (irrelevant to
ishovkun Feb 19, 2026
4c30f07
Delete test_utils.py
ishovkun Feb 19, 2026
1f1c2f4
Suppress mypy false positive for gen_selective_state_update calls
ishovkun Feb 19, 2026
157ecb5
Move Triton reference kernel to triton_reference subdir and update
ishovkun Feb 19, 2026
f32b63b
mark an unused variable with "_" in a test
ishovkun Feb 19, 2026
2656202
rename an unused test variable to _state_ref
ishovkun Feb 19, 2026
5580d28
Refactor Triton reference import for selective_state_update
ishovkun Feb 19, 2026
8738964
Add int16 state quantization with block scaling to
ishovkun Feb 19, 2026
02db096
Add int16 quantized state support to selective_state_update
ishovkun Feb 20, 2026
58f56cd
Fixes aot compilation of the gdn_prefill_sm90 module
ishovkun Feb 20, 2026
d4e33de
Merge branch 'main' into ssu_int16
ishovkun Feb 20, 2026
5d8184e
Substantially reduce the nubmer of SSU aot compilation units. Limited to
ishovkun Feb 20, 2026
9775391
Merge branch 'main' into ssu_int16
ishovkun Feb 20, 2026
7f1173f
Add int16 support for block scaling in selective_state_update kernel
ishovkun Feb 20, 2026
35cc7ba
Add int16 block scaling support to selective_state_update MTP
ishovkun Feb 20, 2026
6cf61b7
Fix rNewState array size calculation for scaleState flag
ishovkun Feb 20, 2026
e9ab619
Refactor selective_state_update to use state_scale dtype
ishovkun Feb 23, 2026
60b627e
Add Philox-4x32 PRNG matching Triton tl.randint and tests
ishovkun Feb 24, 2026
b873d10
Refactor philox_randint to template and add rounding tests
ishovkun Feb 24, 2026
3292662
Stochastic rounding support for fp16 state update (plubming)
ishovkun Feb 25, 2026
b206a5f
Implement stochastic rounding for fp16 state in selective_state_update
ishovkun Feb 25, 2026
c70efbd
Optimize Philox PRNG usage in selective_state_update kernel
ishovkun Feb 26, 2026
181d80d
Fix Philox random offset calculation for state updates
ishovkun Feb 26, 2026
fd1af7c
Remove .plans directory from .gitignore
ishovkun Feb 26, 2026
ff8dfde
Merge branch 'ssu_int16': int16 block-scaled state and stochastic rou…
ishovkun Feb 26, 2026
60bbb5d
Merge remote-tracking branch 'upstream/main'
ishovkun Feb 26, 2026
c01eced
Replace asserts with if checks in the python wrapper
ishovkun Feb 27, 2026
0bf77aa
Remove redundant dtype check for state_batch_indices and
ishovkun Feb 27, 2026
deb48a8
Use tuples instead of lists for parameter sets in tests
ishovkun Feb 27, 2026
e1d9dc3
Fix selective_state_update argument order for state_scale_dtype
ishovkun Feb 27, 2026
b30db63
Replace float pointer casts with __float_as_uint in conversion kernels
ishovkun Feb 27, 2026
317f6bb
Handle zero max value in state scaling calculations
ishovkun Feb 27, 2026
852b9a2
Add static_assert for fp16 state in SR branch to check that an edge case
ishovkun Feb 27, 2026
2ca355e
if not philox_rounds > 0:` β†’ `if philox_rounds <= 0:` β€” same semantics,
ishovkun Feb 27, 2026
8bd1779
Dummy algorithm support to MTP selective_state_update -- only
ishovkun Feb 27, 2026
69d4a3c
Sloppy first prototype without real TMA
ishovkun Feb 27, 2026
42b2d81
Refactor selective_state_update_mtp to use TMA tensor descriptors (1.99x
ishovkun Feb 28, 2026
647160a
Add vertical algorithm tests for intermediate states and update param
ishovkun Feb 28, 2026
c18548f
Refactor vertical MTP kernel to process full DIMΓ—DSTATE tiles per head
ishovkun Feb 28, 2026
04e1ebc
Rename role_compute to role_update_state in vertical MTP kernel
ishovkun Mar 2, 2026
1a69b43
Add vertical kernel dtype check and fix state indexing
ishovkun Mar 9, 2026
a7960c4
Eliminate store warp in vertical MTP kernel and write states directly to
ishovkun Mar 9, 2026
d15d909
Refactor vertical MTP kernel to process 2 groups per CTA
ishovkun Mar 10, 2026
1b2b306
Add benchmark script for selective_state_update MTP mode
ishovkun Mar 10, 2026
a3a56f0
Refactor vertical SSU kernel to use 3 compute groups per CTA
ishovkun Mar 10, 2026
aff33ac
Overlap scalar gmem loads with barrier waits in role_update_state
ishovkun Mar 11, 2026
2c8209b
Refactor MTP selective_state_update: split kernel launcher
ishovkun Mar 11, 2026
4cb4978
Remove DIM <= 64 skip for vertical kernel tests
ishovkun Mar 11, 2026
f7f4dde
Add checks for unsupported features in vertical SSU kernel
ishovkun Mar 11, 2026
9d6e129
Specify the sm_100 for the "vertical" kernel
ishovkun Mar 11, 2026
73b1829
Merge remote-tracking branch 'upstream/main' into ssu_mtp_persistent
ishovkun Mar 11, 2026
329be73
Enable stochastic rounding in vertical kernel and update tests
ishovkun Mar 12, 2026
d1fdfed
Skip vertical MTP tests if SM100+ is not available
ishovkun Mar 12, 2026
226c494
Remove -lineinfo flag from selective_state_update kernels
ishovkun Mar 12, 2026
1a4d24a
Require DIM divisible by 16 for vertical SSU kernel
ishovkun Mar 12, 2026
8ba2ef7
Refactor test to parametrize and clarify scaleState rejection
ishovkun Mar 12, 2026
2040a3c
Improve error message for vertical kernel DIM alignment check
ishovkun Mar 12, 2026
a10a843
Restored issue-claim.yml
ishovkun Mar 13, 2026
2e527cb
I committed WIP. the horizontal thingy not working
ishovkun Mar 16, 2026
94e5456
Add horizontal MTP kernel support for selective_state_update
ishovkun Mar 18, 2026
056385f
Refactor MTP kernel dispatch to split vertical and horizontal paths
ishovkun Mar 18, 2026
083b2e6
Add horizontal MTP kernel for selective_state_update
ishovkun Mar 18, 2026
6bde9be
Optimize state update kernel with f32x2 packed SIMD ops
ishovkun Mar 18, 2026
b1a5421
Add horizontal_v2 algorithm support for selective_state_update MTP
ishovkun Mar 19, 2026
56a0bd7
Add horizontal v2 MTP kernel for selective_state_update
ishovkun Mar 19, 2026
b95a5c0
Support multiple heads per CTA in horizontal_v2 kernel
ishovkun Mar 19, 2026
2f1f0ac
Add tight-spin parity barrier helpers to horizontal_v2 kernel
ishovkun Mar 19, 2026
549e960
Refactor horizontal_v2 kernel to use TMA-level pipelining
ishovkun Mar 20, 2026
ecdff0b
Rename horizontal_v2 MTP to `horizontal` and remove the old `horizontal`
ishovkun Mar 20, 2026
7e6644f
Parametrize selective_state_update MTP tests for all algorithms
ishovkun Mar 20, 2026
c63ef73
Fix alignment checks and extend ngroups ratio coverage in tests
ishovkun Mar 20, 2026
2ff43a6
Support non-power-of-2 dstate in SSU MTP horizontal kernel
ishovkun Mar 20, 2026
0d55e5b
Parametrize TMA_STATE_ROWS in horizontal MTP kernel
ishovkun Mar 20, 2026
df0c2ed
Merge remote-tracking branch 'upstream/main' into ssu_mtp_persistent
ishovkun Mar 20, 2026
fa08e3b
Fix algorithm selection to use Horizontal kernel for batch >= 32
ishovkun Mar 20, 2026
75c91f8
Refactor state conversion with stochastic rounding helpers
ishovkun Mar 23, 2026
e54f3d8
Merge remote-tracking branch 'upstream/main' into ssu_mtp_persistent
ishovkun Mar 23, 2026
389f8a3
Fix auto algorithm selection to use Simple when varlen is set
ishovkun Mar 23, 2026
41433e6
Import re at top level and remove redundant import
ishovkun Mar 23, 2026
aaa28b7
Update vertical kernel DIM alignment check to use warpSize (32)
ishovkun Mar 23, 2026
e7133a0
Pass rand_ints array to convertAndStoreSRHorizontal
ishovkun Mar 23, 2026
0e132ef
Refactor selective_state_update kernels to use IS_PAD template param
ishovkun Mar 23, 2026
e62aa6e
Add checks for varlen support in selective state update MTP kernels
ishovkun Mar 23, 2026
2882a97
Fix z_ptr offset calculation in selective state update kernels
ishovkun Mar 23, 2026
60bc48a
Remove inline comments from role_update_state_horizontal constants
ishovkun Mar 23, 2026
b087ff6
Merge remote-tracking branch 'upstream/main' into ssu_mtp_horizontal
ishovkun Mar 24, 2026
e3c46d1
Remove bar_input_full barrier and merge B/C/X loads with state_in
ishovkun Mar 24, 2026
8dcc4d4
Add toFloat2 overloads for packed and pointer types
ishovkun Mar 24, 2026
61fe779
Strength-reduce per-step indexing in state update kernel
ishovkun Mar 24, 2026
21d0fd7
Add async horizontal SSU MTP kernel (cp.async, SM80+)
ishovkun Mar 24, 2026
4e93292
Support CTAS_PER_HEAD=4 and occupancy-based CTA scaling
ishovkun Mar 24, 2026
86edbab
fixb bank conflict at dstate=96
ishovkun Mar 25, 2026
08e0569
Merge branch 'ssu_mtp_horizontal_async' into ssu_mtp_horizontal
ishovkun Mar 25, 2026
25187a6
Restore cutlass submodule pointer to match upstream/main
ishovkun Mar 25, 2026
4ca73db
Vectorize loads in selective state update kernel
ishovkun Mar 25, 2026
72daf79
Add varlen and scaled-state support to async_horizontal kernel
ishovkun Mar 30, 2026
c6e9996
Deduplicate encode scale computation in state update kernel
ishovkun Mar 30, 2026
e6d8699
Refactor async-horizontal SSU kernel state write path
ishovkun Mar 31, 2026
0f10861
Add state_in smem buffer with cp.async prefetch for MTP kernel
ishovkun Mar 31, 2026
5876429
Refactor async state load into reusable helper function
ishovkun Mar 31, 2026
bcc62ac
Use mul_f32x2 for state decode scale in async horizontal kernel
ishovkun Apr 1, 2026
0f42548
Move dst_slot prefetch earlier to hide LDS latency
ishovkun Apr 1, 2026
2050593
Remove smem zero-fill padding; zero OOB in registers instead
ishovkun Apr 1, 2026
1c09bc5
Remove async_horizontal kernel and merge into simple
ishovkun Apr 2, 2026
6e30877
Add SOL benchmark script for selective_state_update MTP mode
ishovkun Apr 2, 2026
2fbb60a
Merge remote-tracking branch 'upstream/main' into ssu_mtp_horizontal_…
ishovkun Apr 2, 2026
5cf114b
Refactor CTAS dispatch into reusable helper function
ishovkun Apr 2, 2026
5151ddb
Add static_assert to PackedAligned for N > 0 and add missing headers
ishovkun Apr 2, 2026
41c9e36
Fix horizontal kernel DIM alignment check to use TMA_STATE_ROWS
ishovkun Apr 2, 2026
ba6fc7c
Add pad slot support for B/C/x loads in MTP kernels
ishovkun Apr 2, 2026
a614a5a
Add check for DSTATE divisibility in vertical kernel
ishovkun Apr 2, 2026
0e69e21
Fix `max` ambiguity by using `std::max`
ishovkun Apr 2, 2026
b0b4ee9
Remove unused iostream include
ishovkun Apr 2, 2026
fa5f441
Add missing `#pragma once` to MTP vertical kernel header
ishovkun Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 524 additions & 0 deletions benchmarks/bench_ssu_sweep_mtp.py

Large diffs are not rendered by default.

682 changes: 682 additions & 0 deletions benchmarks/bench_ssu_sweep_sol.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def gen_jit_spec(
cuda_cflags += ["-DNDEBUG", "-O3"]
cflags += ["-O3"]

# useful for ncu
# useful for ncu source correlation
if os.environ.get("FLASHINFER_JIT_LINEINFO", "0") == "1":
cuda_cflags += ["-lineinfo"]

Expand Down
2 changes: 2 additions & 0 deletions flashinfer/jit/mamba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

from .selective_state_update import (
gen_selective_state_update_module,
gen_selective_state_update_sm100_module,
gen_selective_state_update_sm90_module,
)
from .seq_chunk_cumsum import gen_seq_chunk_cumsum_module

__all__ = [
"gen_selective_state_update_module",
"gen_selective_state_update_sm90_module",
"gen_selective_state_update_sm100_module",
"gen_seq_chunk_cumsum_module",
]
55 changes: 54 additions & 1 deletion flashinfer/jit/mamba/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def gen_selective_state_update_module(
cu_seqlens_dtype,
num_accepted_tokens_dtype,
philox_rounds=philox_rounds,
extra_cuda_cflags=["-lineinfo"],
)


Expand Down Expand Up @@ -238,3 +237,57 @@ def gen_selective_state_update_sm90_module(
philox_rounds=philox_rounds,
extra_cuda_cflags=nvcc_flags,
)


def gen_selective_state_update_sm100_module(
state_dtype: torch.dtype,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
matrixA_dtype: torch.dtype,
stateIndex_dtype: torch.dtype,
state_scale_dtype: Optional[torch.dtype],
dim: int,
dstate: int,
ntokens_mtp: int,
cu_seqlens_dtype: torch.dtype,
num_accepted_tokens_dtype: torch.dtype,
philox_rounds: int = 0,
) -> JitSpec:
uri = (
get_selective_state_update_uri(
state_dtype,
input_dtype,
weight_dtype,
matrixA_dtype,
stateIndex_dtype,
state_scale_dtype,
dim,
dstate,
ntokens_mtp,
cu_seqlens_dtype,
num_accepted_tokens_dtype,
philox_rounds,
)
+ "_sm100"
)
compilation_context = CompilationContext()
nvcc_flags = compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11, 12]
)
nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM90", "-DFLASHINFER_MAMBA_ENABLE_SM100"]
return _gen_module(
uri,
state_dtype,
input_dtype,
weight_dtype,
matrixA_dtype,
stateIndex_dtype,
state_scale_dtype,
dim,
dstate,
ntokens_mtp,
cu_seqlens_dtype,
num_accepted_tokens_dtype,
philox_rounds=philox_rounds,
extra_cuda_cflags=nvcc_flags,
)
13 changes: 12 additions & 1 deletion flashinfer/mamba/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..api_logging import flashinfer_api
from ..jit.mamba import (
gen_selective_state_update_module,
gen_selective_state_update_sm100_module,
gen_selective_state_update_sm90_module,
)
from ..utils import get_compute_capability, register_custom_op, register_fake_op
Expand Down Expand Up @@ -57,7 +58,9 @@ def _get_module(
num_accepted_tokens_dtype,
philox_rounds,
)
if sm_major >= 9:
if sm_major >= 10:
return gen_selective_state_update_sm100_module(*args).build_and_load()
elif sm_major >= 9:
return gen_selective_state_update_sm90_module(*args).build_and_load()
else:
return gen_selective_state_update_module(*args).build_and_load()
Expand Down Expand Up @@ -266,6 +269,11 @@ def selective_state_update(
# No stochastic rounding when rand_seed is None
philox_rounds = 0

if intermediate_states_buffer is not None and dst_state_batch_indices is not None:
raise ValueError(
"intermediate_states_buffer and dst_state_batch_indices are mutually exclusive"
)

if out is None:
output = torch.empty_like(x)
else:
Expand Down Expand Up @@ -298,6 +306,9 @@ def selective_state_update(
algorithm_int = 2
elif algorithm == "horizontal":
algorithm_int = 3
elif algorithm == "async_horizontal":
# Backward compat: async_horizontal is now merged into simple
algorithm_int = 1
else:
raise ValueError(f"Unknown algorithm: {algorithm}")

Expand Down
12 changes: 10 additions & 2 deletions include/flashinfer/mamba/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ constexpr unsigned warpSize = 32;
// Common types and utilities
// =============================================================================

// Simple packed vector type for loading N elements of type T
// Largest power of 2 that divides v (i.e. v & -v). Returns 1 when v == 0.
inline constexpr unsigned largestPow2Divisor(unsigned v) { return v ? (v & (~v + 1)) : 1; }

// Simple packed vector type for loading N elements of type T.
// Alignment is the largest power-of-2 factor of the total byte size,
// so it is always valid even when N * sizeof(T) is not a power of 2 (e.g. 3 Γ— 2 = 6).
template <typename T, int N = sizeof(float4) / sizeof(T)>
struct alignas(N * sizeof(T)) PackedAligned {
struct alignas(largestPow2Divisor(N * sizeof(T))) PackedAligned {
static_assert(N > 0,
"PackedAligned instantiated with N == 0; "
"ensure getVectorLoadSizeForFullUtilization() returns > 0");
T val[N];
static constexpr int count = N;
using dtype = T;
Expand Down
33 changes: 33 additions & 0 deletions include/flashinfer/mamba/conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@ inline __device__ float toFloat(__nv_bfloat16 val) { return __bfloat162float(val
// (24-bit mantissa represents all integers up to 2^24 = 16M exactly).
inline __device__ float toFloat(int16_t val) { return static_cast<float>(val); }

// Packed 2-element conversion: convert a packed pair to float2.
// Uses native packed intrinsics for bf16/fp16 (fewer PRMT/SHF instructions).
inline __device__ float2 toFloat2(float2 packed) { return packed; }

inline __device__ float2 toFloat2(__half2 packed) { return __half22float2(packed); }

// Pointer-based overloads: read two consecutive elements and convert to float2.
// Dispatches to the packed intrinsic for bf16/fp16 via the overloads above.
inline __device__ float2 toFloat2(float const* ptr) { return {ptr[0], ptr[1]}; }

inline __device__ float2 toFloat2(__half const* ptr) {
return toFloat2(*reinterpret_cast<__half2 const*>(ptr));
}

#ifdef FLASHINFER_ENABLE_BF16
// inline __device__ float2 toFloat2(__nv_bfloat162 packed) { return __bfloat1622float2(packed); }
inline __device__ float2 toFloat2(__nv_bfloat162 packed) {
// bf16 is the upper 16 bits of f32 β€” shift/mask is cheaper than PRMT byte permutation.
// NOTE: this ignores denormals
uint32_t bits = reinterpret_cast<uint32_t const&>(packed);
float2 out;
out.x = __uint_as_float(bits << 16); // low bf16 β†’ upper 16 bits of f32
out.y = __uint_as_float(bits & 0xFFFF0000u); // high bf16 already in upper 16 bits
return out;
}

inline __device__ float2 toFloat2(__nv_bfloat16 const* ptr) {
return toFloat2(*reinterpret_cast<__nv_bfloat162 const*>(ptr));
}
#endif

inline __device__ float2 toFloat2(int16_t const* ptr) { return {toFloat(ptr[0]), toFloat(ptr[1])}; }

inline __device__ void convertAndStore(float* output, float input) { *output = input; }

inline __device__ void convertAndStore(__half* output, float input) {
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/mamba/create_tensor_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

namespace flashinfer::mamba::tma {

inline CUtensorMap buildNdDescriptor(std::type_info const& dtype,
std::vector<uint64_t> const& shapes,
std::vector<uint64_t> const& strides,
std::vector<int32_t> const& tileShapes, void* gmemAddr) {
inline CUtensorMap buildNdDescriptor(
std::type_info const& dtype, std::vector<uint64_t> const& shapes,
std::vector<uint64_t> const& strides, std::vector<int32_t> const& tileShapes, void* gmemAddr,
CUtensorMapFloatOOBfill oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE) {
// The multiplication factor of the data padding in SMEM.
CUtensorMap desc{};
CUtensorMapDataType tmaDataFormat;
Expand Down Expand Up @@ -85,7 +85,7 @@ inline CUtensorMap buildNdDescriptor(std::type_info const& dtype,
boxDim.data(), tileStrides.data(),
/*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType,
/*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
/*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
/*oobFill=*/oobFill);

if (result != CUDA_SUCCESS) {
char const* errorString;
Expand Down
Loading
Loading