diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc index b9d47a27e8e83..d97d5fcbb0b5b 100755 --- a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc @@ -51,7 +51,6 @@ template GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { opset_start_version_ = info.node().SinceVersion(); - std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index ed067a1362663..743bf50f6c608 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -110,9 +110,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): ) return else: - # Shape inference failed. Use default skip_index=1 (no broadcasting) since both - # Add inputs have already been verified as non-initializer dynamic tensors above. - logger.debug("symbolic shape inference failed, using default skip_index for SkipLayerNormalization") + logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") + return gather_path = self.model.match_parent_path(add, ["Gather"], [None]) if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc index 7fd22cc59745f..2423d7f120b20 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc @@ -37,7 +37,7 @@ TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bound test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds_left_top) { @@ -69,6 +69,5 @@ TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bound test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - RunTests(test, GetExecutionProviders(20)); + RunTests(test, GetExecutionProviders()); } - diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index d25432173a8f0..caaaa1aa628cf 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -395,17 +395,17 @@ def test_qwen3_normalization_fusion(self): ssln_count = sum(1 for n in nodes if n.op_type == "SkipSimplifiedLayerNormalization") # 4 RMSNorm patterns: pre-attn, Q-norm, K-norm, post-attn. - # Post-attn RMSNorm has an Add parent (residual) → fused as SkipSimplifiedLayerNormalization. - # Remaining 3 stay as SimplifiedLayerNormalization. + # Fallback for SkipLayerNormalization is disabled, so post-attn RMSNorm does not fuse. + # All 4 stay as SimplifiedLayerNormalization. self.assertEqual( sln_count, - 3, - f"Expected 3 SimplifiedLayerNormalization (pre-attn + Q-norm + K-norm), got {sln_count}", + 4, + f"Expected 4 SimplifiedLayerNormalization (pre-attn + Q-norm + K-norm + post-attn), got {sln_count}", ) self.assertEqual( ssln_count, - 1, - f"Expected 1 SkipSimplifiedLayerNormalization (residual + post-attn RMSNorm), got {ssln_count}", + 0, + f"Expected 0 SkipSimplifiedLayerNormalization (residual + post-attn RMSNorm failed to fuse), got {ssln_count}", )