diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 85b8c4ca5a274..32b3c54f734dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -3,12 +3,15 @@ import { TensorView } from '../../tensor-view'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext } from '../types'; +import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { DataType } from '../../../wasm-common'; import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; +import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding'; +import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; export interface GroupQueryAttentionAttributes { numHeads: number; kvNumHeads: number; @@ -24,9 +27,6 @@ export const validateInputs = ( inputs: readonly TensorView[], attributes: GroupQueryAttentionAttributes, ): AttentionParameters => { - if (attributes.doRotary) { - throw new Error('GroupQuerryAttention do_rotary attribute is not supported'); - } if (attributes.doRotary && inputs.length <= 7) { throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); } @@ -35,6 +35,9 @@ export const validateInputs = ( const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; + if (attributes.doRotary !== 0 && inputs.length <= 7) { + throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero'); + } if (attributes.localWindowSize !== -1) { throw new Error('Local attention is not supported'); } @@ -238,6 +241,77 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params return reshapedInput; }; +const generatePositionIdsProgramInfo = ( + batchSize: number, + sequenceLength: number, + seqLens: TensorView, + totalSeqLen: TensorView, +) => { + const outputDataType = DataType.int64; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const outputShape = [batchSize * sequenceLength]; + const outputSize = batchSize * sequenceLength; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: batchSize }, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims); + const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims); + const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape); + + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, + ]; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')}); + let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length; + let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; + let batch_idx = global_idx / uniforms.sequence_length; + let sequence_idx = i32(global_idx % uniforms.sequence_length); + var pos_id: i32 = 0; + let seqlen = ${seqLensInputHelper.getByOffset('batch_idx')}; + let total_seqlen = seqlen + 1; + if (is_first_prompt) { + if (sequence_idx < total_seqlen) { + pos_id = sequence_idx; + } else { + pos_id = 1; + } + ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} + } else if (is_subsequent_prompt) { + let past_seqlen = total_seqlen - i32(uniforms.sequence_length); + if (past_seqlen + sequence_idx < total_seqlen) { + pos_id = past_seqlen + sequence_idx; + } else { + pos_id = 1; + } + ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} + } else if (global_idx < uniforms.batch_size) { + ${positionIdsHelper.setByOffset('global_idx', 'seqlen')} + }; + } + `; + }; + return { + name: 'GeneratePositionIds', + shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { @@ -268,22 +342,57 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu !k && !v ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) : [q, k!, v!]; - + let qRotary: TensorView | undefined; + let kRotary: TensorView | undefined; + if (attributes.doRotary) { + const posIds = context.compute( + generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!), + { inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] }, + )[0]; + const cosCache = context.inputs[7]; + const sinCache = context.inputs[8]; + const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ + interleaved: attributes.rotaryInterleaved !== 0, + numHeads: params.numHeads, + rotaryEmbeddingDim: 0, + scale: attributes.scale, + }); + const inputs = [query, posIds, cosCache, sinCache]; + const outputs = [-1]; + qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), { + inputs, + outputs, + })[0]; + inputs.splice(0, 1, key); + const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ + interleaved: attributes.rotaryInterleaved !== 0, + numHeads: params.kvNumHeads!, + rotaryEmbeddingDim: 0, + scale: attributes.scale, + }); + kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), { + inputs, + outputs, + })[0]; + } const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, - query, + attributes.doRotary ? qRotary! : query, undefined, 0, ); + const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params); + const V = maybeTransposeToBNSH(context, value, params); + applyAttention( context, Q, - maybeTransposeToBNSH(context, key, params), - maybeTransposeToBNSH(context, value, params), + K, + V, undefined, undefined, pastKey, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts index 8eb7a10ac91fa..fe2567e71d49a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -75,7 +75,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi } }; -const createRotaryEmbeddingProgramInfo = ( +export const createRotaryEmbeddingProgramInfo = ( inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes, ): ProgramInfo => {