Skip to content

Commit 1e6482c

Browse files
authored
[Relax] Expose name_hint field for BlockBuilder.match_cast (#16600)
* [Relax] Expose name_hint field for BlockBuilder.match_cast Prior to this commit, while a `relax.VarBinding` created using `BlockBuilder.emit` could have its name explicitly specified by the user, a `relax.MatchCast` created using `BlockBuilder.match_cast` could not. This commit updates `BlockBuilder.match_cast` to accept an optional `name_hint` parameter, which is then provided to the C++ `BlockBuilder::EmitMatchCast` method. * Fix lint error
1 parent 36ebcd0 commit 1e6482c

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

python/tvm/relax/block_builder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32"))
534534
name_hint = kwargs.pop("name_hint", "")
535535
return self.emit(self.call_te(func, *args, **kwargs), name_hint=name_hint)
536536

537-
def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
537+
def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "") -> Var:
538538
"""Emit a MatchCast.
539539
540540
Parameters
@@ -545,12 +545,20 @@ def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
545545
struct_info : StructInfo
546546
The struct info to be matched.
547547
548+
name_hint : str
549+
The name of the match cast
550+
548551
Returns
549552
-------
550553
ret : tvm.relax.Var
551554
A newly created variable that get bounds to be the casted result.
552555
"""
553-
return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore
556+
return _ffi_api.BlockBuilderEmitMatchCast(
557+
self,
558+
value,
559+
struct_info,
560+
name_hint,
561+
) # type: ignore
554562

555563
def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "") -> Var:
556564
"""Emit output for the current dataflow block or function.

src/relax/ir/block_builder.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,8 +1015,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
10151015
});
10161016

10171017
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
1018-
.set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) {
1019-
return builder->EmitMatchCast(value, struct_info);
1018+
.set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) {
1019+
return builder->EmitMatchCast(value, struct_info, name_hint);
10201020
});
10211021

10221022
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")

tests/python/relax/test_blockbuilder_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_emit_match_cast():
226226
assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32"))
227227

228228
# lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
229-
lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]))
229+
lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
230230
assert lv1.struct_info == rx.ShapeStructInfo([m, n])
231231
gv0 = bb.emit_output(lv1)
232232

@@ -244,6 +244,7 @@ def test_emit_match_cast():
244244
assert b1.value == y
245245
assert b1.struct_info == rx.ShapeStructInfo([m, n])
246246
assert b1.var == lv1
247+
assert b1.var.name_hint == "var_name"
247248

248249

249250
def test_emit_match_cast_binding_in_dataflow_block():

0 commit comments

Comments
 (0)