Skip to content

Commit

Permalink
Use FixedString in GetTTS
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Feb 28, 2024
1 parent f3f4d68 commit 6fae566
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jsonwebtoken = "9"
mp3-duration = "0.1"
itertools = "0.12"
aws-sdk-polly = "1.7.0"
small-fixed-array = { version = "0.4.0", features = ["serde"] }

[dependencies.fernet]
version = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion src/gcloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ pub async fn get_tts(
text: &str,
lang: &str,
speaking_rate: f32,
preferred_format: Option<String>,
preferred_format: Option<&str>,
) -> Result<(bytes::Bytes, Option<reqwest::header::HeaderValue>)> {
let jwt_token = refresh_jwt(state).await?;
let reqwest = state.read().await.reqwest.clone();
Expand Down
23 changes: 12 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use bytes::Bytes;
use deadpool_redis::redis::AsyncCommands;
use serde_json::to_value;
use sha2::Digest;
use small_fixed_array::FixedString;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod espeak;
Expand Down Expand Up @@ -65,15 +66,15 @@ async fn get_voices(

#[derive(serde::Deserialize)]
struct GetTTS {
text: String,
text: FixedString,
mode: TTSMode,
#[serde(rename = "lang")]
voice: String,
voice: FixedString<u8>,
#[serde(default)]
speaking_rate: Option<f32>,
max_length: Option<u64>,
#[serde(default)]
preferred_format: Option<String>,
preferred_format: Option<FixedString<u8>>,
}

async fn get_tts(
Expand Down Expand Up @@ -150,7 +151,7 @@ async fn get_tts(
text,
&voice,
speaking_rate.map(|r| r as u8),
preferred_format,
preferred_format.as_deref(),
)
.await?
}
Expand All @@ -160,7 +161,7 @@ async fn get_tts(
&text,
&voice,
speaking_rate.unwrap_or(0.0),
preferred_format,
preferred_format.as_deref(),
)
.await?
}
Expand Down Expand Up @@ -210,11 +211,11 @@ impl TTSMode {
.map_err(Into::into)
}

#[cfg_attr(
not(feature = "polly"),
allow(unused_variables, clippy::unnecessary_wraps)
)]
async fn check_voice(self, state: &State, voice: String) -> ResponseResult<String> {
async fn check_voice(
self,
state: &State,
voice: FixedString<u8>,
) -> ResponseResult<FixedString<u8>> {
if match self {
Self::gTTS => gtts::check_voice(&voice),
Self::eSpeak => espeak::check_voice(&voice),
Expand Down Expand Up @@ -363,7 +364,7 @@ async fn main() -> Result<()> {
#[derive(Debug)]
enum Error {
Unauthorized,
UnknownVoice(String),
UnknownVoice(FixedString<u8>),
AudioTooLong,
InvalidSpeakingRate(f32),

Expand Down
13 changes: 8 additions & 5 deletions src/polly.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use aws_sdk_polly::types::{Engine, Gender, LanguageCode, OutputFormat, TextType, VoiceId};
use serde::ser::SerializeStruct;
use small_fixed_array::FixedString;

use crate::Result;

Expand Down Expand Up @@ -60,14 +61,16 @@ impl serde::Serialize for VoiceLocal {

pub async fn get_tts(
state: &State,
mut text: String,
text: FixedString,
voice: &str,
speaking_rate: Option<u8>,
preferred_format: Option<String>,
preferred_format: Option<&str>,
) -> Result<(bytes::Bytes, Option<reqwest::header::HeaderValue>)> {
if let Some(speaking_rate) = speaking_rate {
text = format!("<speak><prosody rate=\"{speaking_rate}%\">{text}</prosody></speak>");
}
let text = if let Some(speaking_rate) = speaking_rate {
format!("<speak><prosody rate=\"{speaking_rate}%\">{text}</prosody></speak>")
} else {
text.into_string()
};

let resp = state
.synthesize_speech()
Expand Down

0 comments on commit 6fae566

Please sign in to comment.