Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 4 additions & 48 deletions tests/v1/e2e/general/test_mamba_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def fake_execute_model_fn(

def get_fake_process_mamba_fn(
original_preprocess_mamba_fn: Callable,
original_post_process_mamba_fn: Callable,
original_copy_fn: Callable,
):
copy_info: tuple[list[int], list[int], list[int]] | None = None
Expand Down Expand Up @@ -361,45 +360,6 @@ def fake_preprocess_mamba_fn(
)
return ret

def fake_post_process_mamba_fn(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
num_spec_tokens: int,
num_reqs: int,
*,
forward_context: dict[str, Any] | None = None,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...] | None = None,
copy_bufs: mamba_utils.MambaCopyBuffers | None = None,
):
nonlocal copy_info
copy_info = None
ret = original_post_process_mamba_fn(
scheduler_output,
kv_cache_config,
cache_config,
input_batch,
requests,
mamba_state_idx,
num_spec_tokens,
num_reqs,
forward_context=forward_context,
mamba_state_copy_funcs=mamba_state_copy_funcs,
copy_bufs=copy_bufs,
)
if cur_step_action is not None:
assert forward_context is not None
check_copy_info(
cur_step_action.postprocess_copy_idx,
kv_cache_config,
forward_context,
input_batch,
)
return ret

def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
nonlocal copy_info
assert copy_info is None
Expand All @@ -410,7 +370,7 @@ def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
copy_info = (src_state_list, dest_state_list, num_elements_list)
return original_copy_fn(copy_bufs)

return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
return fake_preprocess_mamba_fn, fake_copy_fn


def run_ref_mamba_state_in_subprocess() -> None:
Expand Down Expand Up @@ -522,15 +482,11 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch):
fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots)
monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn)

fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = (
get_fake_process_mamba_fn(
mamba_utils.preprocess_mamba,
mamba_utils.postprocess_mamba,
mamba_utils.do_mamba_copy_block,
)
fake_preprocess_mamba_fn, fake_copy_fn = get_fake_process_mamba_fn(
mamba_utils.preprocess_mamba,
mamba_utils.do_mamba_copy_block,
)
monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn)
monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn)
Comment thread
tdoublep marked this conversation as resolved.
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)


Expand Down
Loading
Loading