Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 117 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

import { TensorView } from '../../tensor-view';
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext } from '../types';
import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import { DataType } from '../../../wasm-common';

import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding';
import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
export interface GroupQueryAttentionAttributes {
numHeads: number;
kvNumHeads: number;
Expand All @@ -24,9 +27,6 @@ export const validateInputs = (
inputs: readonly TensorView[],
attributes: GroupQueryAttentionAttributes,
): AttentionParameters => {
if (attributes.doRotary) {
throw new Error('GroupQuerryAttention do_rotary attribute is not supported');
}
if (attributes.doRotary && inputs.length <= 7) {
throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
}
Expand All @@ -35,6 +35,9 @@ export const validateInputs = (
const value = inputs[2];
const pastKey = inputs[3];
const pastValue = inputs[4];
if (attributes.doRotary !== 0 && inputs.length <= 7) {
throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero');
}
if (attributes.localWindowSize !== -1) {
throw new Error('Local attention is not supported');
}
Expand Down Expand Up @@ -238,6 +241,77 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
return reshapedInput;
};

const generatePositionIdsProgramInfo = (
batchSize: number,
sequenceLength: number,
seqLens: TensorView,
totalSeqLen: TensorView,
) => {
const outputDataType = DataType.int64;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const outputShape = [batchSize * sequenceLength];
const outputSize = batchSize * sequenceLength;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: sequenceLength },
{ type: DataType.uint32, data: batchSize },
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims);
const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims);
const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape);

const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'sequence_length', type: 'u32' },
{ name: 'batch_size', type: 'u32' },
];

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')});
let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length;
let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length;
let batch_idx = global_idx / uniforms.sequence_length;
let sequence_idx = i32(global_idx % uniforms.sequence_length);
var pos_id: i32 = 0;
let seqlen = ${seqLensInputHelper.getByOffset('batch_idx')};
let total_seqlen = seqlen + 1;
if (is_first_prompt) {
if (sequence_idx < total_seqlen) {
pos_id = sequence_idx;
} else {
pos_id = 1;
}
${positionIdsHelper.setByOffset('global_idx', 'pos_id')}
} else if (is_subsequent_prompt) {
let past_seqlen = total_seqlen - i32(uniforms.sequence_length);
if (past_seqlen + sequence_idx < total_seqlen) {
pos_id = past_seqlen + sequence_idx;
} else {
pos_id = 1;
}
${positionIdsHelper.setByOffset('global_idx', 'pos_id')}
} else if (global_idx < uniforms.batch_size) {
${positionIdsHelper.setByOffset('global_idx', 'seqlen')}
};
}
`;
};
return {
name: 'GeneratePositionIds',
shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: outputDataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};

export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
Expand Down Expand Up @@ -268,22 +342,57 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
!k && !v
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
: [q, k!, v!];

let qRotary: TensorView | undefined;
let kRotary: TensorView | undefined;
if (attributes.doRotary) {
const posIds = context.compute(
generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!),
{ inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] },
)[0];
const cosCache = context.inputs[7];
const sinCache = context.inputs[8];
const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
interleaved: attributes.rotaryInterleaved !== 0,
numHeads: params.numHeads,
rotaryEmbeddingDim: 0,
scale: attributes.scale,
});
const inputs = [query, posIds, cosCache, sinCache];
const outputs = [-1];
qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), {
inputs,
outputs,
})[0];
inputs.splice(0, 1, key);
const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({
interleaved: attributes.rotaryInterleaved !== 0,
numHeads: params.kvNumHeads!,
rotaryEmbeddingDim: 0,
scale: attributes.scale,
});
kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), {
inputs,
outputs,
})[0];
}
const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.sequenceLength,
params.headSize,
query,
attributes.doRotary ? qRotary! : query,
undefined,
0,
);
const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params);
const V = maybeTransposeToBNSH(context, value, params);

applyAttention(
context,
Q,
maybeTransposeToBNSH(context, key, params),
maybeTransposeToBNSH(context, value, params),
K,
V,
undefined,
undefined,
pastKey,
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi
}
};

const createRotaryEmbeddingProgramInfo = (
export const createRotaryEmbeddingProgramInfo = (
inputs: readonly TensorView[],
attributes: RotaryEmbeddingAttributes,
): ProgramInfo => {
Expand Down
Loading