diff --git a/livekit/tests/README.md b/livekit/tests/README.md index a31f2d297..91ad3d7f2 100644 --- a/livekit/tests/README.md +++ b/livekit/tests/README.md @@ -6,5 +6,11 @@ E2E test feature: ```sh livekit-server --dev -cargo test --features __lk-e2e-test +cargo test --features default,__lk-e2e-test -- --nocapture +``` + +Tip: If you are using Rust Analyzer in Visual Studio Code, you can enable this feature to get code completion for these tests. Add the following setting to *.vscode/settings.json*: + +```json +"rust-analyzer.cargo.features": ["default", "__lk-e2e-test"] ``` diff --git a/livekit/tests/audio_test.rs b/livekit/tests/audio_test.rs new file mode 100644 index 000000000..197cf7fce --- /dev/null +++ b/livekit/tests/audio_test.rs @@ -0,0 +1,129 @@ +#[cfg(feature = "__lk-e2e-test")] +use { + anyhow::{anyhow, Ok, Result}, + common::{ + audio::{ChannelIterExt, FreqAnalyzer, SineParameters, SineTrack}, + test_rooms, + }, + futures_util::StreamExt, + libwebrtc::audio_stream::native::NativeAudioStream, + livekit::prelude::*, + std::{sync::Arc, time::Duration}, + tokio::time::timeout, +}; + +mod common; + +struct TestParams { + pub_rate_hz: u32, + pub_channels: u32, + sub_rate_hz: u32, + sub_channels: u32, +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_audio() -> Result<()> { + let test_params = [ + TestParams { pub_rate_hz: 48_000, pub_channels: 1, sub_rate_hz: 48_000, sub_channels: 1 }, + TestParams { pub_rate_hz: 48_000, pub_channels: 2, sub_rate_hz: 48_000, sub_channels: 2 }, + TestParams { pub_rate_hz: 48_000, pub_channels: 2, sub_rate_hz: 24_000, sub_channels: 2 }, + TestParams { pub_rate_hz: 24_000, pub_channels: 2, sub_rate_hz: 24_000, sub_channels: 1 }, + ]; + for params in test_params { + log::info!("Testing with {}", params); + test_audio_with(params).await?; + } + Ok(()) +} + +/// Tests audio transfer between two participants. +/// +/// Verifies that audio can be published and received correctly +/// between two participants by detecting the frequency of the sine wave on the subscriber end. +/// +#[cfg(feature = "__lk-e2e-test")] +async fn test_audio_with(params: TestParams) -> Result<()> { + let mut rooms = test_rooms(2).await?; + let (pub_room, _) = rooms.pop().unwrap(); + let (_, mut sub_room_events) = rooms.pop().unwrap(); + + const SINE_FREQ: f64 = 60.0; + const SINE_AMPLITUDE: f64 = 1.0; + const FRAMES_TO_ANALYZE: usize = 100; + + let sine_params = SineParameters { + freq: SINE_FREQ, + amplitude: SINE_AMPLITUDE, + sample_rate: params.pub_rate_hz, + num_channels: params.pub_channels, + }; + let mut sine_track = SineTrack::new(Arc::new(pub_room), sine_params); + sine_track.publish().await?; + + let analyze_frames = async move { + let track: RemoteTrack = loop { + let Some(event) = sub_room_events.recv().await else { + Err(anyhow!("Never received track"))? + }; + let RoomEvent::TrackSubscribed { track, publication: _, participant: _ } = event else { + continue; + }; + break track.into(); + }; + let RemoteTrack::Audio(track) = track else { Err(anyhow!("Expected audio track"))? }; + let mut stream = NativeAudioStream::new( + track.rtc_track(), + params.sub_rate_hz as i32, + params.sub_channels as i32, + ); + + tokio::spawn(async move { + let mut frames_analyzed = 0; + let mut analyzers = vec![FreqAnalyzer::new(); params.sub_channels as usize]; + + while let Some(frame) = stream.next().await { + assert!(frame.data.len() > 0); + assert_eq!(frame.num_channels, params.sub_channels); + assert_eq!(frame.sample_rate, params.sub_rate_hz); + assert_eq!(frame.samples_per_channel, frame.data.len() as u32 / frame.num_channels); + + for channel_idx in 0..params.sub_channels as usize { + analyzers[channel_idx].analyze(frame.channel_iter(channel_idx)); + } + frames_analyzed += 1; + if frames_analyzed >= FRAMES_TO_ANALYZE { + break; + } + } + assert_eq!(frames_analyzed, FRAMES_TO_ANALYZE); + + for (channel_idx, detected_freq) in analyzers + .into_iter() + .map(|analyzer| analyzer.estimated_freq(params.sub_rate_hz)) + .enumerate() + { + assert!( + (detected_freq - SINE_FREQ).abs() < 20.0, // Expect within 20Hz + "Detected sine frequency not within range for channel {}: {}Hz", + channel_idx, + detected_freq + ); + } + }) + .await?; + Ok(()) + }; + timeout(Duration::from_secs(15), analyze_frames).await??; + Ok(()) +} + +impl std::fmt::Display for TestParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}Hz, {}ch. -> {}Hz, {}ch.", + self.pub_rate_hz, self.pub_channels, self.sub_rate_hz, self.sub_channels + ) + } +} diff --git a/livekit/tests/common/e2e/audio.rs b/livekit/tests/common/e2e/audio.rs new file mode 100644 index 000000000..51c44314d --- /dev/null +++ b/livekit/tests/common/e2e/audio.rs @@ -0,0 +1,196 @@ +use libwebrtc::{ + audio_source::native::NativeAudioSource, + prelude::{AudioFrame, AudioSourceOptions, RtcAudioSource}, +}; +use livekit::{ + options::TrackPublishOptions, + track::{LocalAudioTrack, LocalTrack}, + Room, RoomResult, +}; +use std::sync::Arc; +use tokio::{sync::oneshot, task::JoinHandle}; + +/// Parameters for the sine wave generated with [`SineTrack`]. +#[derive(Clone, Debug)] +pub struct SineParameters { + pub sample_rate: u32, + pub freq: f64, + pub amplitude: f64, + pub num_channels: u32, +} + +/// Audio track which generates and publishes a sine wave. +/// +/// This implementation was taken from the *wgpu_room* example. +/// +pub struct SineTrack { + rtc_source: NativeAudioSource, + params: SineParameters, + room: Arc, + handle: Option, +} + +struct TrackHandle { + close_tx: oneshot::Sender<()>, + track: LocalAudioTrack, + task: JoinHandle<()>, +} + +impl SineTrack { + pub fn new(room: Arc, params: SineParameters) -> Self { + Self { + rtc_source: NativeAudioSource::new( + AudioSourceOptions::default(), + params.sample_rate, + params.num_channels, + 1000, + ), + params, + room, + handle: None, + } + } + + pub async fn publish(&mut self) -> RoomResult<()> { + let (close_tx, close_rx) = oneshot::channel(); + let track = LocalAudioTrack::create_audio_track( + "sine-track", + RtcAudioSource::Native(self.rtc_source.clone()), + ); + let task = + tokio::spawn(Self::track_task(close_rx, self.rtc_source.clone(), self.params.clone())); + self.room + .local_participant() + .publish_track(LocalTrack::Audio(track.clone()), TrackPublishOptions::default()) + .await?; + let handle = TrackHandle { close_tx, track, task }; + self.handle = Some(handle); + Ok(()) + } + + pub async fn unpublish(&mut self) -> RoomResult<()> { + if let Some(handle) = self.handle.take() { + handle.close_tx.send(()).ok(); + handle.task.await.ok(); + self.room.local_participant().unpublish_track(&handle.track.sid()).await?; + } + Ok(()) + } + + async fn track_task( + mut close_rx: oneshot::Receiver<()>, + rtc_source: NativeAudioSource, + params: SineParameters, + ) { + let num_channels = params.num_channels as usize; + let samples_count = (params.sample_rate / 100) as usize * num_channels; + let mut samples_10ms = vec![0; samples_count]; + let mut phase = 0; + loop { + if close_rx.try_recv().is_ok() { + break; + } + for i in (0..samples_count).step_by(num_channels) { + let val = params.amplitude + * f64::sin( + std::f64::consts::PI + * 2.0 + * params.freq + * (phase as f64 / params.sample_rate as f64), + ); + phase += 1; + for c in 0..num_channels { + // WebRTC uses 16-bit signed PCM + samples_10ms[i + c] = (val * 32768.0) as i16; + } + } + let frame = AudioFrame { + data: samples_10ms.as_slice().into(), + sample_rate: params.sample_rate, + num_channels: params.num_channels, + samples_per_channel: samples_count as u32 / params.num_channels, + }; + rtc_source.capture_frame(&frame).await.unwrap(); + } + } +} + +/// Analyzes samples to estimate the frequency of the signal using the zero crossing method. +#[derive(Clone)] +pub struct FreqAnalyzer { + zero_crossings: usize, + samples_analyzed: usize, +} + +impl FreqAnalyzer { + pub fn new() -> Self { + Self { zero_crossings: 0, samples_analyzed: 0 } + } + + pub fn analyze(&mut self, samples: impl IntoIterator) { + let mut iter = samples.into_iter(); + let mut prev = match iter.next() { + Some(v) => v, + None => return, + }; + let mut count = 0; + for curr in iter { + if (prev >= 0 && curr < 0) || (prev < 0 && curr >= 0) { + self.zero_crossings += 1; + } + prev = curr; + count += 1; + } + self.samples_analyzed += count + 1; + } + + pub fn estimated_freq(&self, sample_rate: u32) -> f64 { + let num_cycles = self.zero_crossings as f64 / 2.0; + let duration_seconds = self.samples_analyzed as f64 / sample_rate as f64; + if duration_seconds == 0.0 { + return 0.0; + } + num_cycles / duration_seconds + } +} + +pub trait ChannelIterExt<'a> { + /// Returns an iterator over the samples in a specific channel. + /// + /// # Arguments + /// * `channel_index` - Index of the channel to iterate over (must be less than `num_channels`). + /// + /// # Panics + /// Panics if `channel_index` is greater than or equal to `num_channels`. + /// + fn channel_iter(&'a self, channel_index: usize) -> ChannelIter<'a>; +} + +impl<'a> ChannelIterExt<'a> for AudioFrame<'a> { + fn channel_iter(&'a self, channel_index: usize) -> ChannelIter<'a> { + assert!(channel_index < self.num_channels as usize); + ChannelIter { frame: self, channel_index, index: 0 } + } +} + +/// Iterator over an individual channel in an interleaved [`AudioFrame`]. +pub struct ChannelIter<'a> { + frame: &'a AudioFrame<'a>, + channel_index: usize, + index: usize, +} + +impl<'a> Iterator for ChannelIter<'a> { + type Item = i16; + + fn next(&mut self) -> Option { + let inner_index = + self.index * (self.frame.num_channels as usize) + (self.channel_index as usize); + if inner_index >= self.frame.data.len() { + return None; + } + let sample = self.frame.data[inner_index]; + self.index += 1; + Some(sample) + } +} diff --git a/livekit/tests/common/e2e.rs b/livekit/tests/common/e2e/mod.rs similarity index 99% rename from livekit/tests/common/e2e.rs rename to livekit/tests/common/e2e/mod.rs index baeb566b9..b64fcf42b 100644 --- a/livekit/tests/common/e2e.rs +++ b/livekit/tests/common/e2e/mod.rs @@ -9,6 +9,8 @@ use tokio::{ time::{self, timeout}, }; +pub mod audio; + struct TestEnvironment { api_key: String, api_secret: String,