diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index deaa77a204a7f..66858c6d8da73 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -322,23 +322,24 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { + // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) + ONNX_NAMESPACE::TensorShapeProto present_shape; + *present_shape.add_dim() = past_dims[0]; // batch_size + *present_shape.add_dim() = past_dims[1]; // kv_num_heads if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { // present_sequence_length = max(past_sequence_length, total_sequence_length) const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() ? total_sequence_length_value : past_dims[2].dim_value(); - - ONNX_NAMESPACE::TensorShapeProto present_shape; - for (auto& dim : past_dims) { - *present_shape.add_dim() = dim; - } - - // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); - - updateOutputShape(ctx, 1, present_shape); - updateOutputShape(ctx, 2, present_shape); + present_shape.add_dim()->set_dim_value(present_sequence_length); + } else { + // Cannot compute exact present_sequence_length, copy from past_key (may be dynamic) + *present_shape.add_dim() = past_dims[2]; } + *present_shape.add_dim() = past_dims[3]; // head_size + + updateOutputShape(ctx, 1, present_shape); + updateOutputShape(ctx, 2, present_shape); } if (output_qk_index >= 0) { @@ -370,6 +371,52 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } } + } else if (hasInputShape(ctx, 0)) { + // Handle the case when past_key/past_value is not provided (first token/prefill mode). + // We still need to infer present_key/present_value output shapes from query and attributes. + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + + if (num_heads > 0 && kv_num_heads > 0 && query_dims.size() == 3 && query_dims[2].has_dim_value()) { + int64_t hidden_size = query_dims[2].dim_value(); + int64_t head_size = 0; + + if (hasInputShape(ctx, 2)) { + // query shape is (batch_size, sequence_length, num_heads * head_size) + head_size = hidden_size / num_heads; + } else { + // Packed QKV: query shape is (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size) + head_size = hidden_size / (num_heads + 2 * kv_num_heads); + } + + if (head_size > 0) { + // Determine present_sequence_length from total_sequence_length or kv_sequence_length + int64_t present_sequence_length = 0; + if (total_sequence_length_value > 0) { + present_sequence_length = total_sequence_length_value; + } else if (kv_sequence_length > 0) { + present_sequence_length = kv_sequence_length; + } + + // present key/value shape is (batch_size, kv_num_heads, present_sequence_length, head_size) + ONNX_NAMESPACE::TensorShapeProto present_shape; + *present_shape.add_dim() = query_dims[0]; // batch_size + present_shape.add_dim()->set_dim_value(kv_num_heads); + if (present_sequence_length > 0) { + present_shape.add_dim()->set_dim_value(present_sequence_length); + } else { + // Fallback: use query sequence_length (dim 1) as present_sequence_length for prefill + *present_shape.add_dim() = query_dims[1]; + } + present_shape.add_dim()->set_dim_value(head_size); + + updateOutputShape(ctx, 1, present_shape); + updateOutputShape(ctx, 2, present_shape); + } + } } } }