Skip to content

Commit

Permalink
making temperature_scale configurable from outside
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelVesely84 committed Apr 22, 2024
1 parent 6ec96cd commit ce6e5b5
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 29 deletions.
16 changes: 13 additions & 3 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty);
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_, config_.blank_penalty);
model_.get(),
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
po->Register("temperature-scale", &temperature_scale,
"Temperature scale for confidence computation in decoding.");
}

bool OnlineRecognizerConfig::Validate() const {
Expand Down Expand Up @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ")";
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ")";

return os.str();
}
Expand Down
20 changes: 14 additions & 6 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {

float blank_penalty = 0.0;

float temperature_scale = 2.0;

OnlineRecognizerConfig() = default;

OnlineRecognizerConfig(
const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
const OnlineModelConfig &model_config,
const OnlineLMConfig &lm_config,
const EndpointConfig &endpoint_config,
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score,
float blank_penalty,
float temperature_scale)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
Expand All @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
hotwords_score(hotwords_score),
hotwords_file(hotwords_file),
blank_penalty(blank_penalty) {}
hotwords_score(hotwords_score),
blank_penalty(blank_penalty),
temperature_scale(temperature_scale) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
4 changes: 1 addition & 3 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(

// export the per-token log scores
if (y != 0 && y != unk_id_) {
// TODO(KarelVesely84): configure externally ?
// apply temperature-scaling
float temperature_scale = 2.0;
for (int32_t n = 0; n < vocab_size; ++n) {
p_logit[n] /= temperature_scale;
p_logit[n] /= temperature_scale_;
}
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
Expand Down
10 changes: 8 additions & 2 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ namespace sherpa_onnx {
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
public:
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
int32_t unk_id, float blank_penalty)
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
int32_t unk_id,
float blank_penalty,
float temperature_scale)
: model_(model),
unk_id_(unk_id),
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}

OnlineTransducerDecoderResult GetEmptyResult() const override;

Expand All @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
OnlineTransducerModel *model_; // Not owned
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};

} // namespace sherpa_onnx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(

float *p_logit = logit.GetTensorMutableData<float>();

// copy raw logits, apply temperature-scaling (for confidences)
// copy raw logits, apply temperature-scaling (for confidences)
// Note: temperature scaling is used only for the confidences,
// the decoding algorithm uses the original logits
int32_t p_logit_items = vocab_size * num_hyps;
std::vector<float> logit_with_temperature(p_logit_items);
{
std::copy(p_logit,
p_logit + p_logit_items,
logit_with_temperature.begin());
// TODO(KarelVesely84): configure externally ?
float temperature_scale = 2.0;
for (float& elem : logit_with_temperature) {
elem /= temperature_scale;
elem /= temperature_scale_;
}
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineLM *lm,
int32_t max_active_paths,
float lm_scale, int32_t unk_id,
float blank_penalty)
float blank_penalty,
float temperature_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale),
unk_id_(unk_id),
blank_penalty_(blank_penalty) {}
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}

OnlineTransducerDecoderResult GetEmptyResult() const override;

Expand All @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
float lm_scale_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};

} // namespace sherpa_onnx
Expand Down
30 changes: 22 additions & 8 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::init<const FeatureExtractorConfig &,
const OnlineModelConfig &,
const OnlineLMConfig &,
const EndpointConfig &,
const OnlineCtcFstDecoderConfig &,
bool,
const std::string &,
int32_t,
const std::string &,
float,
float,
float>(),
py::arg("feat_config"),
py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
py::arg("enable_endpoint"),
py::arg("decoding_method"),
py::arg("max_active_paths") = 4,
py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0,
py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def("__str__", &PyClass::ToString);
}

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_transducer(
model_type: str = "",
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
):
"""
Please refer to
Expand Down Expand Up @@ -123,6 +124,10 @@ def from_transducer(
hotwords_score:
The hotword score of each token for biasing word/phrase. Used only if
hotwords_file is given with modified_beam_search as decoding method.
temperature_scale:
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
logits without temperature.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
model_type:
Expand Down Expand Up @@ -193,6 +198,7 @@ def from_transducer(
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
blank_penalty=blank_penalty,
temperature_scale=temperature_scale,
)

self.recognizer = _Recognizer(recognizer_config)
Expand Down

0 comments on commit ce6e5b5

Please sign in to comment.