Skip to content

Commit a94f8b9

Browse files
committed
Merge HEAD, branch 'remove-optional-bias' of github-personal:whyvineet/onnxscript into remove-optional-bias
2 parents ce64fb7 + 79afb87 commit a94f8b9

File tree

7 files changed

+7
-1377
lines changed

7 files changed

+7
-1377
lines changed

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ exclude_patterns = [
5757
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5858
'onnxscript/tools/function_unittest_producer.py', # FIXME
5959
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
60-
'onnxscript/rewriter/generic_pattern.py', # FIXME
6160
]
6261
command = [
6362
'python',

examples/pattern_rewriting.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -141,28 +141,3 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
141141
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
142142

143143
rule.apply_to_model(ir_model)
144-
145-
# TODO(rama): Update the following, the trace-printed looks different now.
146-
147-
######################################
148-
# The logs shows every time the algorithm rejected a pattern.
149-
# We can see the following:
150-
#
151-
# ::
152-
#
153-
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
154-
# --hint--: BACKWARD: different node types
155-
# --pattern
156-
# ConcatTraining(transpose, transpose) -> (output, length)
157-
# -- model
158-
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
159-
# iteration=1
160-
# --marked-- #2
161-
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
162-
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
163-
# len(stacked)=0:[]
164-
#
165-
# Line 673 in file `generic_pattern.py`, the match was rejected.
166-
# It says while comparing two nodes in the backward direction,
167-
# node types do not match.
168-
# It also says that two nodes were actually matched.

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3814,11 +3814,15 @@ def aten_gather(
38143814
else:
38153815
return op.Expand(self, op.Shape(index))
38163816

3817-
if len(index.shape) == 0:
3818-
return op.Identity(self)
3817+
is_scalar_index = len(index.shape) == 0
3818+
if is_scalar_index:
3819+
index = op.Unsqueeze(index, [0])
38193820

38203821
index = op.Cast(index, to=INT64.dtype)
38213822
result = op.GatherElements(self, index, axis=dim)
3823+
3824+
if is_scalar_index:
3825+
result = op.Squeeze(result, [0])
38223826
return result
38233827

38243828

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,7 @@ def __init__(
8282
if isinstance(matcher, _matcher.PatternMatcher):
8383
self._matcher = matcher
8484
elif matcher is None:
85-
if target_pattern.has_single_output_node:
86-
self._matcher = _matcher.SimplePatternMatcher(self._target_pattern)
87-
else:
88-
import onnxscript.rewriter.generic_pattern as generic_pattern
89-
90-
self._matcher = generic_pattern.GenericPatternMatcher(self._target_pattern)
85+
self._matcher = _matcher.SimplePatternMatcher(self._target_pattern)
9186
else:
9287
self._matcher = matcher(self._target_pattern)
9388
self._verbose = verbose

0 commit comments

Comments
 (0)