Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
d17d622
implementation
fzyzcjy Apr 24, 2026
d40ca83
fix index_topk
Fridge003 Apr 24, 2026
0d6856b
Revert "fix index_topk"
Fridge003 Apr 24, 2026
927e149
hisparse scheduling fix
xiezhq-hermann Apr 24, 2026
4807f6c
fix: topk 1024
DarkSharpness Apr 24, 2026
5c59a71
feat: support 1024 topk
DarkSharpness Apr 24, 2026
f5d03db
Reapply "fix index_topk"
Fridge003 Apr 24, 2026
a74b25f
update Dockerfile for B300
Fridge003 Apr 24, 2026
5031406
fix dockerfile
Fridge003 Apr 24, 2026
2045dc0
add gb dockerfile
Fridge003 Apr 24, 2026
c48efaf
add h200/b200 dockerfile
hnyls2002 Apr 24, 2026
0d4735b
fix gb dockerfile
Fridge003 Apr 24, 2026
0f94b5d
update b300 dockerfile
Fridge003 Apr 24, 2026
ca21ebe
fix gb dockerfile
Fridge003 Apr 24, 2026
5e483b7
fix b300
Qiaolin-Yu Apr 25, 2026
8756f36
fix b300 dockerfile
Qiaolin-Yu Apr 25, 2026
dc2b507
SGLANG_FIX_DSV4_BASE_MODEL_LOAD
fzyzcjy Apr 25, 2026
6c396d5
Merge remote-tracking branch 'upstream/deepseek_v4' into deepseek_v4
fzyzcjy Apr 25, 2026
cb591d3
Support dsv4 task / latest_reminder / content parts in OpenAI chat AP…
JustinTong0323 Apr 25, 2026
4bf81c9
[NSA] Fall back to fast_hadamard_transform when sgl_kernel lacks the …
Fridge003 Apr 25, 2026
02451ff
route ignore_eos+disable_radix_cache path through prefill_delayer (wa…
fzyzcjy Apr 25, 2026
00651d8
Merge remote-tracking branch 'upstream/deepseek_v4' into deepseek_v4
fzyzcjy Apr 25, 2026
7f58083
fix: fix fast ep masked
DarkSharpness Apr 25, 2026
2777a6f
feat: new flags
DarkSharpness Apr 25, 2026
2e43d2b
swa split leaf on insert
ispobock Apr 25, 2026
05ab33b
rm token_usage call in prefill delayer
fzyzcjy Apr 26, 2026
b55c725
DEBUG 0425_015 (re-applied on deepseek_v4): prefill_delayer print all…
fzyzcjy Apr 26, 2026
6a5a127
hack bench_one_batch_server_internal.py
fzyzcjy Apr 26, 2026
914bc4d
fix swa batch full
ispobock Apr 25, 2026
97d73a1
opt swa mem
ispobock Apr 26, 2026
97d1a67
release lock after window
ispobock Apr 26, 2026
7544dc3
Add mooncake cu13 to some dockerfiles (#23694)
Fridge003 Apr 26, 2026
901c389
defensive programming as is suggested by liangsheng
fzyzcjy Apr 27, 2026
ab718a0
fix page over estimation
hnyls2002 Apr 27, 2026
a7e27be
[DeepSeek V4] Fix meaningless numbers in chat output by adding swiglu…
GaoYusong Apr 27, 2026
d245d52
[fix] nixl: transport SWA/NSA/Mamba state buffer (#23773)
YAMY1234 Apr 26, 2026
74cd920
SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA
fzyzcjy Apr 27, 2026
9efbd4b
Merge remote-tracking branch 'upstream/deepseek_v4' into deepseek_v4
fzyzcjy Apr 27, 2026
af2e549
enhance SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA
fzyzcjy Apr 27, 2026
e27e225
more enhance SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA
fzyzcjy Apr 27, 2026
c0788f3
chore: tests
fzyzcjy Apr 27, 2026
8a38621
chore: test path
fzyzcjy Apr 27, 2026
6068b34
more dumps
fzyzcjy Apr 27, 2026
a832f38
indexer logical memory space
yushengsu-thu Apr 27, 2026
873cd56
Add benchmarking scripts for deepseek v4 (#23810)
Fridge003 Apr 27, 2026
98d2236
feat: online softmax compress 128
DarkSharpness Apr 27, 2026
33b07da
Add return_routed_experts kwarg to async_generate (#23822)
hnyls2002 Apr 27, 2026
46584ae
fix: fix i32 index out of bound (i hate triton)
DarkSharpness Apr 27, 2026
4e7925e
fix(profiler): warm CUPTI on first profile to avoid losing GPU events
fzyzcjy Apr 27, 2026
7642335
feat(profiler): gate kineto warmup behind SGLANG_HACK_WARMUP_KINETO
fzyzcjy Apr 27, 2026
f718570
fix(profiler): apply kineto warmup at the right call site
fzyzcjy Apr 27, 2026
46c3c96
fix: online c128
DarkSharpness Apr 27, 2026
c409f44
feat: port SGLANG_JIT_DEEPGEMM_FAST_WARMUP to deepseek_v4 branch (#23…
parrot18 Apr 27, 2026
76e39ff
simple rename file
Fridge003 Apr 28, 2026
d737cae
Update deepgemm hash for gb image
Fridge003 Apr 28, 2026
31b2ab7
set decoder stream
yushengsu-thu Apr 28, 2026
5e8862e
Update deepseek_v2.py
zhangxiaolei123456 Apr 28, 2026
0e417d2
Update environ.py
zhangxiaolei123456 Apr 28, 2026
11fb9c8
Update model_runner.py
zhangxiaolei123456 Apr 28, 2026
2ef8310
Deepseek_v4 support w4(mxfp4)a16 on hopper (#23686)
zhangxiaolei123456 Apr 28, 2026
2682190
fix pd-mtp metadata buffer hidden size (#23919)
hnyls2002 Apr 28, 2026
78343ed
[Misc] Replace BF16->FP32 cuBLAS JIT with torch.mm (#23917)
sihan-zzz Apr 28, 2026
2f59ad3
upd
Fridge003 Apr 28, 2026
0927331
clean unneeded sgl-kernel cmakelist for hopper
Fridge003 Apr 28, 2026
3ee6090
Merge branch 'deepseek_v4' into deepseek_v4_share_expert_tp1
zhangxiaolei123456 Apr 28, 2026
97d93ec
fix sgl-kernel CMakeLists for hopper build after dependency cleanup
yhyang201 Apr 28, 2026
37e6304
Merge branch 'deepseek_v4' into deepseek_v4_share_expert_tp1
zhangxiaolei123456 Apr 28, 2026
dbdd494
Fix dsv4 self-closing invoke tags + accept reasoning_effort=max (#23915)
JustinTong0323 Apr 28, 2026
79281e8
fix: restore FP4 deepgemm path for Blackwell broken by #23686 (#23948)
yhyang201 Apr 28, 2026
c13abdc
add --weight-loader-drop-cache-after-load arg to save CPU mem for RL
yueming-yuan Apr 29, 2026
e899288
Merge branch 'deepseek_v4' into deepseek_v4_share_expert_tp1
zhangxiaolei123456 Apr 29, 2026
cdf5fdb
Update environ.py
zhangxiaolei123456 Apr 29, 2026
f353a07
Update deepseek_v2.py
zhangxiaolei123456 Apr 29, 2026
a7ffb22
Update model_runner.py
zhangxiaolei123456 Apr 29, 2026
23d6f04
Merge pull request #2 from zhangxiaolei123456/zhangxiaolei123456-patch-2
zhangxiaolei123456 Apr 29, 2026
6d0858d
Fix DeepGEMM JIT warmup crash by passing m_indices positionally (#24017)
Fridge003 Apr 29, 2026
9af2828
Merge branch 'deepseek_v4' into deepseek_v4_share_expert_tp1
zhangxiaolei123456 Apr 30, 2026
1b497c7
little fix deepgemm hash on gb image
Fridge003 Apr 30, 2026
d38bfad
Merge branch 'deepseek_v4' into deepseek_v4_share_expert_tp1
zhangxiaolei123456 May 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docker/deepseek_v4_b200.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
FROM lmsysorg/sglang:v0.5.7

# need: cu12.9, x86_64 docker
# Same dependency set as H200 (preset.py treats H200/B200 as one flavor).

RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \
git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git

# tilelang 0.1.8 pinned: mhc.py uses T.gemm(wg_wait=0), removed in 0.1.9.
RUN pip install tilelang==0.1.8

RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu129

RUN cd /tmp && rm -rf flash-mla && \
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \
cd flash-mla && git submodule update --init --recursive && \
pip install --no-build-isolation -v . && \
cd /tmp && rm -rf flash-mla

RUN pip install -e /workspace/sglang/python/

# DeepGEMM must come after sglang install: sglang pyproject pulls
# cuda-python / sgl-kernel / quack-kernels / nvidia-cutlass-dsl, which
# DeepGEMM depends on at the resolved versions.
RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \
cd /tmp && rm -rf DeepGEMM && \
git clone https://github.com/sgl-project/DeepGEMM.git -b release && \
cd DeepGEMM && git checkout 7f2a70 && \
git submodule update --init --recursive && \
bash install.sh

# DeepGEMM install.sh bumps apache-tvm-ffi to 0.1.10, which breaks tilelang
# 0.1.8 ABI. Re-pin to 0.1.9 (--no-deps so pip does not touch deep-gemm).
RUN pip install --no-deps apache-tvm-ffi==0.1.9
48 changes: 48 additions & 0 deletions docker/deepseek_v4_b300.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
FROM lmsysorg/sglang:v0.5.7-cu130-runtime

ENV PIP_BREAK_SYSTEM_PACKAGES=1

# tilelang's bundled libtvm.so depends on libz3.so (no version suffix).
# Base image ships nothing matching, and apt's libz3-4 only installs libz3.so.4.
RUN apt-get update && apt-get install -y --no-install-recommends libz3-4 && \
ln -sf /usr/lib/x86_64-linux-gnu/libz3.so.4 /usr/lib/x86_64-linux-gnu/libz3.so && \
ldconfig && rm -rf /var/lib/apt/lists/*

RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \
git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git

RUN pip install cuda-python --upgrade
RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu130


RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.21/sgl_kernel-0.3.21+cu130-cp310-abi3-manylinux2014_x86_64.whl

RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \
cd /tmp && rm -rf DeepGEMM && \
git clone https://github.com/sgl-project/DeepGEMM.git -b release && \
cd DeepGEMM && git checkout 7f2a70 && \
git submodule update --init --recursive && \
ln -sf $(pwd)/third-party/cutlass/include/cutlass $(pwd)/deep_gemm/include/cutlass && \
ln -sf $(pwd)/third-party/cutlass/include/cute $(pwd)/deep_gemm/include/cute && \
bash install.sh

RUN pip install -e /workspace/sglang/python/


RUN pip install --force-reinstall --no-deps tilelang==0.1.8

RUN pip install nvidia-cuda-cccl && \
CCCL_INC=$(find /usr/local/lib -path "*/include/cccl/cuda/std" -type d 2>/dev/null | head -1 | sed 's|/cuda/std$||') && \
ln -sf $CCCL_INC/cuda /usr/local/cuda/include/cuda && \
mv /usr/local/cuda/targets/x86_64-linux/include/cccl /usr/local/cuda/targets/x86_64-linux/include/cccl.bak && \
ln -sf $CCCL_INC /usr/local/cuda/targets/x86_64-linux/include/cccl
# FlashMLA — required by deepseek_v4_backend_radix.py
RUN cd /tmp && rm -rf flash-mla && \
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \
cd flash-mla && git submodule update --init --recursive && \
pip install --no-build-isolation .
# fast_hadamard_transform — sgl_kernel 0.3.21 lacks hadamard_transform on B300
RUN pip install --no-build-isolation git+https://github.com/Dao-AILab/fast-hadamard-transform.git

# Install mooncake
RUN pip install mooncake-transfer-engine-cuda13
28 changes: 28 additions & 0 deletions docker/deepseek_v4_grace_blackwell.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM lmsysorg/sglang:v0.5.7-cu130-runtime

# need: cu13, arm docker
RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \
git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git

RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.21/sgl_kernel-0.3.21+cu130-cp310-abi3-manylinux2014_aarch64.whl
RUN pip install cuda-python --upgrade
RUN cd /tmp && rm -rf flash-mla && git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && cd flash-mla && ln -s /usr/local/cuda/include/cccl/cuda /usr/local/cuda/include/cuda && git submodule update --init --recursive && pip install --no-build-isolation -v .

RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu130
RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \
cd /tmp && rm -rf DeepGEMM && git clone https://github.com/sgl-project/DeepGEMM.git -b release && \
cd DeepGEMM && git checkout 003ed71 && \
git submodule update --init --recursive && \
ln -sf $(pwd)/third-party/cutlass/include/cutlass $(pwd)/deep_gemm/include/cutlass && \
ln -sf $(pwd)/third-party/cutlass/include/cute $(pwd)/deep_gemm/include/cute && \
bash install.sh
RUN pip install -e /workspace/sglang/python/

# Install TileLang for arm
RUN pip install https://github.com/tile-ai/tilelang/releases/download/v0.1.8/tilelang-0.1.8-cp38-abi3-manylinux_2_34_aarch64.whl

# Install hadamard transform
RUN pip install --no-build-isolation git+https://github.com/Dao-AILab/fast-hadamard-transform.git

# Install mooncake
RUN pip install mooncake-transfer-engine-cuda13
36 changes: 36 additions & 0 deletions docker/deepseek_v4_h200.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
FROM lmsysorg/sglang:v0.5.7

# need: cu12.9, x86_64 docker

RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \
git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git

# tilelang 0.1.8 pinned: mhc.py uses T.gemm(wg_wait=0), removed in 0.1.9.
RUN pip install tilelang==0.1.8

RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu129

RUN cd /tmp && rm -rf flash-mla && \
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \
cd flash-mla && git submodule update --init --recursive && \
pip install --no-build-isolation -v . && \
cd /tmp && rm -rf flash-mla

RUN pip install -e /workspace/sglang/python/

# Build kernel for w4a16 marlin
RUN cd /workspace/sglang/sgl-kernel && make build

# DeepGEMM must come after sglang install: sglang pyproject pulls
# cuda-python / sgl-kernel / quack-kernels / nvidia-cutlass-dsl, which
# DeepGEMM depends on at the resolved versions.
RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \
cd /tmp && rm -rf DeepGEMM && \
git clone https://github.com/sgl-project/DeepGEMM.git -b release && \
cd DeepGEMM && git checkout 7f2a70 && \
git submodule update --init --recursive && \
bash install.sh

# DeepGEMM install.sh bumps apache-tvm-ffi to 0.1.10, which breaks tilelang
# 0.1.8 ABI. Re-pin to 0.1.9 (--no-deps so pip does not touch deep-gemm).
RUN pip install --no-deps apache-tvm-ffi==0.1.9
6 changes: 3 additions & 3 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ dependencies = [
"datasets",
"einops",
"fastapi",
"flashinfer_python==0.6.2", # keep it aligned with jit-cache version in Dockerfile
"flashinfer_cubin==0.6.2",
"flashinfer_python==0.6.8", # keep it aligned with jit-cache version in Dockerfile
"flashinfer_cubin==0.6.8",
"gguf",
"hf_transfer",
"huggingface_hub",
Expand All @@ -55,7 +55,7 @@ dependencies = [
"pydantic",
"python-multipart",
"pyzmq>=25.1.2",
"quack-kernels==0.2.4",
# "quack-kernels==0.2.4", # conflicts with flashinfer 0.6.8 on nvidia-cutlass-dsl (<4.4.0 vs >=4.4.2); not used by current bench flows
"requests",
"scipy",
"sentencepiece",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/jit_kernel/.clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
IncludeCategories:
- Regex: '^<sgl_kernel/.*\.h>$'
Priority: 0
- Regex: '^<sgl_kernel/impl/.*>$'
- Regex: '^<sgl_kernel/.*/.*>$'
Priority: 2
- Regex: '^<sgl_kernel/.*\.cuh>$'
Priority: 1
Expand Down
198 changes: 198 additions & 0 deletions python/sglang/jit_kernel/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from __future__ import annotations

import enum
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, cast

import torch
import tvm_ffi
from tvm_ffi import Module

from sglang.jit_kernel.utils import (
cache_once,
is_arch_support_pdl,
load_jit,
make_cpp_args,
)


class ConfigResult(NamedTuple):
num_blocks: int
num_threads: int


class AllReduceAlgo(enum.Enum):
ONE_SHOT_PUSH = enum.auto()
ONE_SHOT_PULL = enum.auto()
TWO_SHOT_PULL = enum.auto()

def is_push(self) -> bool:
return self == AllReduceAlgo.ONE_SHOT_PUSH

@property
def shot(self) -> int:
return 2 if self == AllReduceAlgo.TWO_SHOT_PULL else 1


if TYPE_CHECKING:
CUSTOM_AR_HANDLE = List[int]
CUSTOM_AR_PAIR = Tuple[int, CUSTOM_AR_HANDLE]

class CustomAllReduceObj:
def __init__(
self,
rank: int,
world_size: int,
pull_buffer_bytes: int,
push_buffer_bytes: int,
graph_input_count: int,
*,
max_pull_blocks: Optional[int] = None,
max_push_blocks: Optional[int] = None,
) -> None:
"""
Create a CustomAllReduceObj instance.

:param rank: The rank of the current process.
:param world_size: The total number of processes in the group.
:param pull_buffer_bytes: The size of the buffer (in bytes) used for pull-based all-reduce.
:param push_buffer_bytes: The size of the buffer (in bytes) used for push-based all-reduce.
:param graph_input_count: The maximum number of inputs in all CUDA graphs.
:param max_pull_blocks: The maximum number of thread blocks to launch for pull-based all-reduce.
If None, it will be determined by the implementation.
:param max_push_blocks: The maximum number of thread blocks to launch for push-based all-reduce.
If None, it will be determined by the implementation.
"""

@property
def world_size(self) -> int: ...
def share_storage(self) -> CUSTOM_AR_HANDLE: ...
def share_graph_inputs(self) -> List[CUSTOM_AR_PAIR]: ...
def post_init(self, handles: List[CUSTOM_AR_HANDLE]) -> None: ...
def register_inputs(self, handles: List[List[CUSTOM_AR_PAIR]]) -> None: ...
def set_cuda_graph_capture(self, is_capturing: bool) -> None: ...
def free(self, tp_cpu_group: torch.distributed.ProcessGroup) -> None: ...
def all_reduce(
self, input: torch.Tensor, algo: AllReduceAlgo
) -> tvm_ffi.Tensor: ...
def config_pull(
self, num_blocks: int = -1, num_threads: int = -1
) -> ConfigResult:
"""
Configure the CUDA kernel's grid and block dimensions.
This provides only the upper bound of the configuration,
and the actual launch configuration may be determined by implementation.
Note that push-based all-reduce can not be configured currently.

:param num_blocks: The maximum number of thread blocks to launch. -1 means no limit.
:param num_threads: The maximum number of threads per block. -1 means no limit.

:return: The previous configuration as a ConfigResult named tuple.
"""
...


@cache_once
def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int) -> Module:
args = make_cpp_args(dtype, world_size, is_arch_support_pdl())
return load_jit(
"custom_all_reduce_pull",
*args,
extra_ldflags=["-lcuda"],
cuda_files=["distributed/custom_all_reduce_pull.cuh"],
cuda_wrappers=[("all_reduce", f"custom_all_reduce<{args}>")],
)


@cache_once
def _jit_custom_all_reduce_push_module(dtype: torch.dtype, world_size: int) -> Module:
args = make_cpp_args(dtype, world_size, is_arch_support_pdl())
return load_jit(
"custom_all_reduce_push",
*args,
extra_ldflags=["-lcuda"],
cuda_files=["distributed/custom_all_reduce_push.cuh"],
cuda_wrappers=[("all_reduce", f"custom_all_reduce<{args}>")],
)


@cache_once
def get_custom_all_reduce_cls() -> type[CustomAllReduceObj]:
module = load_jit(
"custom_all_reduce_base",
extra_ldflags=["-lcuda"],
cuda_files=["distributed/custom_all_reduce_base.cuh"],
cuda_wrappers=[("register_once", "register_custom_all_reduce")],
)
module.register_once()
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
NUM_CTA = props.multi_processor_count
MAX_THREADS = 512

@tvm_ffi.register_object("sgl.CustomAllReduce")
class CustomAllReduceObjReal(tvm_ffi.Object):
__slots__ = ("__dict__",)

def __init__(
self,
rank: int,
world_size: int,
pull_buffer_bytes: int,
push_buffer_bytes: int,
graph_input_count: int,
*,
max_pull_blocks: Optional[int] = None,
max_push_blocks: Optional[int] = None,
) -> None:
max_pull_blocks = NUM_CTA if max_pull_blocks is None else max_pull_blocks
max_push_blocks = NUM_CTA if max_push_blocks is None else max_push_blocks
self.__ffi_init__(
rank,
world_size,
max_pull_blocks,
max_push_blocks,
pull_buffer_bytes,
push_buffer_bytes,
graph_input_count,
)
self._world_size = world_size
self._pull_config = ConfigResult(min(NUM_CTA, max_pull_blocks), MAX_THREADS)
if max_pull_blocks > 0: # special case: cannot configure 0 blocks
self.configure_pull(*self._pull_config) # type: ignore

@property
def world_size(self) -> int:
return self._world_size

def all_reduce(
self,
input: torch.Tensor,
algo: AllReduceAlgo,
) -> tvm_ffi.Tensor:
compile_fn = (
_jit_custom_all_reduce_push_module
if algo.is_push()
else _jit_custom_all_reduce_pull_module
)
module = compile_fn(input.dtype, self._world_size)
return module.all_reduce(self, input, algo.shot)

def config_pull(
self, num_blocks: int = -1, num_threads: int = -1
) -> ConfigResult:
old_config = self._pull_config
num_blocks = num_blocks if num_blocks != -1 else old_config.num_blocks
num_threads = num_threads if num_threads != -1 else old_config.num_threads
new_config = ConfigResult(num_blocks, num_threads)
if new_config != old_config:
result = ConfigResult(*self.configure_pull(*new_config)) # type: ignore
assert result == self._pull_config
self._pull_config = new_config
return old_config

def free(self, tp_cpu_group: torch.distributed.ProcessGroup) -> None:
self.free_ipc_handles() # type: ignore
torch.distributed.barrier(group=tp_cpu_group)
self.free_storage() # type: ignore

return cast(type["CustomAllReduceObj"], CustomAllReduceObjReal)
Loading
Loading