diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index 0879273f37..22faf9d17c 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -12,6 +12,18 @@ use crate::{ use super::Pipeline; +macro_rules! fixup_sentencepiece { + ($txt:expr) => { + $txt.to_string().replace("▁", " ") + }; + (Option $txt:expr) => { + match &$txt { + Some(txt) => Some(fixup_sentencepiece!(txt)), + None => None, + } + }; +} + pub(crate) async fn finish_or_add_toks_to_seq( this: &dyn Pipeline, prefix_cacher: &mut PrefixCacheManagerV2, @@ -56,7 +68,9 @@ pub(crate) async fn finish_or_add_toks_to_seq( }; seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice { delta: crate::Delta { - content: text_new.map(ToString::to_string), + content: fixup_sentencepiece!( + Option text_new.map(ToString::to_string) + ), role: "assistant".to_string(), tool_calls: Some(tool_calls), }, @@ -76,7 +90,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( } else { seq.add_streaming_completion_chunk_choice_to_group( crate::CompletionChunkChoice { - text: delta.clone(), + text: fixup_sentencepiece!(delta), index: seq.get_response_index(), finish_reason: is_done.map(|x| x.to_string()), logprobs: if seq.return_logprobs() { @@ -190,7 +204,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( let (text_new, tool_calls) = parse_text_tools(text.as_str(), seq.tools.clone()) .map_err(candle_core::Error::msg)?; let choice = crate::Choice { - finish_reason: reason.to_string(), + finish_reason: fixup_sentencepiece!(reason), index: seq.get_response_index(), message: crate::ResponseMessage { content: text_new.map(ToString::to_string), @@ -202,7 +216,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( seq.add_choice_to_group(choice); } else { let choice = crate::CompletionChoice { - finish_reason: reason.to_string(), + finish_reason: fixup_sentencepiece!(reason), index: seq.get_response_index(), text, logprobs: None,