Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flashinfer/gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def gated_delta_rule_decode_pretranspose(
use_qk_l2norm: bool = True,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
output_state_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Gated Delta Rule Decode kernel for single-token generation.

Expand Down Expand Up @@ -165,6 +166,10 @@ def gated_delta_rule_decode_pretranspose(
Per-batch indices of shape ``[B]`` (int32 or int64) mapping each batch
entry to its slot in ``initial_state``. Required when ``initial_state``
is provided.
output_state_indices (Optional[torch.Tensor]):
Per-batch indices of shape ``[B]`` (int32 or int64) specifying where to write the updated state for each batch entry in the pool.
Requires ``initial_state`` to be provided.
If None, the kernel will write the updated state back to the same slot it read from (i.e., ``initial_state_indices``).

Returns:
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -191,6 +196,18 @@ def gated_delta_rule_decode_pretranspose(
assert use_pool == (initial_state_indices is not None), (
"initial_state and initial_state_indices must be provided together"
)
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)
Comment on lines +199 to +210
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Reject in-place remaps that alias another batch item's source slot.

output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.

⚠️ Potential issue | πŸ”΄ Critical

Validate output_state_indices against the pool before dispatch.

The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.

πŸ’‘ Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+        assert in_range.all().item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.


if use_pool:
pool_size = initial_state.shape[0]
Expand Down Expand Up @@ -253,6 +270,7 @@ def gated_delta_rule_decode_pretranspose(
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
output_state_indices=output_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
)
Expand Down Expand Up @@ -339,6 +357,7 @@ def gated_delta_rule_decode_pretranspose(
use_qk_l2norm,
use_pool_indexing=use_pool_indexing,
initial_state_indices=initial_state_indices,
output_state_indices=output_state_indices,
)

# Copy state back only if not using pool and state was not contiguous
Expand Down
61 changes: 51 additions & 10 deletions flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ def gdn_decode_bf16state_mtp_kernel(
v: cute.Tensor, # [B, T, HV, V]
b: cute.Tensor, # [B, T, HV]
o: cute.Tensor, # [B, T, HV, V] - output
h0_indices: cute.Tensor, # [B] - initial state indices
h0_indices: cute.Tensor, # [B] - initial state indices (read)
h0_out_indices: cute.Tensor, # [B] - output state indices (write)
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
scale: cutlass.Constexpr[float],
Expand Down Expand Up @@ -1320,6 +1321,8 @@ def gdn_decode_bf16state_mtp_kernel(

# Each group handles tile_v/num_groups V rows, 8 at a time (ILP=8)
flat_state_idx = cache_idx * HV + i_hv
write_cache_idx = h0_out_indices[i_n]
flat_write_idx = write_cache_idx * HV + i_hv
rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups
eighth_rows: cutlass.Constexpr[int] = rows_per_group // MTP_ILP_ROWS

Expand Down Expand Up @@ -1975,14 +1978,38 @@ def gdn_decode_bf16state_mtp_kernel(
r_hb5[i] = cutlass.BFloat16(r_h[5, i])
r_hb6[i] = cutlass.BFloat16(r_h[6, i])
r_hb7[i] = cutlass.BFloat16(r_h[7, i])
cute.autovec_copy(r_hb0, ht0)
cute.autovec_copy(r_hb1, ht1)
cute.autovec_copy(r_hb2, ht2)
cute.autovec_copy(r_hb3, ht3)
cute.autovec_copy(r_hb4, ht4)
cute.autovec_copy(r_hb5, ht5)
cute.autovec_copy(r_hb6, ht6)
cute.autovec_copy(r_hb7, ht7)
wt0 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v0, lane_in_group)
)
wt1 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v1, lane_in_group)
)
wt2 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v2, lane_in_group)
)
wt3 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v3, lane_in_group)
)
wt4 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v4, lane_in_group)
)
wt5 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v5, lane_in_group)
)
wt6 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v6, lane_in_group)
)
wt7 = cute.local_tile(
h0_source, (1, 1, vec_size), (flat_write_idx, v7, lane_in_group)
)
cute.autovec_copy(r_hb0, wt0)
cute.autovec_copy(r_hb1, wt1)
cute.autovec_copy(r_hb2, wt2)
cute.autovec_copy(r_hb3, wt3)
cute.autovec_copy(r_hb4, wt4)
cute.autovec_copy(r_hb5, wt5)
cute.autovec_copy(r_hb6, wt6)
cute.autovec_copy(r_hb7, wt7)


# ==============================================================================
Expand All @@ -2003,6 +2030,7 @@ def run_gdn_decode_bf16state_mtp(
b: cute.Tensor,
o: cute.Tensor,
h0_indices: cute.Tensor,
h0_out_indices: cute.Tensor,
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
scale: cutlass.Constexpr[float],
Expand Down Expand Up @@ -2056,6 +2084,7 @@ def run_gdn_decode_bf16state_mtp(
b,
o,
h0_indices,
h0_out_indices,
softplus_beta,
softplus_threshold,
scale,
Expand Down Expand Up @@ -2467,6 +2496,7 @@ def gated_delta_rule_mtp(
b: Optional[torch.Tensor] = None,
initial_state_source: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
output_state_indices: Optional[torch.Tensor] = None,
intermediate_states_buffer: Optional[torch.Tensor] = None,
disable_state_update: bool = False,
use_qk_l2norm_in_kernel: bool = True,
Expand All @@ -2487,7 +2517,9 @@ def gated_delta_rule_mtp(
v: [B, T, HV, V] bf16
b: [B, T, HV] bf16
initial_state_source: [pool_size, HV, V, K] bf16
initial_state_indices: [B] int32 - indices into state pool
initial_state_indices: [B] int32 - indices into state pool (read)
output_state_indices: Optional [B] int32 - indices for writing updated state.
Defaults to initial_state_indices when None.
intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16
disable_state_update: bool - if True, don't update initial state
scale: Optional, default 1/sqrt(K)
Expand All @@ -2514,6 +2546,12 @@ def gated_delta_rule_mtp(
if initial_state_indices is None:
initial_state_indices = torch.arange(B, dtype=torch.int32, device=q.device)

# Default output indices to read indices
if output_state_indices is None:
output_state_indices = initial_state_indices
elif output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
Comment on lines +2549 to +2553
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Preserve padding/null-buffer semantics when defaulting output_state_indices.

This regresses the existing BF16 negative-index path: padded reads still come in as initial_state_indices == -1, but None now copies that -1 straight onto the write side. The kernel uses h0_out_indices for final writeback, so padded rows now write before h0_source instead of falling back to slot 0.

πŸ› Minimal fix
-    if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
-        output_state_indices = output_state_indices.to(torch.int32)
+    if output_state_indices is None:
+        # Preserve the existing slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp_min(0)
+    if output_state_indices.dtype != torch.int32:
+        output_state_indices = output_state_indices.to(torch.int32)
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices (when output_state_indices is None),
preserve padding/null-buffer semantics by cloning initial_state_indices but
mapping padded markers (-1) back to the fallback write slot (e.g., 0) before
use; specifically, in the block handling output_state_indices, set
output_state_indices = initial_state_indices.clone(), then replace any entries
equal to -1 with 0, and finally ensure dtype is torch.int32. This keeps the
kernel's h0_out_indices behavior correct (padded reads won't write to -1
locations) while keeping the int32 conversion logic.


if output is None:
output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)

Expand Down Expand Up @@ -2552,6 +2590,7 @@ def gated_delta_rule_mtp(
dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True)
o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True)
h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix formatting to pass pre-commit checks.

The pipeline failure indicates this line needs reformatting per ruff format.

πŸ”§ Apply ruff formatting
-    h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+    h0_out_idx_ = from_dlpack(
+        output_state_indices, assumed_align=32, enable_tvm_ffi=True
+    )
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(
output_state_indices, assumed_align=32, enable_tvm_ffi=True
)
🧰 Tools
πŸͺ› GitHub Actions: pre-commit

[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...).

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` at line 2593, The assignment
to h0_out_idx_ calling from_dlpack is misformatted; reformat that line to
satisfy ruff (apply ruff format or adjust spacing/punctuation) so it matches the
project's formatting rules (e.g., proper spacing around the = and within the
function call) in the h0_out_idx_ = from_dlpack(...) statement; keep the same
variable name h0_out_idx_ and function call from_dlpack with arguments
output_state_indices, assumed_align=32, enable_tvm_ffi=True.


stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

Expand Down Expand Up @@ -2590,6 +2629,7 @@ def gated_delta_rule_mtp(
b_,
o_,
h0_idx_,
h0_out_idx_,
softplus_beta,
softplus_threshold,
scale,
Expand Down Expand Up @@ -2620,6 +2660,7 @@ def gated_delta_rule_mtp(
b_,
o_,
h0_idx_,
h0_out_idx_,
stream,
)

Expand Down
46 changes: 39 additions & 7 deletions flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def gdn_decode_kernel_small_batch_pretranspose(
v: cute.Tensor, # [B, T, HV, V]
b: cute.Tensor, # [B, T, HV]
o: cute.Tensor, # [B, T, HV, V] - output
h0_indices: cute.Tensor, # [B] - initial state indices
h0_indices: cute.Tensor, # [B] - initial state indices (read)
h0_out_indices: cute.Tensor, # [B] - output state indices (write)
cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen)
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
Expand Down Expand Up @@ -134,16 +135,18 @@ def gdn_decode_kernel_small_batch_pretranspose(
# Compute state index: use pool indexing if enabled.
if cutlass.const_expr(use_pool_indexing):
pool_idx = h0_indices[i_n]
out_pool_idx = h0_out_indices[i_n]
else:
pool_idx = 0
out_pool_idx = 0

if pool_idx >= 0:
# Get current state slice.
# Get current batch
if cutlass.const_expr(use_pool_indexing):
# h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride)
gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K)
gDst = cute.local_tile(
h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0)
h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0)
)
else:
# h0_source layout: [B*HV, V, K]
Expand Down Expand Up @@ -307,7 +310,7 @@ def gdn_decode_kernel_small_batch_pretranspose(
r_h[i] += r_k[i] * v_new
sum_hq += r_h[i] * r_q[i]

# Write h back to state.
# Write h to gDst using 4D local_tile + autovec_copy (contiguous in K)
if cutlass.const_expr(use_pool_indexing):
gDst_tile = cute.local_tile(
gDst,
Expand Down Expand Up @@ -361,7 +364,8 @@ def gdn_decode_kernel_big_batch_pretranspose(
v: cute.Tensor, # [B, T, HV, V]
b: cute.Tensor, # [B, T, HV]
o: cute.Tensor, # [B, T, HV, V] - output
h0_indices: cute.Tensor, # [B] - initial state indices
h0_indices: cute.Tensor, # [B] - initial state indices (read)
h0_out_indices: cute.Tensor, # [B] - output state indices (write)
cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen)
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
Expand Down Expand Up @@ -436,16 +440,18 @@ def gdn_decode_kernel_big_batch_pretranspose(
# Compute state index: use pool indexing if enabled.
if cutlass.const_expr(use_pool_indexing):
pool_idx = h0_indices[i_n]
out_pool_idx = h0_out_indices[i_n]
else:
pool_idx = 0
out_pool_idx = 0

if pool_idx >= 0:
# Get current state slice.
if cutlass.const_expr(use_pool_indexing):
# h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride)
gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K)
gDst = cute.local_tile(
h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0)
h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0)
)
else:
# h0_source layout: [B*HV, V, K]
Expand Down Expand Up @@ -657,6 +663,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose(
b: cute.Tensor,
o: cute.Tensor,
h0_indices: cute.Tensor,
h0_out_indices: cute.Tensor,
cu_seqlens: cute.Tensor,
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
Expand Down Expand Up @@ -734,6 +741,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose(
b,
o,
h0_indices,
h0_out_indices,
cu_seqlens,
softplus_beta,
softplus_threshold,
Expand Down Expand Up @@ -768,6 +776,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose(
b: cute.Tensor,
o: cute.Tensor,
h0_indices: cute.Tensor,
h0_out_indices: cute.Tensor,
cu_seqlens: cute.Tensor,
softplus_beta: cutlass.Constexpr[float],
softplus_threshold: cutlass.Constexpr[float],
Expand Down Expand Up @@ -840,6 +849,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose(
b,
o,
h0_indices,
h0_out_indices,
cu_seqlens,
softplus_beta,
softplus_threshold,
Expand Down Expand Up @@ -910,6 +920,7 @@ def run_pretranspose_decode(
use_qk_l2norm: bool,
use_pool_indexing: bool = False,
initial_state_indices: Optional[torch.Tensor] = None,
output_state_indices: Optional[torch.Tensor] = None,
):
"""Compile and execute the pretranspose decode kernel.

Expand All @@ -924,6 +935,8 @@ def run_pretranspose_decode(
use_pool_indexing: Whether to use pool-based indirect state indexing.
initial_state_indices: Int32 indices into state pool, shape [B].
Negative values indicate padding (kernel writes zeros).
output_state_indices: Optional int32 indices for write destination, shape [B].
When None, writes go to the same slot as initial_state_indices.
"""
# Compile kernel with TVM FFI (cached)
if use_pool_indexing:
Expand Down Expand Up @@ -959,6 +972,11 @@ def run_pretranspose_decode(
h0_indices = initial_state_indices.to(torch.int32)
else:
h0_indices = cache["h0_indices"]
# Resolve output indices: default to same as read indices
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
Comment on lines +976 to +979
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.

You can simplify this logic for better readability.

Suggested change
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
if output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices

cu_seqlens = cache["cu_seqlens"]

if "compiled" not in cache:
Expand All @@ -976,6 +994,7 @@ def run_pretranspose_decode(
b_tensor = from_dlpack(b, assumed_align=16)
o_tensor = from_dlpack(output, assumed_align=16)
h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16)
h0_out_indices_tensor = from_dlpack(h0_out_indices, assumed_align=16)
cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16)

# Always use 8-CTA architecture (benchmarks show it's better for all batch sizes)
Expand All @@ -994,6 +1013,7 @@ def run_pretranspose_decode(
b_tensor,
o_tensor,
h0_indices_tensor,
h0_out_indices_tensor,
cu_seqlens_tensor,
softplus_beta=1.0,
softplus_threshold=20.0,
Expand All @@ -1018,5 +1038,17 @@ def run_pretranspose_decode(
# Run kernel directly with PyTorch tensors (no from_dlpack needed)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache["compiled"](
h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream
h0_source,
A_log,
a,
dt_bias,
q,
k,
v,
b,
output,
h0_indices,
h0_out_indices,
cu_seqlens,
stream,
)
Loading
Loading