Skip to content

Commit a6adaae

Browse files
authored
[Unity][DistIR] LowerDistIR (#16169)
* lower distir * format * format * add whitespace * fix lint * fix warning
1 parent 8a6184c commit a6adaae

File tree

5 files changed

+690
-1
lines changed

5 files changed

+690
-1
lines changed

include/tvm/relax/distributed/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ TVM_DLL Pass LowerGlobalViewToLocalView();
6262
*/
6363
TVM_DLL Pass LegalizeRedistribute();
6464

65+
/*!
66+
* \brief Lower DistIR to Relax
67+
*
68+
* \return The Pass.
69+
*/
70+
TVM_DLL Pass LowerDistIR();
6571
} // namespace transform
6672
} // namespace distributed
6773
} // namespace relax

python/tvm/relax/distributed/transform/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,9 @@
1616
# under the License.
1717
"""Relax distributed-related transformations. """
1818

19-
from .transform import PropagateSharding, LowerGlobalViewToLocalView, LegalizeRedistribute
19+
from .transform import (
20+
PropagateSharding,
21+
LowerGlobalViewToLocalView,
22+
LegalizeRedistribute,
23+
LowerDistIR,
24+
)

python/tvm/relax/distributed/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,14 @@ def LegalizeRedistribute() -> tvm.ir.transform.Pass:
5454
The registered pass
5555
"""
5656
return _ffi_api.LegalizeRedistribute() # type: ignore
57+
58+
59+
def LowerDistIR() -> tvm.ir.transform.Pass:
60+
"""Lower DistIR to Relax
61+
62+
Returns
63+
-------
64+
ret : tvm.transform.Pass
65+
The registered pass
66+
"""
67+
return _ffi_api.LowerDistIR() # type: ignore
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/distributed/transform/lower_distir.cc
22+
* \brief Pass for lowering DistIR into Relax
23+
* This pass assumes all the TensorIR functions are in local view,
24+
* so the pass only handles sharding relax tensor shape and
25+
* inserting necessary broadcast and scatter for inputs.
26+
*/
27+
28+
#include <tvm/relax/attrs/ccl.h>
29+
#include <tvm/relax/distributed/axis_group_graph.h>
30+
#include <tvm/relax/distributed/transform.h>
31+
#include <tvm/relax/expr_functor.h>
32+
#include <tvm/tir/stmt_functor.h>
33+
34+
#include "../../../tir/schedule/transform.h"
35+
#include "../../op/ccl/ccl.h"
36+
#include "../../op/tensor/manipulate.h"
37+
#include "utils.h"
38+
39+
namespace tvm {
40+
namespace relax {
41+
namespace distributed {
42+
43+
class DistIRSharder : public ExprMutator {
44+
public:
45+
static IRModule LowerDistIR(IRModule mod) { return DistIRSharder(mod).Lower(); }
46+
47+
private:
48+
explicit DistIRSharder(IRModule mod) : ExprMutator(mod) {}
49+
50+
IRModule Lower() {
51+
auto mod = builder_->GetContextIRModule();
52+
for (const auto& [gv, base_func] : mod->functions) {
53+
const auto* func_ = base_func.as<FunctionNode>();
54+
if (func_ == nullptr || !IsDistIRFunc(GetRef<Function>(func_))) {
55+
continue;
56+
}
57+
Function func = RewriteFunction(GetRef<Function>(func_));
58+
builder_->UpdateFunction(gv, func);
59+
}
60+
return builder_->GetContextIRModule();
61+
}
62+
63+
ShapeExpr ShardShape(ShapeExpr orig_shape, DeviceMesh device_mesh, Placement placement) {
64+
ShapeTuple device_mesh_shape = device_mesh->shape;
65+
Array<PrimExpr> new_tensor_shape_value = orig_shape->values;
66+
for (int i = 0; i < static_cast<int>(device_mesh_shape.size()); i++) {
67+
if (placement->dim_specs[i]->kind == PlacementSpecKind::kSharding) {
68+
int shard_size = device_mesh_shape[i];
69+
int axis = placement->dim_specs[i]->axis;
70+
new_tensor_shape_value.Set(axis, floordiv(orig_shape->values[axis], shard_size));
71+
}
72+
}
73+
return ShapeExpr(new_tensor_shape_value);
74+
}
75+
76+
TensorStructInfo ShardDTensorSinfo(DTensorStructInfo orig_sinfo) {
77+
TensorStructInfo tensor_sinfo = orig_sinfo->tensor_sinfo;
78+
ICHECK(tensor_sinfo->shape);
79+
const auto* orig_shape = tensor_sinfo->shape.as<ShapeExprNode>();
80+
auto new_tensor_sinfo = make_object<TensorStructInfoNode>(*tensor_sinfo.get());
81+
new_tensor_sinfo->shape =
82+
ShardShape(GetRef<ShapeExpr>(orig_shape), orig_sinfo->device_mesh, orig_sinfo->placement);
83+
return TensorStructInfo(new_tensor_sinfo);
84+
}
85+
86+
StructInfo ConvertSinfo(StructInfo orig_sinfo, bool shard_shape) {
87+
if (const auto* dtensor_sinfo = orig_sinfo.as<DTensorStructInfoNode>()) {
88+
if (shard_shape) {
89+
return ShardDTensorSinfo(GetRef<DTensorStructInfo>(dtensor_sinfo));
90+
} else {
91+
return dtensor_sinfo->tensor_sinfo;
92+
}
93+
} else if (const auto* tuple_sinfo = orig_sinfo.as<TupleStructInfoNode>()) {
94+
Array<StructInfo> new_fields;
95+
for (const auto& field_sinfo : tuple_sinfo->fields) {
96+
if (const auto* dtensor_sinfo = field_sinfo.as<DTensorStructInfoNode>()) {
97+
if (shard_shape) {
98+
new_fields.push_back(ShardDTensorSinfo(GetRef<DTensorStructInfo>(dtensor_sinfo)));
99+
} else {
100+
new_fields.push_back(dtensor_sinfo->tensor_sinfo);
101+
}
102+
} else {
103+
new_fields.push_back(field_sinfo);
104+
}
105+
}
106+
return TupleStructInfo(new_fields);
107+
} else {
108+
return orig_sinfo;
109+
}
110+
}
111+
112+
Expr ShardInputParamTensorAndConstant(Expr input) {
113+
ICHECK(input->struct_info_);
114+
StructInfo old_sinfo = GetStructInfo(input);
115+
StructInfo new_sinfo = ConvertSinfo(old_sinfo, false);
116+
if (const auto* var = input.as<VarNode>()) {
117+
Var new_param(var->name_hint(), new_sinfo);
118+
return new_param;
119+
} else if (const auto* constant = input.as<ConstantNode>()) {
120+
for (const auto& spec : Downcast<DTensorStructInfo>(old_sinfo)->placement->dim_specs) {
121+
ICHECK(spec->kind == PlacementSpecKind::kReplica);
122+
}
123+
Constant new_constant(constant->data, new_sinfo);
124+
return new_constant;
125+
} else {
126+
LOG(FATAL) << "Cannot shard tensor which is not Var or Constant: " << input;
127+
throw;
128+
}
129+
}
130+
131+
void EmitBroadcastOrScatter(Expr old_expr, Expr new_expr, DTensorStructInfo dtensor_sinfo) {
132+
// FIXME: this is a hack that only works for 1d device mesh
133+
ICHECK(dtensor_sinfo->device_mesh->shape.size() == 1);
134+
PlacementSpec sharding_spec = dtensor_sinfo->placement->dim_specs[0];
135+
if (sharding_spec->kind == PlacementSpecKind::kReplica) {
136+
Var new_var = builder_->Emit(broadcast_from_worker0(new_expr));
137+
if (const auto* var = old_expr.as<VarNode>()) {
138+
var_remap_[var->vid] = new_var;
139+
} else {
140+
tuple_getitem_remap_[Downcast<TupleGetItem>(old_expr)] = new_var;
141+
}
142+
} else if (sharding_spec->kind == PlacementSpecKind::kSharding) {
143+
Var scatter_var = builder_->Emit(scatter_from_worker0(
144+
new_expr, dtensor_sinfo->device_mesh->shape[0], sharding_spec->axis));
145+
if (const auto* var = old_expr.as<VarNode>()) {
146+
var_remap_[var->vid] = scatter_var;
147+
} else {
148+
tuple_getitem_remap_[Downcast<TupleGetItem>(old_expr)] = scatter_var;
149+
}
150+
} else {
151+
LOG(FATAL) << "Unsupported placement spec";
152+
}
153+
}
154+
155+
void InputPreprocessing() {
156+
for (int i = 0; i < static_cast<int>(func_->params.size()); i++) {
157+
Var param = func_->params[i];
158+
if (const auto* dtensor_sinfo = GetStructInfoAs<DTensorStructInfoNode>(param)) {
159+
EmitBroadcastOrScatter(param, new_params_[i], GetRef<DTensorStructInfo>(dtensor_sinfo));
160+
} else if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(param)) {
161+
for (int j = 0; j < static_cast<int>(tuple_sinfo->fields.size()); j++) {
162+
if (const auto* dtensor_sinfo = tuple_sinfo->fields[j].as<DTensorStructInfoNode>()) {
163+
EmitBroadcastOrScatter(TupleGetItem(param, j), TupleGetItem(new_params_[i], j),
164+
GetRef<DTensorStructInfo>(dtensor_sinfo));
165+
}
166+
}
167+
}
168+
}
169+
}
170+
171+
Function RewriteFunction(Function func) {
172+
Array<Var> new_params;
173+
for (const Var& var : func->params) {
174+
Var new_param = Downcast<Var>(ShardInputParamTensorAndConstant(var));
175+
var_remap_[var->vid] = new_param;
176+
new_params.push_back(new_param);
177+
}
178+
func_ = func;
179+
new_params_ = new_params;
180+
auto new_body = VisitWithNewScope(func->body, new_params);
181+
Function new_func(new_params, new_body, NullOpt, func->is_pure, func->attrs);
182+
return new_func;
183+
}
184+
185+
void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) {
186+
if (tuple_getitem_remap_.count(GetRef<TupleGetItem>(val))) {
187+
var_remap_[binding->var->vid] = tuple_getitem_remap_[GetRef<TupleGetItem>(val)];
188+
} else {
189+
ExprMutator::VisitBinding_(binding, val);
190+
}
191+
}
192+
193+
BindingBlock VisitBindingBlock_(const BindingBlockNode* block) {
194+
builder_->BeginBindingBlock();
195+
InputPreprocessing();
196+
for (Binding binding : block->bindings) {
197+
this->VisitBinding(binding);
198+
}
199+
return builder_->EndBlock();
200+
}
201+
202+
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
203+
builder_->BeginDataflowBlock();
204+
InputPreprocessing();
205+
for (auto binding : block->bindings) {
206+
this->VisitBinding(binding);
207+
}
208+
return builder_->EndBlock();
209+
}
210+
211+
Call HandleSpecialCaseinDTensorLowering(const CallNode* call, Var binding_var) {
212+
static Op reshape_op = Op::Get("relax.reshape");
213+
static Op call_tir_op = Op::Get("relax.call_tir");
214+
static Op call_tir_local_view_op = Op::Get("relax.dist.call_tir_local_view");
215+
if (call->op.same_as(reshape_op)) {
216+
ICHECK(call->args[1].as<ShapeExprNode>());
217+
const auto* out_sinfo = GetStructInfoAs<DTensorStructInfoNode>(binding_var);
218+
ICHECK(out_sinfo);
219+
auto new_call_node = make_object<CallNode>(*call);
220+
new_call_node->args.Set(1, ShardShape(Downcast<ShapeExpr>(call->args[1]),
221+
out_sinfo->device_mesh, out_sinfo->placement));
222+
return Call(new_call_node);
223+
} else if (call->op.same_as(call_tir_local_view_op)) {
224+
auto new_call_node = make_object<CallNode>(*call);
225+
new_call_node->op = call_tir_op;
226+
new_call_node->sinfo_args = {ConvertSinfo(GetStructInfo(binding_var), true)};
227+
return Call(new_call_node);
228+
} else if (call->op.same_as(call_tir_op)) {
229+
LOG(FATAL) << "call_tir should be lowered to call_tir_local_view before lowering to relax";
230+
} else if (const auto* extern_func = call->op.as<ExternFuncNode>()) {
231+
auto new_call_node = make_object<CallNode>(*call);
232+
if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_append") {
233+
new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_append");
234+
} else if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") {
235+
new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_view");
236+
auto orig_shape = Downcast<ShapeExpr>(call->args[1]);
237+
const auto* out_sinfo = GetStructInfoAs<DTensorStructInfoNode>(binding_var);
238+
ICHECK(out_sinfo);
239+
ShapeExpr new_shape = ShardShape(orig_shape, out_sinfo->device_mesh, out_sinfo->placement);
240+
new_call_node->args.Set(1, new_shape);
241+
new_call_node->sinfo_args = {TensorStructInfo(new_shape, out_sinfo->tensor_sinfo->dtype)};
242+
}
243+
return Call(new_call_node);
244+
}
245+
return GetRef<Call>(call);
246+
}
247+
248+
void VisitBinding_(const VarBindingNode* binding, const CallNode* val) {
249+
Call new_call =
250+
Downcast<Call>(this->VisitExpr(HandleSpecialCaseinDTensorLowering(val, binding->var)));
251+
ReEmitBinding(binding, builder_->Normalize(new_call));
252+
}
253+
254+
Function func_;
255+
Array<Var> new_params_;
256+
std::unordered_map<TupleGetItem, Var, StructuralHash, StructuralEqual> tuple_getitem_remap_;
257+
};
258+
259+
namespace transform {
260+
261+
Pass LowerDistIR() {
262+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
263+
[=](IRModule m, PassContext pc) { return DistIRSharder::LowerDistIR(m); };
264+
return CreateModulePass(pass_func, 1, "LowerDistIR", {});
265+
}
266+
TVM_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR);
267+
} // namespace transform
268+
269+
} // namespace distributed
270+
} // namespace relax
271+
} // namespace tvm

0 commit comments

Comments
 (0)