Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/dist/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def tensor_model_parallel_all_reduce(


def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor, weight_: torch.Tensor, eps: float
input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
return get_tp_group().fused_allreduce_rmsnorm(input_, weight_, eps)
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)


def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions aiter/dist/device_communicators/communicator_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
return out

def fused_allreduce_rmsnorm(
self, input_, weight_, eps
self, input_, res_inp_, weight_, eps
) -> tuple[torch.Tensor, torch.Tensor]:
n = input_.shape[-1]
can_use_fuse_ar_rms = (
Expand All @@ -174,7 +174,7 @@ def fused_allreduce_rmsnorm(
and ca_comm.should_custom_ar(input_)
and can_use_fuse_ar_rms
):
res_out, out = ca_comm.custom_fused_ar_rms(input_, weight_, eps)
res_out, out = ca_comm.custom_fused_ar_rms(input_, res_inp_, weight_, eps)
assert out is not None
assert res_out is not None
return res_out, out
Expand Down
16 changes: 13 additions & 3 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def custom_all_gather(self, inp: torch.Tensor) -> Optional[torch.Tensor]:
def fused_ar_rms(
self,
inp: torch.Tensor,
res_inp: torch.Tensor,
*,
res_out: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
Expand All @@ -353,6 +354,7 @@ def fused_ar_rms(
ops.fused_allreduce_rmsnorm(
self._ptr,
inp,
res_inp,
res_out,
out,
w,
Expand All @@ -362,18 +364,26 @@ def fused_ar_rms(
return res_out, out

def custom_fused_ar_rms(
self, input: torch.Tensor, weight: torch.Tensor, eps: float
self,
input: torch.Tensor,
residual_inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.fused_ar_rms(input, w=weight, eps=eps, registered=True)
return self.fused_ar_rms(
input, residual_inp, w=weight, eps=eps, registered=True
)
else:
return torch.zeros_like(input), torch.zeros_like(input)
else:
return self.fused_ar_rms(input, w=weight, eps=eps, registered=False)
return self.fused_ar_rms(
input, residual_inp, w=weight, eps=eps, registered=False
)

def close(self):
if not self.disabled and self._ptr:
Expand Down
32 changes: 25 additions & 7 deletions aiter/dist/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,28 @@ def all_reduce_(


def fused_allreduce_rmsnorm_fake(
inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str
inp: torch.Tensor,
res_inp: torch.Tensor,
w: torch.Tensor,
eps: float,
group_name: str,
) -> torch.Tensor:
return torch.empty_like(inp)


@torch_compile_guard(gen_fake=fused_allreduce_rmsnorm_fake)
def fused_allreduce_rmsnorm_(
inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str
inp: torch.Tensor,
res_inp: torch.Tensor,
w: torch.Tensor,
eps: float,
group_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._fused_allreduce_rmsnorm_out_place(inp, w, eps)
return group._fused_allreduce_rmsnorm_out_place(inp, res_inp, w, eps)


if supports_custom_op():
Expand Down Expand Up @@ -347,18 +355,28 @@ def _all_reduce_out_place(
return self.device_communicator.all_reduce(input_, ca_fp8_quant)

def fused_allreduce_rmsnorm(
self, input_: torch.Tensor, weight_: torch.Tensor, eps: float
self,
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return fused_allreduce_rmsnorm_(
input_, weight_, eps, group_name=self.unique_name
input_, residual_inp_, weight_, eps, group_name=self.unique_name
)

def _fused_allreduce_rmsnorm_out_place(
self, input_: torch.Tensor, weight_: torch.Tensor, eps: float
self,
input_: torch.Tensor,
residual_inp_: torch.Tensor,
weight_: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.fused_allreduce_rmsnorm(input_, weight_, eps)
return self.device_communicator.fused_allreduce_rmsnorm(
input_, residual_inp_, weight_, eps
)

def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.device_communicator.ca_comm
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def all_gather_unreg(
def fused_allreduce_rmsnorm(
_fa: int,
inp: torch.Tensor,
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
w: torch.Tensor,
Expand Down
51 changes: 33 additions & 18 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ namespace aiter
}
tmp_out[idx - start] = write_reg;
}
__syncthreads();
}
end_sync<ngpus>(sg, self_sg, rank);

Expand Down Expand Up @@ -935,9 +936,8 @@ namespace aiter
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
float res_x = ck_tile::type_convert<float>(input_reg.data[i]);
float sum_x = *(reinterpret_cast<float*>(&tmp_smem[0]) + lane_id * pack_size + i);
rslt.data[i] = ck_tile::type_convert<T>(res_x + sum_x);
rslt.data[i] = ck_tile::type_convert<T>(sum_x);
}
tmps[warp_id][(rank * part + bid) * tnum_gpu + lane_id] = rslt;
}
Expand Down Expand Up @@ -976,6 +976,7 @@ namespace aiter
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm_naive(
RankSignals sg,
T* __restrict__ residual_inp,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
Expand All @@ -1000,14 +1001,18 @@ namespace aiter
for (int n_iter = 0; n_iter < n_loop; ++n_iter)
{
int read_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
rmsnorm_inp[n_iter] = tmps[read_idx];
P reduce_out_pack = tmps[read_idx];
P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * blockDim.x + threadIdx.x);
A reduce_pack;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
float ar_elem = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
reduce_pack.data[i] = ar_elem * ar_elem;
float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
float rms_inp = res_inp + ar_out;
rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
reduce_pack.data[i] = rms_inp * rms_inp;
}
square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
}
Expand Down Expand Up @@ -1041,6 +1046,7 @@ namespace aiter
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm(
RankSignals sg,
T* __restrict__ residual_inp,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
Expand All @@ -1067,14 +1073,18 @@ namespace aiter
if (n_iter * tnum + threadIdx.x < (n / pack_size))
{
int read_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
rmsnorm_inp[n_iter] = tmps[read_idx];
P reduce_out_pack = tmps[read_idx];
P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * tnum + threadIdx.x);
A reduce_pack;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
float ar_elem = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
reduce_pack.data[i] = ar_elem * ar_elem;
float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
float rms_inp = ar_out + res_inp;
rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
reduce_pack.data[i] = rms_inp * rms_inp;
}
square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
}
Expand Down Expand Up @@ -1108,6 +1118,7 @@ namespace aiter
template <typename T, int n_loop>
__global__ void __launch_bounds__(256, 1) local_device_load_rmsnorm_512n(
RankSignals sg,
T* __restrict__ residual_inp,
T* __restrict__ residual_out,
T* __restrict__ results,
T* __restrict__ weight,
Expand All @@ -1134,14 +1145,18 @@ namespace aiter
for (int n_iter = 0; n_iter < n_loop; ++n_iter)
{
int read_idx = bid * 64 * n_loop + n_iter * 64 + lane_id;
rmsnorm_inp[n_iter] = tmps[read_idx];
P reduce_out_pack = tmps[read_idx];
P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * 64 + lane_id);
A reduce_pack;
#pragma unroll
for (int i = 0; i < pack_size; ++i)
{
float ar_elem = ck_tile::type_convert<float>(rmsnorm_inp[n_iter].data[i]);
reduce_pack.data[i] = ar_elem * ar_elem;
float ar_out = ck_tile::type_convert<float>(reduce_out_pack.data[i]);
float res_inp = ck_tile::type_convert<float>(residual_inp_pack.data[i]);
float rms_inp = ar_out + res_inp;
rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert<T>(rms_inp);
reduce_pack.data[i] = rms_inp * rms_inp;
}
float tmp_sum = packReduce<AddFunctor, float, pack_size>(reduce_pack);
square_sum += tmp_sum;
Expand Down Expand Up @@ -1587,7 +1602,7 @@ namespace aiter
}

template <typename T>
void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_out, T* output, T* weight, float eps, int m, int n)
void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_inp, T* residual_out, T* output, T* weight, float eps, int m, int n)
{
auto d = packed_t<T>::P::size;
int size = m * n;
Expand Down Expand Up @@ -1633,12 +1648,12 @@ namespace aiter
grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy;
};

#define launch_fused_allreduce_rmsnorm(template_kernel) \
do \
{ \
auto kernel_ptr = reinterpret_cast<const void*>(template_kernel); \
setGrid(naive_grid_size, kernel_ptr); \
template_kernel<<<grid, block, 0, stream>>>(sg_, residual_out, output, weight, eps, rank_, m, n); \
#define launch_fused_allreduce_rmsnorm(template_kernel) \
do \
{ \
auto kernel_ptr = reinterpret_cast<const void*>(template_kernel); \
setGrid(naive_grid_size, kernel_ptr); \
template_kernel<<<grid, block, 0, stream>>>(sg_, residual_inp, residual_out, output, weight, eps, rank_, m, n); \
} while (0)

if (n_bytes % 1024 == 0)
Expand Down
1 change: 1 addition & 0 deletions csrc/include/custom_all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void all_gather_unreg(fptr_t _fa,
torch::Tensor& out);
void fused_allreduce_rmsnorm(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& res_inp,
torch::Tensor& res_out,
torch::Tensor& out,
torch::Tensor& w,
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ namespace py = pybind11;
&aiter::fused_allreduce_rmsnorm, \
py::arg("_fa"), \
py::arg("inp"), \
py::arg("res_inp"), \
py::arg("res_out"), \
py::arg("out"), \
py::arg("w"), \
Expand Down
12 changes: 9 additions & 3 deletions csrc/kernels/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
}

void _fused_allreduce_rmsnorm(
fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream)
fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream)
{
auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
Expand All @@ -226,6 +226,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::Float: {
fa->dispatchFusedAllReduceRMSNorm<float>(stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(residual_inp.data_ptr()),
reinterpret_cast<float*>(residual_out.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<float*>(w.data_ptr()),
Expand All @@ -235,6 +236,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::Half: {
fa->dispatchFusedAllReduceRMSNorm<half>(stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(residual_inp.data_ptr()),
reinterpret_cast<half*>(residual_out.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
reinterpret_cast<half*>(w.data_ptr()),
Expand All @@ -245,6 +247,7 @@ void _fused_allreduce_rmsnorm(
case at::ScalarType::BFloat16: {
fa->dispatchFusedAllReduceRMSNorm<__hip_bfloat16>(stream,
reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(residual_inp.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(residual_out.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(w.data_ptr()),
Expand All @@ -259,6 +262,7 @@ void _fused_allreduce_rmsnorm(

void fused_allreduce_rmsnorm(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& res_inp,
torch::Tensor& res_out,
torch::Tensor& out,
torch::Tensor& w,
Expand All @@ -268,7 +272,9 @@ void fused_allreduce_rmsnorm(fptr_t _fa,
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.scalar_type(), res_inp.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_EQ(inp.numel(), res_inp.numel());
int n = w.numel();
int m = inp.numel() / n;

Expand All @@ -282,11 +288,11 @@ void fused_allreduce_rmsnorm(fptr_t _fa,
input_size,
hipMemcpyDeviceToDevice,
stream));
_fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_out, out, w, eps, m, n, stream);
_fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_inp, res_out, out, w, eps, m, n, stream);
}
else
{
_fused_allreduce_rmsnorm(_fa, inp, res_out, out, w, eps, m, n, stream);
_fused_allreduce_rmsnorm(_fa, inp, res_inp, res_out, out, w, eps, m, n, stream);
}
}

Expand Down
Loading