Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 209 additions & 56 deletions core/src/raw/oio/write/block_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use async_trait::async_trait;
use futures::Future;
use futures::FutureExt;
use futures::StreamExt;
use uuid::Uuid;

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -77,17 +78,22 @@ pub trait BlockWrite: Send + Sync + Unpin + 'static {
/// order.
///
/// - block_id is the id of the block.
async fn write_block(&self, size: u64, block_id: String, body: AsyncBody) -> Result<()>;
async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()>;

/// complete_block will complete the block upload to build the final
/// file.
async fn complete_block(&self, block_ids: Vec<String>) -> Result<()>;
async fn complete_block(&self, block_ids: Vec<Uuid>) -> Result<()>;

/// abort_block will cancel the block upload and purge all data.
async fn abort_block(&self, block_ids: Vec<String>) -> Result<()>;
async fn abort_block(&self, block_ids: Vec<Uuid>) -> Result<()>;
}

struct WriteBlockFuture(BoxedFuture<Result<()>>);
/// WriteBlockResult is the result returned by [`WriteBlockFuture`].
///
/// The error part will carries input `(block_id, bytes, err)` so caller can retry them.
type WriteBlockResult = std::result::Result<Uuid, (Uuid, oio::ChunkedBytes, Error)>;

struct WriteBlockFuture(BoxedFuture<WriteBlockResult>);

/// # Safety
///
Expand All @@ -100,19 +106,38 @@ unsafe impl Send for WriteBlockFuture {}
unsafe impl Sync for WriteBlockFuture {}

impl Future for WriteBlockFuture {
type Output = Result<()>;
type Output = WriteBlockResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().0.poll_unpin(cx)
}
}

impl WriteBlockFuture {
pub fn new<W: BlockWrite>(w: Arc<W>, block_id: Uuid, bytes: oio::ChunkedBytes) -> Self {
let fut = async move {
w.write_block(
block_id,
bytes.len() as u64,
AsyncBody::ChunkedBytes(bytes.clone()),
)
.await
// Return bytes while we got an error to allow retry.
.map_err(|err| (block_id, bytes, err))
// Return the successful block id.
.map(|_| block_id)
};

WriteBlockFuture(Box::pin(fut))
}
}

/// BlockWriter will implements [`Write`] based on block
/// uploads.
pub struct BlockWriter<W: BlockWrite> {
state: State,
w: Arc<W>,

block_ids: Vec<String>,
block_ids: Vec<Uuid>,
cache: Option<oio::ChunkedBytes>,
futures: ConcurrentFutures<WriteBlockFuture>,
}
Expand Down Expand Up @@ -168,20 +193,30 @@ where
let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
}

let cache = self.cache.take().expect("pending write must exist");
let block_id = uuid::Uuid::new_v4().to_string();
self.block_ids.push(block_id.clone());
let w = self.w.clone();
let size = cache.len();
self.futures
.push_back(WriteBlockFuture(Box::pin(async move {
w.write_block(size as u64, block_id, AsyncBody::ChunkedBytes(cache))
.await
})));
self.futures.push_back(WriteBlockFuture::new(
self.w.clone(),
Uuid::new_v4(),
cache,
));

let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
} else if let Some(res) = ready!(self.futures.poll_next_unpin(cx)) {
res?;
match res {
Ok(block_id) => {
self.block_ids.push(block_id);
}
Err((block_id, bytes, err)) => {
self.futures.push_front(WriteBlockFuture::new(
self.w.clone(),
block_id,
bytes,
));
return Poll::Ready(Err(err));
}
}
}
}
State::Close(_) => {
Expand All @@ -198,53 +233,55 @@ where
loop {
match &mut self.state {
State::Idle => {
let w = self.w.clone();
let block_ids = self.block_ids.clone();
// No write block has been sent.
if self.futures.is_empty() && self.block_ids.is_empty() {
let w = self.w.clone();
let (size, body) = match self.cache.clone() {
Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)),
None => (0, AsyncBody::Empty),
};
// Call write_once if there is no data in buffer and no location.
self.state =
State::Close(Box::pin(
async move { w.write_once(size as u64, body).await },
));
continue;
}

if self.block_ids.is_empty() {
match &self.cache {
Some(cache) => {
let w = self.w.clone();
let bs = cache.clone();
self.state = State::Close(Box::pin(async move {
let size = bs.len();
w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await
}));
}
None => {
let w = self.w.clone();
// Call write_once if there is no data in cache.
self.state = State::Close(Box::pin(async move {
w.write_once(0, AsyncBody::Empty).await
}));
if self.futures.has_remaining() {
if let Some(cache) = self.cache.take() {
self.futures.push_back(WriteBlockFuture::new(
self.w.clone(),
Uuid::new_v4(),
cache,
));
}
}

if !self.futures.is_empty() {
while let Some(result) = ready!(self.futures.poll_next_unpin(cx)) {
match result {
Ok(block_id) => {
self.block_ids.push(block_id);
}
Err((block_id, bytes, err)) => {
self.futures.push_front(WriteBlockFuture::new(
self.w.clone(),
block_id,
bytes,
));
return Poll::Ready(Err(err));
}
}
}
} else if self.futures.is_empty() && self.cache.is_none() {
} else {
let w = self.w.clone();
let block_ids = self.block_ids.clone();
self.state =
State::Close(Box::pin(
async move { w.complete_block(block_ids).await },
));
} else {
if self.futures.has_remaining() {
if let Some(cache) = self.cache.take() {
let block_id = uuid::Uuid::new_v4().to_string();
self.block_ids.push(block_id.clone());
let size = cache.len();
let w = self.w.clone();
self.futures
.push_back(WriteBlockFuture(Box::pin(async move {
w.write_block(
size as u64,
block_id,
AsyncBody::ChunkedBytes(cache),
)
.await
})));
}
}
while let Some(res) = ready!(self.futures.poll_next_unpin(cx)) {
res?;
}
continue;
}
}
State::Close(fut) => {
Expand All @@ -270,6 +307,7 @@ where
let w = self.w.clone();
let block_ids = self.block_ids.clone();
self.futures.clear();
self.cache = None;
self.state =
State::Abort(Box::pin(async move { w.abort_block(block_ids).await }));
}
Expand All @@ -285,3 +323,118 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::raw::oio::{StreamExt, WriteBuf, WriteExt};
use bytes::Bytes;
use pretty_assertions::assert_eq;
use rand::{thread_rng, Rng, RngCore};
use std::collections::HashMap;
use std::sync::Mutex;

struct TestWrite {
length: u64,
bytes: HashMap<Uuid, Bytes>,
content: Option<Bytes>,
}

impl TestWrite {
pub fn new() -> Arc<Mutex<Self>> {
let v = Self {
length: 0,
bytes: HashMap::new(),
content: None,
};

Arc::new(Mutex::new(v))
}
}

#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl BlockWrite for Arc<Mutex<TestWrite>> {
async fn write_once(&self, _: u64, _: AsyncBody) -> Result<()> {
Ok(())
}

async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()> {
// We will have 50% percent rate for write part to fail.
if thread_rng().gen_bool(5.0 / 10.0) {
return Err(Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!"));
}

let bs = match body {
AsyncBody::Empty => Bytes::new(),
AsyncBody::Bytes(bs) => bs,
AsyncBody::ChunkedBytes(cb) => cb.bytes(cb.remaining()),
AsyncBody::Stream(s) => s.collect().await.unwrap(),
};

let mut this = self.lock().unwrap();
this.length += size;
this.bytes.insert(block_id, bs);

Ok(())
}

async fn complete_block(&self, block_ids: Vec<Uuid>) -> Result<()> {
let mut this = self.lock().unwrap();
let mut bs = Vec::new();
for id in block_ids {
bs.extend_from_slice(&this.bytes[&id]);
}
this.content = Some(bs.into());

Ok(())
}

async fn abort_block(&self, _: Vec<Uuid>) -> Result<()> {
Ok(())
}
}

#[tokio::test]
async fn test_block_writer_with_concurrent_errors() {
let mut rng = thread_rng();

let mut w = BlockWriter::new(TestWrite::new(), 8);
let mut total_size = 0u64;
let mut expected_content = Vec::new();

for _ in 0..1000 {
let size = rng.gen_range(1..1024);
total_size += size as u64;

let mut bs = vec![0; size];
rng.fill_bytes(&mut bs);

expected_content.extend_from_slice(&bs);

loop {
match w.write(&bs.as_slice()).await {
Ok(_) => break,
Err(_) => continue,
}
}
}

loop {
match w.close().await {
Ok(_) => break,
Err(_) => continue,
}
}

let inner = w.w.lock().unwrap();

assert_eq!(total_size, inner.length, "length must be the same");
assert!(inner.content.is_some());
assert_eq!(
expected_content,
inner.content.clone().unwrap(),
"content must be the same"
);
}
}
26 changes: 12 additions & 14 deletions core/src/raw/oio/write/range_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,18 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
}
}
}
None => match self.buffer.clone() {
Some(bs) => {
self.state = State::Complete(Box::pin(async move {
let size = bs.len();
w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await
}));
}
None => {
// Call write_once if there is no data in buffer and no location.
self.state = State::Complete(Box::pin(async move {
w.write_once(0, AsyncBody::Empty).await
}));
}
},
None => {
let w = self.w.clone();
let (size, body) = match self.buffer.clone() {
Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)),
None => (0, AsyncBody::Empty),
};
// Call write_once if there is no data in buffer and no location.

self.state = State::Complete(Box::pin(async move {
w.write_once(size as u64, body).await
}));
}
}
}
State::Init(_) => {
Expand Down
Loading