@@ -1910,15 +1910,47 @@ def _impl_v13(cls, bb, inputs, attr, params):
19101910 if isinstance (shape , relax .ShapeExpr ):
19111911 data_shape = list (data .struct_info .shape )
19121912 target_shape = list (shape .values )
1913+ original_data_shape = [
1914+ dim .value if hasattr (dim , "value" ) else str (dim ) for dim in data_shape
1915+ ]
1916+ original_target_shape = [
1917+ dim .value if hasattr (dim , "value" ) else str (dim ) for dim in target_shape
1918+ ]
19131919 data_shape = [1 ] * (len (target_shape ) - len (data_shape )) + data_shape
19141920 assert len (data_shape ) == len (target_shape )
1915- # Fix small target shapes or target shapes assigned to -1
1921+ # Apply ONNX v13 Expand broadcasting rules
19161922 for i , s in enumerate (target_shape ):
1917- if isinstance (s , tvm .tir .IntImm ) and (
1918- (isinstance (data_shape [i ], tvm .tir .IntImm ) and s < data_shape [i ])
1919- or s .value == - 1
1920- ):
1921- target_shape [i ] = data_shape [i ]
1923+ if isinstance (s , tvm .tir .IntImm ):
1924+ if s .value == - 1 :
1925+ # -1 means preserve the input dimension
1926+ target_shape [i ] = data_shape [i ]
1927+ elif isinstance (data_shape [i ], tvm .tir .IntImm ) and data_shape [i ].value == 1 :
1928+ # Input dimension is 1, can broadcast to any target dimension >= 1
1929+ if s .value < 1 :
1930+ raise ValueError (
1931+ f"ONNX Expand: Invalid target dimension { s .value } "
1932+ f"at possition { i } . Target dimensions must be >= 1."
1933+ )
1934+ elif (
1935+ isinstance (data_shape [i ], tvm .tir .IntImm ) and s .value == data_shape [i ].value
1936+ ):
1937+ # Dimensions match, no change needed
1938+ pass
1939+ elif s .value == 1 :
1940+ # Target dimension is 1 but input dimension is not 1
1941+ # This would "squeeze" the dimension - preserve input for safety
1942+ target_shape [i ] = data_shape [i ]
1943+ else :
1944+ if isinstance (data_shape [i ], tvm .tir .IntImm ):
1945+ raise ValueError (
1946+ f"ONNX Expand: Cannot broadcast input shape { original_data_shape } "
1947+ f"to target shape { original_target_shape } . "
1948+ f"At dimension { i } : input size { data_shape [i ].value } is "
1949+ f"incompatible with target size { s .value } . "
1950+ f"ONNX broadcasting requires corresponding dimensions to have "
1951+ f"the same value or one of them to be 1."
1952+ )
1953+ # For dynamic shapes, let broadcast_to handle it
19221954 if target_shape == data_shape :
19231955 return data
19241956 return relax .op .broadcast_to (data , relax .ShapeExpr (target_shape ))
@@ -1929,6 +1961,8 @@ def _impl_v13(cls, bb, inputs, attr, params):
19291961 # ONNX Expand operator requires preserving target rank and broadcasting
19301962 # according to standard rules. Dimensions are right-aligned.
19311963 data_shape = [dim .value for dim in data .struct_info .shape ]
1964+ original_data_shape = data_shape .copy ()
1965+ original_new_shape = new_shape .copy ()
19321966
19331967 # Right-align the shapes
19341968 if len (new_shape ) > len (data_shape ):
@@ -1938,8 +1972,32 @@ def _impl_v13(cls, bb, inputs, attr, params):
19381972 # Fix small target shapes - if target dim is smaller than input dim
19391973 # use the input dim (ONNX-specific behavior).
19401974 for i in range (len (new_shape )):
1941- if new_shape [i ] < data_shape [i ]:
1975+ if new_shape [i ] == - 1 :
1976+ # -1 means preserve the input dimension
1977+ new_shape [i ] = data_shape [i ]
1978+ elif data_shape [i ] == 1 :
1979+ # Input dimension is 1, can broadcast to any target dimension >= 1
1980+ if new_shape [i ] < 1 :
1981+ raise ValueError (
1982+ f"ONNX Expand: Invalid target dimension { new_shape [i ]} "
1983+ f"at possition { i } . Target dimensions must be >= 1."
1984+ )
1985+ elif new_shape [i ] == data_shape [i ]:
1986+ # Dimensions match, no change needed
1987+ pass
1988+ elif new_shape [i ] == 1 :
1989+ # Target dimension is 1 but input dimension is not 1
1990+ # This would "squeeze" the dimension - preserve input for safety
19421991 new_shape [i ] = data_shape [i ]
1992+ else :
1993+ raise ValueError (
1994+ f"ONNX Expand: Cannot broadcast input shape { original_data_shape } "
1995+ f"to target shape { original_new_shape } . "
1996+ f"At dimension { i } : input size { data_shape [i ]} is incompatible "
1997+ f"with target size { new_shape [i ]} . "
1998+ f"ONNX broadcasting requires corresponding dimensions to have the same "
1999+ f"value or one of them to be 1."
2000+ )
19432001 return relax .op .broadcast_to (data , relax .ShapeExpr (new_shape ))
19442002
19452003 # Otherwise handle dynamic shapes.
@@ -1956,7 +2014,18 @@ def _impl_v13(cls, bb, inputs, attr, params):
19562014 for i in range (shape_ndim ):
19572015 shape_vars .append (tvm .tir .Var ("x_%d" % i , "int64" ))
19582016 bb .match_cast (shape_dataflow_var , relax .ShapeStructInfo (shape_vars ))
1959- return bb .normalize (relax .op .broadcast_to (data , relax .ShapeExpr (shape_vars )))
2017+
2018+ # Applying broadcasting rules for dynamic shapes
2019+ data_shape = list (data .struct_info .shape )
2020+ data_ndim = len (data_shape )
2021+ target_ndim = shape_ndim
2022+ padded_data = data
2023+
2024+ if target_ndim > data_ndim :
2025+ padded_data_shape = [tir .IntImm ("int64" , 1 )] * (target_ndim - data_ndim ) + data_shape
2026+ padded_data = bb .normalize (relax .op .reshape (data , relax .ShapeExpr (padded_data_shape )))
2027+
2028+ return bb .normalize (relax .op .broadcast_to (padded_data , relax .ShapeExpr (shape_vars )))
19602029
19612030
19622031class Attention (OnnxOpConverter ):
0 commit comments