Skip to content

Commit ce64fb7

Browse files
authored
Merge branch 'main' into remove-optional-bias
2 parents 153b4e7 + f529292 commit ce64fb7

File tree

16 files changed

+249
-197
lines changed

16 files changed

+249
-197
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.5.0
1+
0.5.2

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"packaging",
4343
"protobuf",
4444
)
45-
ONNX_IR = "onnx_ir==0.1.7"
45+
ONNX_IR = "onnx_ir==0.1.9"
4646
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4747

4848

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -925,16 +925,21 @@ def aten_atan(self: TFloat) -> TFloat:
925925
return op.Atan(self)
926926

927927

928-
@torch_op("aten::atan2")
928+
@torch_op("aten::atan2", trace_only=True)
929929
def aten_atan2(self: TFloat, other: TFloat) -> TFloat:
930930
"""atan2(Tensor self, Tensor other) -> Tensor"""
931931

932932
# self is y, and other is x on coordinate
933933
slope = op.Div(self, other)
934934
atan = op.Atan(slope)
935+
zero = common_ops.constant(0.0, dtype=self.dtype)
936+
pi = common_ops.constant(_MATH_PI, dtype=self.dtype)
935937

936-
second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI)
937-
result = op.Where(other < 0.0, second_third_quadrant, atan)
938+
second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi)
939+
result = op.Where(op.Less(other, zero), second_third_quadrant, atan)
940+
941+
# Map NaN to 0 to match PyTorch behavior
942+
result = op.Where(op.IsNaN(result), zero, result)
938943

939944
return result
940945

@@ -7327,16 +7332,25 @@ def aten_repeat_interleave_self_int(
73277332
self_rank = len(self.shape)
73287333
pos_dim = (dim + self_rank) % self_rank
73297334
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
7330-
tiles = [1] * (self_rank + 1)
7331-
tiles[pos_dim + 1] = repeats
7332-
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
7333-
tiled = op.Tile(unsqueezed, tile_repeat)
7335+
if isinstance(repeats, int):
7336+
tiles = [1] * (self_rank + 1)
7337+
tiles[pos_dim + 1] = repeats
7338+
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
7339+
else:
7340+
# repeats is a symbolic tensor
7341+
tile_repeat = op.Concat(
7342+
op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
7343+
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
7344+
op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
7345+
axis=0,
7346+
)
7347+
tiled = op.Expand(unsqueezed, tile_repeat)
73347348
if self_rank == 1:
73357349
return op.Identity(tiled)
73367350
final_shape = op.Concat(
73377351
op.Shape(self, start=0, end=dim),
73387352
op.Constant(value_ints=[-1]),
7339-
op.Shape(self, start=dim + 1),
7353+
op.Shape(self, start=pos_dim + 1),
73407354
axis=0,
73417355
)
73427356
return op.Reshape(tiled, final_shape)
@@ -7375,20 +7389,22 @@ def aten_repeat_interleave_Tensor(
73757389
if dim is None:
73767390
# flatten
73777391
self = op.Reshape(self, [-1])
7378-
rk = 1
7392+
rank = 1
73797393
else:
7380-
rk = len(self.shape)
7394+
rank = len(self.shape)
73817395

7382-
if rk > 2:
7396+
if rank > 2:
73837397
shape_x0 = op.Shape(self, start=0, end=1)
73847398
shape_x = op.Shape(self, start=1)
73857399
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
7386-
elif rk == 1:
7400+
elif rank == 1:
73877401
shape_x = None
73887402
self = op.Reshape(self, [-1, 1])
73897403
else:
7390-
if rk != 2:
7391-
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
7404+
if rank != 2:
7405+
raise NotImplementedError(
7406+
f"rank(self)={rank} not implemented for repeat_interleave"
7407+
)
73927408
shape_x = None
73937409

73947410
ci = op.CumSum(repeats, [0])
@@ -7401,7 +7417,7 @@ def aten_repeat_interleave_Tensor(
74017417
)
74027418
indices = op.Reshape(srows, [-1])
74037419
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
7404-
if rk == 2:
7420+
if rank == 2:
74057421
return values
74067422
# shape_x is None at this stage.
74077423
assert shape_x is None # for mypy

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat:
17411741
return scale
17421742

17431743

1744+
def _attention_repeat_kv_for_group_query(
1745+
query: TFloat, key: TFloat, value: TFloat
1746+
) -> Tuple[TFloat, TFloat]:
1747+
"""Expand key and value for group query attention.
1748+
1749+
repeat_interleave is applied on key and value to match the number of heads in query.
1750+
1751+
Args:
1752+
query: Tensor of shape [B, q_num_heads, q_S, E]
1753+
key: Tensor of shape [B, k_num_heads, kv_S, E]
1754+
value: Tensor of shape [B, v_num_heads, kv_S, E]
1755+
1756+
Returns:
1757+
Tuple of (expanded_key, expanded_value) where:
1758+
- expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
1759+
- expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
1760+
"""
1761+
1762+
assert (
1763+
query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0
1764+
), (
1765+
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
1766+
)
1767+
1768+
# NOTE: QKV are expected to be 4D tensors
1769+
1770+
batch_size = op.Shape(query, start=0, end=1) # [B]
1771+
q_num_heads = op.Shape(query, start=1, end=2) # [Hq]
1772+
kv_num_heads = op.Shape(key, start=1, end=2) # [Hk]
1773+
qk_head_size = op.Shape(key, start=3, end=4) # [Dk]
1774+
v_head_size = op.Shape(value, start=3, end=4) # [Dv]
1775+
new_kv_seq_len = op.Shape(key, start=2, end=3) # [T]
1776+
1777+
interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk
1778+
two = op.Constant(value_int=2)
1779+
k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk]
1780+
v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv]
1781+
1782+
k_expand_shape = op.Concat(
1783+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0
1784+
)
1785+
k_expand = op.Expand(k_unsqueezed, k_expand_shape)
1786+
v_expand_shape = op.Concat(
1787+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0
1788+
)
1789+
v_expand = op.Expand(v_unsqueezed, v_expand_shape)
1790+
1791+
k_attention_shape = op.Concat(
1792+
batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0
1793+
)
1794+
v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0)
1795+
1796+
expanded_key = op.Reshape(k_expand, k_attention_shape)
1797+
expanded_value = op.Reshape(v_expand, v_attention_shape)
1798+
1799+
return expanded_key, expanded_value
1800+
1801+
17441802
@torch_op("aten::scaled_dot_product_attention", trace_only=True)
17451803
def aten_scaled_dot_product_attention(
17461804
query: TFloat,
@@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention(
17721830
"is_causal and attn_mask cannot be set at the same time"
17731831
)
17741832

1775-
assert not enable_gqa, (
1776-
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
1833+
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
1834+
"only 4D query, key, and value are supported"
17771835
)
17781836

17791837
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention(
17841842
if is_causal:
17851843
attn_mask = _causal_attention_mask(query, key)
17861844

1845+
if enable_gqa:
1846+
key, value = _attention_repeat_kv_for_group_query(query, key, value)
1847+
else:
1848+
assert query.shape[1] == key.shape[1] == value.shape[1], (
1849+
"SDPA (MHA) requires q_num_heads = kv_num_heads"
1850+
)
1851+
17871852
if attn_mask is None:
17881853
return _aten_scaled_dot_product_attention_no_mask_onnx(
17891854
query, key, value, scale, dropout_p
@@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask(
19812046
assert (not is_causal) or (is_causal and attn_mask is None), (
19822047
"is_causal and attn_mask cannot be set at the same time"
19832048
)
1984-
1985-
assert not enable_gqa, (
1986-
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
2049+
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
2050+
"only 4D query, key, and value are supported"
19872051
)
19882052

19892053
if scale is None:
@@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask(
19972061
query, key, value, attn_mask, scale, dropout_p
19982062
)
19992063

2064+
if enable_gqa:
2065+
key, value = _attention_repeat_kv_for_group_query(query, key, value)
2066+
20002067
if attn_mask is None:
20012068
return _aten_scaled_dot_product_attention_no_mask_onnx(
20022069
query, key, value, scale, dropout_p

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
176176
raise NotImplementedError()
177177

178178

179+
@torch_op("prims::broadcast_in_dim", trace_only=True)
179180
def prims_broadcast_in_dim(
180-
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
181+
a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int]
181182
) -> TensorType:
182183
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
183184

184-
raise NotImplementedError()
185+
target_rank = len(shape)
186+
187+
if not broadcast_dimensions:
188+
# Special case: no broadcast dimensions - all target dims should be 1
189+
return op.Expand(a, common_ops.merge_dims(shape))
190+
191+
# Create base shape of all 1s
192+
ones = [1] * target_rank
193+
194+
# For each broadcast dimension, we'll replace the 1 with the actual input dimension
195+
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
196+
intermediate_shape = ones
197+
198+
for i, broadcast_dim in enumerate(broadcast_dimensions):
199+
# Get the input dimension value
200+
input_dim_value = op.Shape(a, start=i, end=i + 1)
201+
intermediate_shape[broadcast_dim] = input_dim_value
202+
203+
# Reshape input to intermediate shape and expand to target
204+
reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape))
205+
return op.Expand(reshaped, shape)
185206

186207

187208
def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType:

0 commit comments

Comments
 (0)