Skip to content

Commit

Permalink
fix(playback): refactor cache code (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
snylonue authored Aug 23, 2024
1 parent 4e4bd6b commit 4de0f4e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 39 deletions.
27 changes: 17 additions & 10 deletions anni-playback/src/sources/cached_http/cache.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{
fs::{self, File}, io::{self, ErrorKind}, path::{Path, PathBuf}
fs::{self, File},
io::{self, ErrorKind},
path::{Path, PathBuf},
};

use crate::CODEC_REGISTRY;
Expand All @@ -26,7 +28,7 @@ impl CacheStore {
/// Returns the path to given `track`
pub fn loaction_of(&self, track: RawTrackIdentifier) -> PathBuf {
let mut tmp = self.base.clone();

tmp.extend([
track.album_id.as_ref(),
&format!(
Expand All @@ -40,12 +42,13 @@ impl CacheStore {

/// Attempts to open a cache file corresponding to `track` and validates it.
///
/// On success, returns a `Result<File, File>`.
/// On success, returns a `Result<File, (File, File)>`.
/// If the cache exists and is valid, opens it in read mode and returns an `Ok(_)`.
/// Otherwise, opens or creates a cache file in append mode and returns an `Err(_)`.
/// Otherwise, creates or truncates a cache file, opens it in read mode as a `reader`
/// and append mode as a `writer`, and returns an `Err((reader, writer))`
///
/// On error, an [`Error`](std::io::Error) is returned.
pub fn acquire(&self, track: RawTrackIdentifier) -> io::Result<Result<File, File>> {
pub fn acquire(&self, track: RawTrackIdentifier) -> io::Result<Result<File, (File, File)>> {
let path = self.loaction_of(track.copied());

if path.exists() {
Expand All @@ -58,12 +61,16 @@ impl CacheStore {

create_dir_all(path.parent().unwrap())?; // parent of `path` exists

File::options()
.read(true)
.append(true)
let _ = File::options()
.write(true)
.truncate(true)
.create(true)
.open(path)
.map(|f| Err(f))
.open(&path)?; // truncate the file first to clear incorrect data

let reader = File::options().read(true).open(&path)?;
let writer = File::options().append(true).open(path)?;

Ok(Err((reader, writer)))
}

pub fn add(&self, path: &Path, track: RawTrackIdentifier) -> io::Result<()> {
Expand Down
57 changes: 28 additions & 29 deletions anni-playback/src/sources/cached_http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod provider;

use std::{
fs::File,
hint::spin_loop,
io::{ErrorKind, Read, Seek, Write},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Expand All @@ -29,7 +30,7 @@ pub struct CachedHttpSource {
identifier: TrackIdentifier,
cache: File,
buf_len: Arc<AtomicUsize>,
pos: Arc<AtomicUsize>,
pos: usize,
is_buffering: Arc<AtomicBool>,
#[allow(unused)]
buffer_signal: Arc<AtomicBool>,
Expand All @@ -45,15 +46,15 @@ impl CachedHttpSource {
client: Client,
buffer_signal: Arc<AtomicBool>,
) -> Result<Self, OpenTrackError> {
let cache = match cache_store.acquire(identifier.inner.copied())? {
let (reader, writer) = match cache_store.acquire(identifier.inner.copied())? {
Ok(cache) => {
let buf_len = cache.metadata()?.len() as usize;

return Ok(Self {
identifier,
cache,
buf_len: Arc::new(AtomicUsize::new(buf_len)),
pos: Arc::new(AtomicUsize::new(0)),
pos: 0,
is_buffering: Arc::new(AtomicBool::new(false)),
buffer_signal,
duration: None,
Expand All @@ -64,16 +65,14 @@ impl CachedHttpSource {

let buf_len = Arc::new(AtomicUsize::new(0));
let is_buffering = Arc::new(AtomicBool::new(true));
let pos = Arc::new(AtomicUsize::new(0));

let (url, duration) = url().ok_or(OpenTrackError::NoAvailableAnnil)?;

log::debug!("got duration {duration:?}");

thread::spawn({
let mut cache = cache.try_clone()?;
let mut cache = writer;
let buf_len = Arc::clone(&buf_len);
let pos = Arc::clone(&pos);
let mut buf = [0; BUF_SIZE];
let is_buffering = Arc::clone(&is_buffering);
let identifier = identifier.clone();
Expand All @@ -83,45 +82,45 @@ impl CachedHttpSource {
Ok(r) => r,
Err(e) => {
log::error!("failed to send request: {e}");
is_buffering.store(false, Ordering::Release);
return;
}
};

loop {
match response.read(&mut buf) {
Ok(0) => {
is_buffering.store(false, Ordering::Release);
log::info!("{identifier} reached eof");
break;
}
Ok(n) => {
let pos = pos.load(Ordering::Acquire);
if let Err(e) = cache.write_all(&buf[..n]) {
log::error!("{e}")
log::error!("{e}");
break;
}

log::trace!("wrote {n} bytes to {identifier}");

let _ = cache.seek(std::io::SeekFrom::Start(pos as u64));
let _ = cache.flush();

buf_len.fetch_add(n, Ordering::AcqRel);

log::trace!("wrote {n} bytes to {identifier}");
}
Err(e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => {
log::error!("{e}");
is_buffering.store(false, Ordering::Release);
break;
}
}
}

is_buffering.store(false, Ordering::Release);
}
});

Ok(Self {
identifier,
cache,
cache: reader,
buf_len,
pos,
pos: 0,
is_buffering,
buffer_signal,
duration,
Expand All @@ -131,25 +130,25 @@ impl CachedHttpSource {

impl Read for CachedHttpSource {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// let n = self.cache.read(buf)?;
// self.pos.fetch_add(n, Ordering::AcqRel);
// log::trace!("read {n} bytes");
// Ok(n)

// A naive spin loop that waits until we have more data to read.
loop {
let has_buf = self.buf_len.load(Ordering::Acquire) > self.pos.load(Ordering::Acquire);
let is_buffering = self.is_buffering.load(Ordering::Acquire);
let buf_len = self.buf_len.load(Ordering::Acquire);
let has_buf = buf_len > self.pos;

if has_buf {
let n = self.cache.read(buf)?;
let n = <File as Read>::by_ref(&mut self.cache)
.take((buf_len - self.pos) as u64)
.read(buf)?; // ensure not exceeding the buffer

log::trace!("read {n} bytes from {}", self.identifier);
if n == 0 {
continue;
}
self.pos.fetch_add(n, Ordering::AcqRel);

self.pos += n;
break Ok(n);
} else if !is_buffering {
break Ok(0);
} else {
spin_loop();
}
}
}
Expand All @@ -158,14 +157,14 @@ impl Read for CachedHttpSource {
impl Seek for CachedHttpSource {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
let p = self.cache.seek(pos)?;
self.pos.store(p as usize, Ordering::Release);
self.pos += p as usize;
Ok(p)
}
}

impl MediaSource for CachedHttpSource {
fn is_seekable(&self) -> bool {
!self.is_buffering.load(Ordering::Relaxed)
!self.is_buffering.load(Ordering::Acquire)
}

fn byte_len(&self) -> Option<u64> {
Expand Down

0 comments on commit 4de0f4e

Please sign in to comment.