diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9126884ff7ac8..9c33a032e3e58 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -194,7 +194,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { } }; - if (A_has_shape && B_has_shape && Y_has_shape) { + if (A_has_shape && B_has_shape && Y_has_shape && + A_shape.size() >= 2 && B_shape.size() >= 2) { std::vector shared_attributes; shared_attributes.push_back(MakeAttribute("beta", float(0))); AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));