Skip to content

Commit 118e3b1

Browse files
authored
[Relax][Frontend][ONNX] Error converting operator Expand: TVMError: broadcast_to expects the input tensor shape is broadcastable to the target shape (#18329)
1 parent 4c82c71 commit 118e3b1

File tree

2 files changed

+177
-8
lines changed

2 files changed

+177
-8
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,15 +1910,47 @@ def _impl_v13(cls, bb, inputs, attr, params):
19101910
if isinstance(shape, relax.ShapeExpr):
19111911
data_shape = list(data.struct_info.shape)
19121912
target_shape = list(shape.values)
1913+
original_data_shape = [
1914+
dim.value if hasattr(dim, "value") else str(dim) for dim in data_shape
1915+
]
1916+
original_target_shape = [
1917+
dim.value if hasattr(dim, "value") else str(dim) for dim in target_shape
1918+
]
19131919
data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape
19141920
assert len(data_shape) == len(target_shape)
1915-
# Fix small target shapes or target shapes assigned to -1
1921+
# Apply ONNX v13 Expand broadcasting rules
19161922
for i, s in enumerate(target_shape):
1917-
if isinstance(s, tvm.tir.IntImm) and (
1918-
(isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i])
1919-
or s.value == -1
1920-
):
1921-
target_shape[i] = data_shape[i]
1923+
if isinstance(s, tvm.tir.IntImm):
1924+
if s.value == -1:
1925+
# -1 means preserve the input dimension
1926+
target_shape[i] = data_shape[i]
1927+
elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1:
1928+
# Input dimension is 1, can broadcast to any target dimension >= 1
1929+
if s.value < 1:
1930+
raise ValueError(
1931+
f"ONNX Expand: Invalid target dimension {s.value} "
1932+
f"at possition {i}. Target dimensions must be >= 1."
1933+
)
1934+
elif (
1935+
isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value
1936+
):
1937+
# Dimensions match, no change needed
1938+
pass
1939+
elif s.value == 1:
1940+
# Target dimension is 1 but input dimension is not 1
1941+
# This would "squeeze" the dimension - preserve input for safety
1942+
target_shape[i] = data_shape[i]
1943+
else:
1944+
if isinstance(data_shape[i], tvm.tir.IntImm):
1945+
raise ValueError(
1946+
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} "
1947+
f"to target shape {original_target_shape}. "
1948+
f"At dimension {i}: input size {data_shape[i].value} is "
1949+
f"incompatible with target size {s.value}. "
1950+
f"ONNX broadcasting requires corresponding dimensions to have "
1951+
f"the same value or one of them to be 1."
1952+
)
1953+
# For dynamic shapes, let broadcast_to handle it
19221954
if target_shape == data_shape:
19231955
return data
19241956
return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
@@ -1929,6 +1961,8 @@ def _impl_v13(cls, bb, inputs, attr, params):
19291961
# ONNX Expand operator requires preserving target rank and broadcasting
19301962
# according to standard rules. Dimensions are right-aligned.
19311963
data_shape = [dim.value for dim in data.struct_info.shape]
1964+
original_data_shape = data_shape.copy()
1965+
original_new_shape = new_shape.copy()
19321966

19331967
# Right-align the shapes
19341968
if len(new_shape) > len(data_shape):
@@ -1938,8 +1972,32 @@ def _impl_v13(cls, bb, inputs, attr, params):
19381972
# Fix small target shapes - if target dim is smaller than input dim
19391973
# use the input dim (ONNX-specific behavior).
19401974
for i in range(len(new_shape)):
1941-
if new_shape[i] < data_shape[i]:
1975+
if new_shape[i] == -1:
1976+
# -1 means preserve the input dimension
1977+
new_shape[i] = data_shape[i]
1978+
elif data_shape[i] == 1:
1979+
# Input dimension is 1, can broadcast to any target dimension >= 1
1980+
if new_shape[i] < 1:
1981+
raise ValueError(
1982+
f"ONNX Expand: Invalid target dimension {new_shape[i]} "
1983+
f"at possition {i}. Target dimensions must be >= 1."
1984+
)
1985+
elif new_shape[i] == data_shape[i]:
1986+
# Dimensions match, no change needed
1987+
pass
1988+
elif new_shape[i] == 1:
1989+
# Target dimension is 1 but input dimension is not 1
1990+
# This would "squeeze" the dimension - preserve input for safety
19421991
new_shape[i] = data_shape[i]
1992+
else:
1993+
raise ValueError(
1994+
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} "
1995+
f"to target shape {original_new_shape}. "
1996+
f"At dimension {i}: input size {data_shape[i]} is incompatible "
1997+
f"with target size {new_shape[i]}. "
1998+
f"ONNX broadcasting requires corresponding dimensions to have the same "
1999+
f"value or one of them to be 1."
2000+
)
19432001
return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
19442002

19452003
# Otherwise handle dynamic shapes.
@@ -1956,7 +2014,18 @@ def _impl_v13(cls, bb, inputs, attr, params):
19562014
for i in range(shape_ndim):
19572015
shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
19582016
bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
1959-
return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars)))
2017+
2018+
# Applying broadcasting rules for dynamic shapes
2019+
data_shape = list(data.struct_info.shape)
2020+
data_ndim = len(data_shape)
2021+
target_ndim = shape_ndim
2022+
padded_data = data
2023+
2024+
if target_ndim > data_ndim:
2025+
padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape
2026+
padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape)))
2027+
2028+
return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars)))
19602029

19612030

19622031
class Attention(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,106 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data):
19091909
_test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data)
19101910

19111911

1912+
def test_expand_incompatible_broadcasting():
1913+
"""
1914+
This test case reproduces the error where input tensor shape at dim 1 is 25
1915+
and target shape at dim 3 is 56, which violates ONNX broadcasting rules
1916+
"""
1917+
1918+
def _test_expand_error_case(name, data_shape, target_shape_vals):
1919+
data = np.random.uniform(size=data_shape).astype(np.float32)
1920+
1921+
shape_array = np.array(target_shape_vals, dtype=np.int64)
1922+
shape_node = onnx.helper.make_node(
1923+
"Constant",
1924+
inputs=[],
1925+
outputs=["shape"],
1926+
value=onnx.helper.make_tensor(
1927+
name="const_tensor",
1928+
data_type=onnx.TensorProto.INT64,
1929+
dims=shape_array.shape,
1930+
vals=shape_array.flatten(),
1931+
),
1932+
)
1933+
1934+
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
1935+
1936+
graph = helper.make_graph(
1937+
[shape_node, expand_node],
1938+
"expand_error_test",
1939+
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
1940+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
1941+
)
1942+
1943+
model = helper.make_model(graph, producer_name=name)
1944+
1945+
with pytest.raises(ValueError) as exc_info:
1946+
from_onnx(model, keep_params_in_input=True)
1947+
1948+
error_msg = str(exc_info.value)
1949+
assert (
1950+
"broadcast" in error_msg.lower() or "incompatible" in error_msg.lower()
1951+
), f"Expected broadcasting error, but got: {error_msg}"
1952+
1953+
# Test case 1: Reproduce the exact error from the issue-17769
1954+
# Input shape: (25,), target shape: (1, 1, 1, 56)
1955+
# This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1
1956+
_test_expand_error_case(
1957+
"expand_incompatible_25_to_56",
1958+
data_shape=(25,),
1959+
target_shape_vals=(1, 1, 1, 56),
1960+
)
1961+
1962+
# Test case 2: Another incompatible case
1963+
# Input shape: (1, 25), target shape: (1, 1, 1, 56)
1964+
# After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56)
1965+
# This should fail because 25 != 56 and neither is 1
1966+
_test_expand_error_case(
1967+
"expand_incompatible_aligned_25_to_56",
1968+
data_shape=(1, 25),
1969+
target_shape_vals=(1, 1, 1, 56),
1970+
)
1971+
1972+
# Test case 3: Valid case for comparison - should not raise error
1973+
def _test_expand_valid_case():
1974+
"""Test a valid expand case to ensure our fix doesn't break valid operations"""
1975+
data_shape = (1, 25)
1976+
target_shape_vals = [2, 25] # Valid: input (1, 25) can broadcast to (2, 25)
1977+
1978+
data = np.random.uniform(size=data_shape).astype(np.float32)
1979+
shape_array = np.array(target_shape_vals, dtype=np.int64)
1980+
1981+
shape_node = onnx.helper.make_node(
1982+
"Constant",
1983+
inputs=[],
1984+
outputs=["shape"],
1985+
value=onnx.helper.make_tensor(
1986+
name="const_tensor",
1987+
data_type=onnx.TensorProto.INT64,
1988+
dims=shape_array.shape,
1989+
vals=shape_array.flatten(),
1990+
),
1991+
)
1992+
1993+
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
1994+
1995+
graph = helper.make_graph(
1996+
[shape_node, expand_node],
1997+
"expand_valid_test",
1998+
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
1999+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
2000+
)
2001+
2002+
model = helper.make_model(graph, producer_name="expand_valid_test_case")
2003+
2004+
try:
2005+
tvm_model = from_onnx(model, keep_params_in_input=True)
2006+
except Exception as e:
2007+
pytest.fail(f"Valid expand case should not fail, but got error: {e}")
2008+
2009+
_test_expand_valid_case()
2010+
2011+
19122012
# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed.
19132013
@pytest.mark.skip("Produces ill-formed IR")
19142014
def test_constantofshape():

0 commit comments

Comments
 (0)