diff --git a/src/models/audio_features.cpp b/src/models/audio_features.cpp index d9ef1c6dbe..91372643cf 100644 --- a/src/models/audio_features.cpp +++ b/src/models/audio_features.cpp @@ -12,7 +12,7 @@ AudioFeatures::AudioFeatures(State& state, const std::string& name, const std::v name_{name} { // Get audio features for (const auto& [input_name, value] : extra_inputs) { - if (input_name == Config::Defaults::AudioFeaturesName) { + if (input_name == name) { audio_features_ = model_.ExpandInputs(value->ort_tensor_, state_.params_->search.num_beams); } } diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index fd8897a261..903b227c14 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -54,10 +54,11 @@ DeviceSpan AudioEncoderState::Run(int current_length, DeviceSpan WhisperDecoderState::WhisperDecoderState(const WhisperModel& model, const GeneratorParams& params, const int num_frames) : State{params, model}, model_{model}, + kv_cache_(CreateKeyValueCache(*this)), num_frames_{num_frames} { input_ids_.Add(); logits_.Add(); - kv_cache_.Add(); + kv_cache_->Add(); // Add past sequence length if (HasPastSequenceLengthInput()) { @@ -117,7 +118,7 @@ void WhisperDecoderState::UpdateInputsOutputs(DeviceSpan& next_tokens, int batch_size = static_cast(input_ids_.GetShape()[0]); size_t new_length = next_tokens.size() / batch_size; input_ids_.Update(next_tokens); - kv_cache_.Update(beam_indices, current_length); + kv_cache_->Update(beam_indices, current_length); logits_.Update(next_tokens, first_run_ ? current_length : new_length); // Return early if this method is just initializing the above OrtValue objects and not updating them @@ -171,17 +172,24 @@ WhisperState::WhisperState(const WhisperModel& model, const GeneratorParams& par : State{params, model}, model_{model} { encoder_state_ = std::make_unique(model, params); - cross_cache_ = std::make_unique(*this, encoder_state_->GetNumFrames() / 2); - encoder_state_->AddCrossCache(cross_cache_); decoder_state_ = std::make_unique(model, params, encoder_state_->GetNumFrames()); - decoder_state_->AddCrossCache(cross_cache_); - transpose_k_cache_buffer_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), cross_cache_->GetShape(), cross_cache_->GetType()); + if (encoder_state_->HasCrossKVCacheOutputs()) { + cross_cache_ = std::make_unique(*this, encoder_state_->GetNumFrames() / 2); + encoder_state_->AddCrossCache(cross_cache_); + decoder_state_->AddCrossCache(cross_cache_); + transpose_k_cache_buffer_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), cross_cache_->GetShape(), cross_cache_->GetType()); + } } void WhisperState::SetExtraInputs(const std::vector& extra_inputs) { encoder_state_->SetExtraInputs(extra_inputs); + if (!encoder_state_->HasCrossKVCacheOutputs()) { + decoder_state_->inputs_.push_back(encoder_state_->hidden_states_.get()); + decoder_state_->input_names_.push_back(model_.config_->model.decoder.inputs.encoder_hidden_states.c_str()); + } + // Check if alignment heads input exists void* alignment_heads_input = nullptr; for (const auto& [name, value] : extra_inputs) { @@ -337,7 +345,12 @@ DeviceSpan WhisperState::Run(int current_length, DeviceSpan& nex // Transpose the K caches only when the else branch is run for the first time. // Otherwise the GetOutput(present_key_{self/cross}_{i}) method returns transposed K caches. TransposeKCaches(cross_cache_->GetValues()); - TransposeKCaches(decoder_state_->kv_cache_.GetPresents()); + + auto default_kv_cache_ptr = dynamic_cast(decoder_state_->kv_cache_.get()); + if (!default_kv_cache_ptr) { + throw std::runtime_error("Unable to convert KeyValueCache to DefaultKeyValueCache"); + } + TransposeKCaches(default_kv_cache_ptr->GetPresents()); } // Update inputs and outputs for decoder diff --git a/src/models/whisper.h b/src/models/whisper.h index 34e5208f22..ab490eced0 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -31,6 +31,8 @@ struct AudioEncoderState : State { int GetNumFrames() { return num_frames_; } + bool HasCrossKVCacheOutputs() { return model_.session_info_.HasOutput(ComposeKeyValueName(model_.config_->model.encoder.outputs.cross_present_key_names, 0)); } + private: friend struct WhisperState; @@ -62,7 +64,7 @@ struct WhisperDecoderState : State { const WhisperModel& model_; DefaultInputIDs input_ids_{*this}; // Model input - DefaultKeyValueCache kv_cache_{*this}; // Model input and output + std::unique_ptr kv_cache_; // Inputs for beam search attention std::unique_ptr past_sequence_length_; // Model input