diff --git a/Cargo.lock b/Cargo.lock index c4f2d35da12..24595a5f711 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3826,6 +3826,7 @@ dependencies = [ "arrow-arith", "arrow-array", "arrow-buffer", + "arrow-ipc", "arrow-ord", "arrow-row", "arrow-schema", @@ -3849,6 +3850,7 @@ dependencies = [ "datafusion-physical-expr", "deepsize", "dirs", + "either", "env_logger", "futures", "half", @@ -3976,9 +3978,11 @@ dependencies = [ "lance-datagen", "lazy_static", "log", + "pin-project", "prost 0.13.5", "snafu", "substrait-expr", + "tempfile", "tokio", "tracing", ] diff --git a/python/Cargo.lock b/python/Cargo.lock index 22a171d477f..743d2cd898a 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -3192,6 +3192,7 @@ dependencies = [ "arrow-arith", "arrow-array", "arrow-buffer", + "arrow-ipc", "arrow-ord", "arrow-row", "arrow-schema", @@ -3210,6 +3211,7 @@ dependencies = [ "datafusion-functions", "datafusion-physical-expr", "deepsize", + "either", "futures", "half", "humantime", @@ -3321,8 +3323,10 @@ dependencies = [ "lance-datagen", "lazy_static", "log", + "pin-project", "prost 0.13.5", "snafu", + "tempfile", "tokio", "tracing", ] diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index 28b7979371a..6c055f9126a 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -30,6 +30,7 @@ pub mod floats; pub use floats::*; pub mod cast; pub mod list; +pub mod memory; type Result = std::result::Result; diff --git a/rust/lance-arrow/src/memory.rs b/rust/lance-arrow/src/memory.rs new file mode 100644 index 00000000000..6b8db9da769 --- /dev/null +++ b/rust/lance-arrow/src/memory.rs @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::HashSet; + +use arrow_array::{Array, RecordBatch}; +use arrow_data::ArrayData; + +/// Counts memory used by buffers of Arrow arrays and RecordBatches. +/// +/// This is meant to capture how much memory is being used by the Arrow data +/// structures as they are. It does not represent the memory used if the data +/// were to be serialized and then deserialized. In particular: +/// +/// * This does not double count memory used by buffers shared by multiple +/// arrays or batches. Round-tripped data may use more memory because of this. +/// * This counts the **total** size of the buffers, even if the array is a slice. +/// Round-tripped data may use less memory because of this. +#[derive(Default)] +pub struct MemoryAccumulator { + seen: HashSet, + total: usize, +} + +impl MemoryAccumulator { + pub fn record_array(&mut self, array: &dyn Array) { + let data = array.to_data(); + self.record_array_data(&data); + } + + fn record_array_data(&mut self, data: &ArrayData) { + for buffer in data.buffers() { + let ptr = buffer.as_ptr(); + if self.seen.insert(ptr as usize) { + self.total += buffer.capacity(); + } + } + + if let Some(nulls) = data.nulls() { + let null_buf = nulls.inner().inner(); + let ptr = null_buf.as_ptr(); + if self.seen.insert(ptr as usize) { + self.total += null_buf.capacity(); + } + } + + for child in data.child_data() { + self.record_array_data(child); + } + } + + pub fn record_batch(&mut self, batch: &RecordBatch) { + for array in batch.columns() { + self.record_array(array); + } + } + + pub fn total(&self) -> usize { + self.total + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field, Schema}; + + use super::*; + + #[test] + fn test_memory_accumulator() { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let slice = batch.slice(1, 2); + + let mut acc = MemoryAccumulator::default(); + + // Should record whole buffer, not just slice + acc.record_batch(&slice); + assert_eq!(acc.total(), 3 * std::mem::size_of::()); + + // Should not double count + acc.record_batch(&slice); + assert_eq!(acc.total(), 3 * std::mem::size_of::()); + } +} diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 7387234e918..6eed53b010e 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -51,6 +51,14 @@ pub enum Error { source: BoxedError, location: Location, }, + #[snafu(display("Retryable commit conflict for version {version}: {source}, {location}"))] + RetryableCommitConflict { + version: u64, + source: BoxedError, + location: Location, + }, + #[snafu(display("Too many concurrent writers. {message}, {location}"))] + TooMuchWriteContention { message: String, location: Location }, #[snafu(display("Encountered internal error. Please file a bug report at https://github.com/lancedb/lance/issues. {message}, {location}"))] Internal { message: String, location: Location }, #[snafu(display("A prerequisite task failed: {message}, {location}"))] diff --git a/rust/lance-core/src/utils.rs b/rust/lance-core/src/utils.rs index a67cfad693d..f04ca305f93 100644 --- a/rust/lance-core/src/utils.rs +++ b/rust/lance-core/src/utils.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors pub mod address; +pub mod backoff; pub mod bit; pub mod cpu; pub mod deletion; diff --git a/rust/lance-core/src/utils/backoff.rs b/rust/lance-core/src/utils/backoff.rs new file mode 100644 index 00000000000..d2093cb5b6d --- /dev/null +++ b/rust/lance-core/src/utils/backoff.rs @@ -0,0 +1,92 @@ +use rand::Rng; +use std::time::Duration; + +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +/// Computes backoff as +/// +/// ```text +/// backoff = base^attempt * unit + jitter +/// ``` +/// +/// The defaults are base=2, unit=50ms, jitter=50ms, min=0ms, max=5s. This gives +/// a backoff of 50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 5s, (not including jitter). +/// +/// You can have non-exponential backoff by setting base=1. +pub struct Backoff { + base: u32, + unit: u32, + jitter: i32, + min: u32, + max: u32, + attempt: u32, +} + +impl Default for Backoff { + fn default() -> Self { + Self { + base: 2, + unit: 50, + jitter: 50, + min: 0, + max: 5000, + attempt: 0, + } + } +} + +impl Backoff { + pub fn with_base(self, base: u32) -> Self { + Self { base, ..self } + } + + pub fn with_jitter(self, jitter: i32) -> Self { + Self { jitter, ..self } + } + + pub fn with_min(self, min: u32) -> Self { + Self { min, ..self } + } + + pub fn with_max(self, max: u32) -> Self { + Self { max, ..self } + } + + pub fn next_backoff(&mut self) -> Duration { + let backoff = self + .base + .saturating_pow(self.attempt) + .saturating_mul(self.unit); + let jitter = rand::thread_rng().gen_range(-self.jitter..=self.jitter); + let backoff = (backoff.saturating_add_signed(jitter)).clamp(self.min, self.max); + self.attempt += 1; + Duration::from_millis(backoff as u64) + } + + pub fn attempt(&self) -> u32 { + self.attempt + } + + pub fn reset(&mut self) { + self.attempt = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_backoff() { + let mut backoff = Backoff::default().with_jitter(0); + assert_eq!(backoff.next_backoff().as_millis(), 50); + assert_eq!(backoff.attempt(), 1); + assert_eq!(backoff.next_backoff().as_millis(), 100); + assert_eq!(backoff.attempt(), 2); + assert_eq!(backoff.next_backoff().as_millis(), 200); + assert_eq!(backoff.attempt(), 3); + assert_eq!(backoff.next_backoff().as_millis(), 400); + assert_eq!(backoff.attempt(), 4); + } +} diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index b99a4a52f87..3d60207422f 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -28,8 +28,10 @@ lance-core = { workspace = true, features = ["datafusion"] } lance-datagen.workspace = true lazy_static.workspace = true log.workspace = true +pin-project.workspace = true prost.workspace = true snafu.workspace = true +tempfile.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/rust/lance-datafusion/src/lib.rs b/rust/lance-datafusion/src/lib.rs index 002c087754d..a99afbbbe08 100644 --- a/rust/lance-datafusion/src/lib.rs +++ b/rust/lance-datafusion/src/lib.rs @@ -9,6 +9,7 @@ pub mod expr; pub mod logical_expr; pub mod planner; pub mod projection; +pub mod spill; pub mod sql; #[cfg(feature = "substrait")] pub mod substrait; diff --git a/rust/lance-datafusion/src/spill.rs b/rust/lance-datafusion/src/spill.rs new file mode 100644 index 00000000000..cb60669a5af --- /dev/null +++ b/rust/lance-datafusion/src/spill.rs @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + io::{BufReader, BufWriter}, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, Schema}; +use datafusion::{ + execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter, +}; +use datafusion_common::DataFusionError; +use lance_arrow::memory::MemoryAccumulator; +use lance_core::error::LanceOptionExt; + +/// Start a spill of Arrow data to a file that can be read later multiple times. +/// +/// Up to `memory_limit` bytes of data can be buffered in memory before a spill +/// is created. If the memory limit is never reached before [`SpillSender::finish()`] +/// is called, then the data will simply be kept in memory and no spill will be +/// created. +/// +/// `path` is the path to the file that may be created. It should not already +/// exist. It is the responsibility of the caller to delete the file after it is +/// no longer needed. +/// +/// The [`SpillSender`] allows you to write batches to the spill. +/// +/// The [`SpillReceiver`] can open a [`SendableRecordBatchStream`] that reads +/// batches from the spill. This can be opened before, during, or after batches +/// have been written to the spill. +/// +/// Once [`SpillSender`] is dropped, the temporary file is deleted. This will +/// cause the [`SpillReceiver`] to return an error if it is still open. +pub fn create_replay_spill( + path: std::path::PathBuf, + schema: Arc, + memory_limit: usize, +) -> (SpillSender, SpillReceiver) { + let initial_status = WriteStatus::default(); + let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status); + let sender = SpillSender { + memory_limit, + path: path.clone(), + schema: schema.clone(), + state: SpillState::default(), + status_sender, + }; + + let receiver = SpillReceiver { + status_receiver, + path, + schema, + }; + + (sender, receiver) +} + +#[derive(Clone)] +pub struct SpillReceiver { + status_receiver: tokio::sync::watch::Receiver, + path: PathBuf, + schema: Arc, +} + +impl SpillReceiver { + /// Returns a stream of batches from the spill. The stream will emit + /// batches as they are written to the spill. If the spill has already + /// been finished, the stream will emit all batches in the spill. + /// + /// The stream will not complete until [`Self::finish()`] is called. + /// + /// If the spill has been dropped, an error will be returned. + pub fn read(&self) -> SendableRecordBatchStream { + let rx = self.status_receiver.clone(); + let reader = SpillReader::new(rx, self.path.clone()); + + let stream = futures::stream::try_unfold(reader, move |mut reader| async move { + match reader.read().await { + Ok(None) => Ok(None), + Ok(Some(batch)) => Ok(Some((batch, reader))), + Err(err) => Err(err), + } + }); + + Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) + } +} + +struct SpillReader { + pub batches_read: usize, + receiver: tokio::sync::watch::Receiver, + state: SpillReaderState, +} + +enum SpillReaderState { + Buffered { spill_path: PathBuf }, + Reader { reader: AsyncStreamReader }, +} + +impl SpillReader { + fn new(receiver: tokio::sync::watch::Receiver, spill_path: PathBuf) -> Self { + Self { + batches_read: 0, + receiver, + state: SpillReaderState::Buffered { spill_path }, + } + } + + async fn wait_for_more_data(&mut self) -> Result>, DataFusionError> { + let status = self + .receiver + .wait_for(|status| { + status.error.is_some() + || status.finished + || status.batches_written() > self.batches_read + }) + .await + .map_err(|_| { + DataFusionError::Execution( + "Spill has been dropped before reader has finish.".into(), + ) + })?; + + if let Some(error) = &status.error { + let mut guard = error.lock().ok().expect_ok()?; + return Err(DataFusionError::from(&mut (*guard))); + } + + if let DataLocation::Buffered { batches } = &status.data_location { + Ok(Some(batches.clone())) + } else { + Ok(None) + } + } + + async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> { + if let SpillReaderState::Buffered { spill_path } = &self.state { + let reader = AsyncStreamReader::open(spill_path.clone()).await?; + // Skip batches we've already read before the writer started spilling. + // The read batches were spilled to the file for the benefit of + // future readers, as the spill is replay-able. + for _ in 0..self.batches_read { + reader.read().await?; + } + self.state = SpillReaderState::Reader { reader }; + } + + if let SpillReaderState::Reader { reader } = &mut self.state { + Ok(reader) + } else { + unreachable!() + } + } + + async fn read(&mut self) -> Result, DataFusionError> { + let maybe_data = self.wait_for_more_data().await?; + + if let Some(batches) = maybe_data { + if self.batches_read < batches.len() { + let batch = batches[self.batches_read].clone(); + self.batches_read += 1; + Ok(Some(batch)) + } else { + Ok(None) + } + } else { + let reader = self.get_reader().await?; + let batch = reader.read().await?; + if batch.is_some() { + self.batches_read += 1; + } + Ok(batch) + } + } +} + +/// The sender side of the spill. This is used to write batches to the spill. +/// +/// Note: this must be kept alive until after the readers are done reading the +/// spill. Otherwise, they will return an error. +pub struct SpillSender { + memory_limit: usize, + schema: Arc, + path: PathBuf, + state: SpillState, + status_sender: tokio::sync::watch::Sender, +} + +enum SpillState { + Buffering { + batches: Vec, + memory_accumulator: MemoryAccumulator, + }, + Spilling { + writer: AsyncStreamWriter, + batches_written: usize, + }, + Finished { + batches: Option>, + batches_written: usize, + }, + Errored { + error: Arc>, + }, +} + +impl Default for SpillState { + fn default() -> Self { + Self::Buffering { + batches: Vec::new(), + memory_accumulator: MemoryAccumulator::default(), + } + } +} + +#[derive(Clone, Debug, Default)] +struct WriteStatus { + error: Option>>, + finished: bool, + data_location: DataLocation, +} + +impl WriteStatus { + fn batches_written(&self) -> usize { + match &self.data_location { + DataLocation::Buffered { batches } => batches.len(), + DataLocation::Spilled { + batches_written, .. + } => *batches_written, + } + } +} + +#[derive(Clone, Debug)] +enum DataLocation { + Buffered { batches: Arc<[RecordBatch]> }, + Spilled { batches_written: usize }, +} + +impl Default for DataLocation { + fn default() -> Self { + Self::Buffered { + batches: Arc::new([]), + } + } +} + +/// A DataFusion error that be be emitted multiple times. We provide the +/// Original error first, and subsequent conversions provide a copy with a +/// string representation of the original error. +#[derive(Debug)] +enum SpillError { + Original(DataFusionError), + Copy(DataFusionError), +} + +impl From for SpillError { + fn from(err: DataFusionError) -> Self { + Self::Original(err) + } +} + +impl From<&mut SpillError> for DataFusionError { + fn from(err: &mut SpillError) -> Self { + match err { + SpillError::Original(inner) => { + let copy = Self::Execution(inner.to_string()); + let original = std::mem::replace(err, SpillError::Copy(copy)); + if let SpillError::Original(inner) = original { + inner + } else { + unreachable!() + } + } + SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()), + _ => unreachable!(), + } + } +} + +impl From<&SpillState> for WriteStatus { + fn from(state: &SpillState) -> Self { + match state { + SpillState::Buffering { batches, .. } => Self { + finished: false, + data_location: DataLocation::Buffered { + batches: batches.clone().into(), + }, + error: None, + }, + SpillState::Spilling { + batches_written, .. + } => Self { + finished: false, + data_location: DataLocation::Spilled { + batches_written: *batches_written, + }, + error: None, + }, + SpillState::Finished { + batches_written, + batches, + } => { + let data_location = if let Some(batches) = batches { + DataLocation::Buffered { + batches: batches.clone(), + } + } else { + DataLocation::Spilled { + batches_written: *batches_written, + } + }; + Self { + finished: true, + data_location, + error: None, + } + } + SpillState::Errored { error } => Self { + finished: true, + data_location: DataLocation::default(), // Doesn't matter. + error: Some(error.clone()), + }, + } + } +} + +impl SpillSender { + /// Write a batch to the spill. + /// + /// If there is room in the `memory_limit` then the batch is queued. + /// If `memory_limit` is first encountered then all queued batches, and this one, + /// will be written to disk as part of this call. + /// If we are already spilling then the batch will be written to disk as part of this + /// call. + pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> { + if let SpillState::Finished { .. } = self.state { + return Err(DataFusionError::Execution( + "Spill has already been finished".to_string(), + )); + } + + if let SpillState::Errored { .. } = &self.state { + return Err(DataFusionError::Execution( + "Spill has sent an error".to_string(), + )); + } + + let (writer, batches_written) = match &mut self.state { + SpillState::Buffering { + batches, + ref mut memory_accumulator, + } => { + memory_accumulator.record_batch(&batch); + + if memory_accumulator.total() > self.memory_limit { + let writer = + AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?; + let batches_written = batches.len(); + for batch in batches.drain(..) { + writer.write(batch).await?; + } + self.state = SpillState::Spilling { + writer, + batches_written, + }; + if let SpillState::Spilling { + writer, + batches_written, + } = &mut self.state + { + (writer, batches_written) + } else { + unreachable!() + } + } else { + batches.push(batch); + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + return Ok(()); + } + } + SpillState::Spilling { + writer, + batches_written, + } => (writer, batches_written), + _ => unreachable!(), + }; + + writer.write(batch).await?; + *batches_written += 1; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + + Ok(()) + } + + /// Send an error to the spill. This will be sent to all readers of the + /// spill. + pub fn send_error(&mut self, err: DataFusionError) { + let error = Arc::new(Mutex::new(err.into())); + self.state = SpillState::Errored { error }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + + /// Complete the spill write. This will finalize the Arrow IPC stream file. + /// The file will remain available for reading until [`Self::shutdown()`] + /// or until the spill is dropped. + pub async fn finish(&mut self) -> Result<(), DataFusionError> { + // We create a temporary state to get an owned copy of current state. + // Since we hold an exclusive reference to `self`, no one should be + // able to see this temporary state. + let tmp_state = SpillState::Finished { + batches_written: 0, + batches: None, + }; + match std::mem::replace(&mut self.state, tmp_state) { + SpillState::Buffering { batches, .. } => { + let batches_written = batches.len(); + self.state = SpillState::Finished { + batches_written, + batches: Some(batches.into()), + }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + SpillState::Spilling { + writer, + batches_written, + } => { + writer.finish().await?; + self.state = SpillState::Finished { + batches_written, + batches: None, + }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + SpillState::Finished { .. } => { + return Err(DataFusionError::Execution( + "Spill has already been finished".to_string(), + )); + } + SpillState::Errored { .. } => { + return Err(DataFusionError::Execution( + "Spill has sent an error".to_string(), + )); + } + }; + + Ok(()) + } +} + +/// An async wrapper around [`StreamWriter`]. Each call uses [`tokio::task::spawn_blocking`] +/// to spawn a blocking task to write the batch. +struct AsyncStreamWriter { + writer: Arc>>>, +} + +impl AsyncStreamWriter { + pub async fn open(path: PathBuf, schema: Arc) -> Result { + let writer = tokio::task::spawn_blocking(move || { + let file = std::fs::File::create(&path).map_err(ArrowError::from)?; + let writer = BufWriter::new(file); + StreamWriter::try_new(writer, &schema) + }) + .await + .unwrap()?; + let writer = Arc::new(Mutex::new(writer)); + Ok(Self { writer }) + } + + pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> { + let writer = self.writer.clone(); + tokio::task::spawn_blocking(move || { + let mut writer = writer.lock().unwrap(); + writer.write(&batch)?; + writer.flush() + }) + .await + .unwrap() + } + + pub async fn finish(self) -> Result<(), ArrowError> { + let writer = self.writer.clone(); + tokio::task::spawn_blocking(move || { + let mut writer = writer.lock().unwrap(); + writer.finish() + }) + .await + .unwrap() + } +} + +struct AsyncStreamReader { + reader: Arc>>>, +} + +impl AsyncStreamReader { + pub async fn open(path: PathBuf) -> Result { + let reader = tokio::task::spawn_blocking(move || { + let file = std::fs::File::open(&path).map_err(ArrowError::from)?; + let reader = BufReader::new(file); + StreamReader::try_new(reader, None) + }) + .await + .unwrap()?; + let reader = Arc::new(Mutex::new(reader)); + Ok(Self { reader }) + } + + pub async fn read(&self) -> Result, ArrowError> { + let reader = self.reader.clone(); + tokio::task::spawn_blocking(move || { + let mut reader = reader.lock().unwrap(); + reader.next() + }) + .await + .unwrap() + .transpose() + } +} + +#[cfg(test)] +mod tests { + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field}; + use futures::{poll, StreamExt, TryStreamExt}; + + use super::*; + + #[tokio::test] + async fn test_spill() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batches = [ + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![4, 5, 6]))], + ) + .unwrap(), + ]; + + // Create a stream + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), 0); + + // We can open a reader prior to writing any data. No batches will be ready. + let mut stream_before = receiver.read(); + let mut stream_before_next = stream_before.next(); + let poll_res = poll!(&mut stream_before_next); + assert!(poll_res.is_pending()); + + // If we write a batch, the existing reader can now receive it. + spill.write(batches[0].clone()).await.unwrap(); + let stream_before_batch1 = stream_before_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_before_batch1, &batches[0]); + let mut stream_before_next = stream_before.next(); + let poll_res = poll!(&mut stream_before_next); + assert!(poll_res.is_pending()); + + // We can also open a ready while the spill is being written to. We can + // retrieve batches written so far immediately. + let mut stream_during = receiver.read(); + let stream_during_batch1 = stream_during + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_during_batch1, &batches[0]); + let mut stream_during_next = stream_during.next(); + let poll_res = poll!(&mut stream_during_next); + assert!(poll_res.is_pending()); + + // Once we finish the spill, readers can get remaining batches and will + // reach the end of the stream. + spill.write(batches[1].clone()).await.unwrap(); + spill.finish().await.unwrap(); + + let stream_before_batch2 = stream_before_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_before_batch2, &batches[1]); + assert!(stream_before.next().await.is_none()); + + let stream_during_batch2 = stream_during_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_during_batch2, &batches[1]); + assert!(stream_during.next().await.is_none()); + + // Can also start a reader after finishing. + let stream_after = receiver.read(); + let stream_after_batches = stream_after.try_collect::>().await.unwrap(); + assert_eq!(&stream_after_batches, &batches); + + std::fs::remove_file(path).unwrap(); + } + + #[tokio::test] + async fn test_spill_error() { + // Create a spill + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), 0); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + spill.write(batch.clone()).await.unwrap(); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + + spill.send_error(DataFusionError::ResourcesExhausted("🥱".into())); + let stream_error = stream + .next() + .await + .expect("Expected an error") + .expect_err("Expected an error"); + assert!(matches!( + stream_error, + DataFusionError::ResourcesExhausted(message) if message == "🥱" + )); + + // If we try to write after sending an error, it should return an error. + let err = spill.write(batch).await; + assert!(matches!( + err, + Err(DataFusionError::Execution(message)) if message == "Spill has sent an error" + )); + + // If we try to finish after sending an error, it should return an error. + let err = spill.finish().await; + assert!(matches!( + err, + Err(DataFusionError::Execution(message)) if message == "Spill has sent an error" + )); + + // If we try to read after sending an error, it should return an error. + let mut stream = receiver.read(); + let stream_error = stream + .next() + .await + .expect("Expected an error") + .expect_err("Expected an error"); + assert!(matches!( + stream_error, + DataFusionError::Execution(message) if message.contains("🥱") + )); + + std::fs::remove_file(path).unwrap(); + } + + #[tokio::test] + async fn test_spill_buffered() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let memory_limit = 1024 * 1024; // 1 MiB + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit); + + // 0.5 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + spill.finish().await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + + assert!(!std::fs::exists(&path).unwrap()); + } + + #[tokio::test] + async fn test_spill_buffered_transition() { + // Starts as buffered, then spills, then finished. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let memory_limit = 1024 * 1024; // 1 MiB + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit); + + // 0.7 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + assert!(!std::fs::exists(&path).unwrap()); + + // 0.5 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(std::fs::exists(&path).unwrap()); + + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + assert!(std::fs::exists(&path).unwrap()); + + spill.finish().await.unwrap(); + + assert!(stream.next().await.is_none()); + + std::fs::remove_file(path).unwrap(); + } +} diff --git a/rust/lance-datafusion/src/utils.rs b/rust/lance-datafusion/src/utils.rs index bd07ed0f70a..8d33c68c6ed 100644 --- a/rust/lance-datafusion/src/utils.rs +++ b/rust/lance-datafusion/src/utils.rs @@ -7,6 +7,7 @@ use arrow::ffi_stream::ArrowArrayStreamReader; use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_schema::{ArrowError, SchemaRef}; use async_trait::async_trait; +use background_iterator::BackgroundIterator; use datafusion::{ execution::RecordBatchStream, physical_plan::{ @@ -16,21 +17,12 @@ use datafusion::{ }, }; use datafusion_common::DataFusionError; -use futures::{stream, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{stream, StreamExt, TryStreamExt}; use lance_core::datatypes::Schema; use lance_core::Result; -use tokio::task::{spawn, spawn_blocking}; +use tokio::task::spawn; -fn background_iterator(iter: I) -> impl Stream -where - I::Item: Send, -{ - stream::unfold(iter, |mut iter| { - spawn_blocking(|| iter.next().map(|val| (val, iter))) - .unwrap_or_else(|err| panic!("{}", err)) - }) - .fuse() -} +pub mod background_iterator; /// A trait for [BatchRecord] iterators, readers and streams /// that can be converted to a concrete stream type [SendableRecordBatchStream]. @@ -151,7 +143,9 @@ pub fn reader_to_stream(batches: Box) -> SendableR let arrow_schema = batches.arrow_schema(); let stream = RecordBatchStreamAdapter::new( arrow_schema, - background_iterator(batches).map_err(DataFusionError::from), + BackgroundIterator::new(batches) + .fuse() + .map_err(DataFusionError::from), ); Box::pin(stream) } diff --git a/rust/lance-datafusion/src/utils/background_iterator.rs b/rust/lance-datafusion/src/utils/background_iterator.rs new file mode 100644 index 00000000000..d9f0458718e --- /dev/null +++ b/rust/lance-datafusion/src/utils/background_iterator.rs @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use futures::ready; +use futures::Stream; +use std::{ + future::Future, + panic, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::task::JoinHandle; + +/// Wrap an iterator as a stream that executes the iterator in a background +/// blocking thread. +/// +/// The size hint is preserved, but the stream is not fused. +#[pin_project::pin_project] +pub struct BackgroundIterator { + #[pin] + state: BackgroundIterState, +} + +impl BackgroundIterator { + pub fn new(iter: I) -> Self { + Self { + state: BackgroundIterState::Current { iter }, + } + } +} + +impl Stream for BackgroundIterator +where + I::Item: Send + 'static, +{ + type Item = I::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if let Some(mut iter) = this.state.as_mut().take_iter() { + this.state.set(BackgroundIterState::Running { + size_hint: iter.size_hint(), + task: tokio::task::spawn_blocking(move || { + let next = iter.next(); + next.map(|next| (iter, next)) + }), + }); + } + + let step = match this.state.as_mut().project_future() { + Some(task) => ready!(task.poll(cx)), + None => panic!( + "BackgroundIterator must not be polled after it returned `Poll::Ready(None)`" + ), + }; + + match step { + Ok(Some((iter, next))) => { + this.state.set(BackgroundIterState::Current { iter }); + Poll::Ready(Some(next)) + } + Ok(None) => { + this.state.set(BackgroundIterState::Empty); + Poll::Ready(None) + } + Err(err) => { + if err.is_panic() { + // Resume the panic on the main task + panic::resume_unwind(err.into_panic()); + } else { + panic!("Background task failed: {:?}", err); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.state { + BackgroundIterState::Current { iter } => iter.size_hint(), + BackgroundIterState::Running { size_hint, .. } => *size_hint, + BackgroundIterState::Empty => (0, Some(0)), + } + } +} + +// Inspired by Unfold implementation: https://github.com/rust-lang/futures-rs/blob/master/futures-util/src/unfold_state.rs#L22 +#[pin_project::pin_project(project = StateProj, project_replace = StateReplace)] +enum BackgroundIterState { + Current { + iter: I, + }, + Running { + size_hint: (usize, Option), + #[pin] + task: NextHandle, + }, + Empty, +} + +type NextHandle = JoinHandle>; + +impl BackgroundIterState { + fn project_future(self: Pin<&mut Self>) -> Option>> { + match self.project() { + StateProj::Running { task, .. } => Some(task), + _ => None, + } + } + + fn take_iter(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Current { .. } => match self.project_replace(Self::Empty) { + StateReplace::Current { iter } => Some(iter), + _ => None, + }, + _ => None, + } + } +} diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 6504ae607c5..924a673032f 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -28,6 +28,7 @@ lance-table = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ipc = { workspace = true } arrow-ord = { workspace = true } arrow-row = { workspace = true } arrow-schema = { workspace = true } @@ -61,6 +62,7 @@ datafusion.workspace = true datafusion-functions.workspace = true datafusion-physical-expr.workspace = true datafusion-expr.workspace = true +either.workspace = true lapack = { version = "0.19.0", optional = true } snafu = { workspace = true } log = { workspace = true } diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index 948e2157a6d..7aab0cec806 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -199,6 +199,25 @@ pub enum Operation { }, } +impl std::fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Append { .. } => write!(f, "Append"), + Self::Delete { .. } => write!(f, "Delete"), + Self::Overwrite { .. } => write!(f, "Overwrite"), + Self::CreateIndex { .. } => write!(f, "CreateIndex"), + Self::Rewrite { .. } => write!(f, "Rewrite"), + Self::Merge { .. } => write!(f, "Merge"), + Self::Restore { .. } => write!(f, "Restore"), + Self::ReserveFragments { .. } => write!(f, "ReserveFragments"), + Self::Update { .. } => write!(f, "Update"), + Self::Project { .. } => write!(f, "Project"), + Self::UpdateConfig { .. } => write!(f, "UpdateConfig"), + Self::DataReplacement { .. } => write!(f, "DataReplacement"), + } + } +} + #[derive(Debug, Clone)] pub struct RewrittenIndex { pub old_id: Uuid, @@ -366,6 +385,24 @@ impl Operation { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ConflictResult { + /// The operation is compatible with the other operation + /// + /// For example, two operations that modify different fragments are compatible. + Compatible, + /// The operation is not compatible with the other operation + /// + /// For example, an Overwrite with a change in schema and an Append are + /// not compatible. + NotCompatible, + /// The operation is not compatible, but the current operation can be + /// retried on top of the others changes. + /// + /// For example, two operations that modify the same fragment. + Retryable, +} + impl Transaction { pub fn new( read_version: u64, @@ -385,91 +422,144 @@ impl Transaction { /// Returns true if the transaction cannot be committed if the other /// transaction is committed first. - pub fn conflicts_with(&self, other: &Self) -> bool { + pub fn conflicts_with(&self, other: &Self) -> ConflictResult { + use ConflictResult::*; // This assumes IsolationLevel is Snapshot Isolation, which is more // permissive than Serializable. In particular, it allows a Delete // transaction to succeed after a concurrent Append, even if the Append // added rows that would be deleted. match &self.operation { Operation::Append { .. } => match &other.operation { - // Append is compatible with anything that doesn't change the schema - Operation::Append { .. } => false, - Operation::Rewrite { .. } => false, - Operation::CreateIndex { .. } => false, - Operation::Delete { .. } | Operation::Update { .. } => false, - Operation::ReserveFragments { .. } => false, - Operation::Project { .. } => false, - Operation::UpdateConfig { .. } => false, - Operation::DataReplacement { .. } => false, - _ => true, + Operation::Append { .. } + | Operation::Rewrite { .. } + | Operation::CreateIndex { .. } + | Operation::Delete { .. } + | Operation::Update { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::Merge { .. } + | Operation::UpdateConfig { .. } + | Operation::DataReplacement { .. } => Compatible, + // Append is not compatible with any operation that completely + // overwrites the schema. + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::Rewrite { .. } => match &other.operation { // Rewrite is only compatible with operations that don't touch - // existing fragments. - // TODO: it could also be compatible with operations that update - // fragments we don't touch. - Operation::Append { .. } => false, - Operation::ReserveFragments { .. } => false, + // existing fragments or update fragments we don't touch. + Operation::Append { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::UpdateConfig { .. } => Compatible, Operation::Delete { .. } | Operation::Rewrite { .. } | Operation::Update { .. } => { // As long as they rewrite disjoint fragments they shouldn't conflict. - self.operation.modifies_same_ids(&other.operation) + if self.operation.modifies_same_ids(&other.operation) { + Retryable + } else { + Compatible + } } - Operation::Project { .. } => false, - Operation::UpdateConfig { .. } => false, - Operation::DataReplacement { .. } => { + Operation::DataReplacement { .. } | Operation::Merge { .. } => { // TODO(rmeng): check that the fragments being replaced are not part of the groups - true + Retryable } - _ => true, + Operation::CreateIndex { new_indices, .. } => { + let mut affected_ids = HashSet::new(); + for index in new_indices { + if let Some(frag_bitmap) = &index.fragment_bitmap { + affected_ids.extend(frag_bitmap.iter()); + } else { + return Retryable; + } + } + if self + .operation + .modified_fragment_ids() + .any(|id| affected_ids.contains(&(id as u32))) + { + Retryable + } else { + Compatible + } + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, // Restore always succeeds - Operation::Restore { .. } => false, + Operation::Restore { .. } => Compatible, // ReserveFragments is compatible with anything that doesn't reset the // max fragment id. - Operation::ReserveFragments { .. } => matches!( - &other.operation, - Operation::Overwrite { .. } | Operation::Restore { .. } - ), - Operation::CreateIndex { .. } => match &other.operation { - Operation::Append { .. } => false, + Operation::ReserveFragments { .. } => match &other.operation { + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, + _ => Compatible, + }, + Operation::CreateIndex { new_indices, .. } => match &other.operation { + Operation::Append { .. } => Compatible, // Indices are identified by UUIDs, so they shouldn't conflict. - Operation::CreateIndex { .. } => false, + Operation::CreateIndex { .. } => Compatible, // Although some of the rows we indexed may have been deleted / moved, // row ids are still valid, so we allow this optimistically. - Operation::Delete { .. } | Operation::Update { .. } => false, - // Merge & reserve don't change row ids, so this should be fine. - Operation::Merge { .. } => false, - Operation::ReserveFragments { .. } => false, - // Rewrite likely changed many of the row ids, so our index is - // likely useless. It should be rebuilt. - // TODO: we could be smarter here and only invalidate the index - // if the rewrite changed more than X% of row ids. - Operation::Rewrite { .. } => true, - Operation::UpdateConfig { .. } => false, + Operation::Delete { .. } | Operation::Update { .. } => Compatible, + // Merge, reserve, and project don't change row ids, so this should be fine. + Operation::Merge { .. } => Compatible, + Operation::ReserveFragments { .. } => Compatible, + Operation::Project { .. } => Compatible, + // Should be compatible with rewrite if it didn't move the rows + // we indexed. If it did, we could retry. + // TODO: this will change with stable row ids. + Operation::Rewrite { .. } => { + let mut affected_ids = HashSet::new(); + for index in new_indices { + if let Some(frag_bitmap) = &index.fragment_bitmap { + affected_ids.extend(frag_bitmap.iter()); + } else { + return Retryable; + } + } + if other + .operation + .modified_fragment_ids() + .any(|id| affected_ids.contains(&(id as u32))) + { + Retryable + } else { + Compatible + } + } + Operation::UpdateConfig { .. } => Compatible, Operation::DataReplacement { .. } => { // TODO(rmeng): check that the new indices isn't on the column being replaced - true + Retryable } - _ => true, + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::Delete { .. } | Operation::Update { .. } => match &other.operation { - Operation::CreateIndex { .. } => false, - Operation::ReserveFragments { .. } => false, - Operation::Delete { .. } | Operation::Rewrite { .. } | Operation::Update { .. } => { + Operation::CreateIndex { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::Append { .. } + | Operation::UpdateConfig { .. } => Compatible, + Operation::Delete { .. } + | Operation::Rewrite { .. } + | Operation::Update { .. } + | Operation::DataReplacement { .. } => { // If we update the same fragments, we conflict. - self.operation.modifies_same_ids(&other.operation) + if self.operation.modifies_same_ids(&other.operation) { + Retryable + } else { + Compatible + } } - Operation::Project { .. } => false, - Operation::Append { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::Merge { .. } => Retryable, + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::Overwrite { .. } => match &other.operation { // Overwrite only conflicts with another operation modifying the same update config - Operation::Overwrite { .. } | Operation::UpdateConfig { .. } => { - self.operation.upsert_key_conflict(&other.operation) + Operation::Overwrite { .. } | Operation::UpdateConfig { .. } + if self.operation.upsert_key_conflict(&other.operation) => + { + NotCompatible } - _ => false, + _ => Compatible, }, Operation::UpdateConfig { schema_metadata, @@ -479,52 +569,79 @@ impl Transaction { Operation::Overwrite { .. } => { // Updates to schema metadata or field metadata conflict with any kind // of overwrite. - if schema_metadata.is_some() || field_metadata.is_some() { - true + if schema_metadata.is_some() + || field_metadata.is_some() + || self.operation.upsert_key_conflict(&other.operation) + { + NotCompatible } else { - self.operation.upsert_key_conflict(&other.operation) + Compatible } } Operation::UpdateConfig { .. } => { - self.operation.upsert_key_conflict(&other.operation) + if self.operation.upsert_key_conflict(&other.operation) | self.operation.modifies_same_metadata(&other.operation) + { + NotCompatible + } else { + Compatible + } } - _ => false, + _ => Compatible, }, // Merge changes the schema, but preserves row ids, so the only operations // it's compatible with is CreateIndex, ReserveFragments, SetMetadata and DeleteMetadata. - Operation::Merge { .. } => !matches!( - &other.operation, + Operation::Merge { .. } => match &other.operation { Operation::CreateIndex { .. } - | Operation::ReserveFragments { .. } - | Operation::UpdateConfig { .. } - ), + | Operation::ReserveFragments { .. } + | Operation::UpdateConfig { .. } => Compatible, + Operation::Update { .. } + | Operation::Append { .. } + | Operation::Delete { .. } + | Operation::Rewrite { .. } + | Operation::Merge { .. } + | Operation::DataReplacement { .. } => Retryable, + Operation::Overwrite { .. } + | Operation::Restore { .. } + | Operation::Project { .. } => NotCompatible, + }, Operation::Project { .. } => match &other.operation { // Project is compatible with anything that doesn't change the schema - Operation::CreateIndex { .. } => false, - Operation::Overwrite { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::Append { .. } + | Operation::Update { .. } + | Operation::Delete { .. } + | Operation::UpdateConfig { .. } + | Operation::CreateIndex { .. } + | Operation::DataReplacement { .. } + | Operation::Rewrite { .. } + | Operation::ReserveFragments { .. } => Compatible, + Operation::Merge { .. } | Operation::Project { .. } => { + // Need to recompute the schema + Retryable + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::DataReplacement { .. } => match &other.operation { Operation::Append { .. } | Operation::Delete { .. } | Operation::Update { .. } | Operation::Merge { .. } - | Operation::UpdateConfig { .. } => false, + | Operation::UpdateConfig { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } => Compatible, Operation::CreateIndex { .. } => { // TODO(rmeng): check that the new indices isn't on the column being replaced - true + NotCompatible } Operation::Rewrite { .. } => { // TODO(rmeng): check that the fragments being replaced are not part of the groups - true + NotCompatible } Operation::DataReplacement { .. } => { // TODO(rmeng): check cell conflicts - true + NotCompatible } - _ => true, + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, } } @@ -1735,6 +1852,8 @@ mod tests { #[test] fn test_conflicts() { + use ConflictResult::*; + let index0 = Index { uuid: uuid::Uuid::new_v4(), name: "test".to_string(), @@ -1813,7 +1932,17 @@ mod tests { Operation::Append { fragments: vec![fragment0.clone()], }, - [false, false, false, true, true, false, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( Operation::Delete { @@ -1822,7 +1951,17 @@ mod tests { deleted_fragment_ids: vec![], predicate: "x > 2".to_string(), }, - [false, false, false, true, true, false, false, true, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Retryable, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Delete { @@ -1831,7 +1970,17 @@ mod tests { deleted_fragment_ids: vec![], predicate: "x > 2".to_string(), }, - [false, false, true, true, true, true, false, true, false], + [ + Compatible, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Overwrite { @@ -1841,9 +1990,7 @@ mod tests { }, // No conflicts: overwrite can always happen since it doesn't // depend on previous state of the table. - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( Operation::CreateIndex { @@ -1851,7 +1998,17 @@ mod tests { removed_indices: vec![index0], }, // Will only conflict with operations that modify row ids. - [false, false, false, false, true, true, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( // Rewrite that affects different fragments @@ -1862,7 +2019,17 @@ mod tests { }], rewritten_indices: Vec::new(), }, - [false, true, false, true, true, false, false, true, false], + [ + Compatible, // append + Retryable, // create index + Compatible, // delete + Retryable, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( // Rewrite that affects the same fragments @@ -1873,7 +2040,17 @@ mod tests { }], rewritten_indices: Vec::new(), }, - [false, true, true, true, true, true, false, true, false], + [ + Compatible, // append + Retryable, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Merge { @@ -1881,12 +2058,32 @@ mod tests { schema: Schema::default(), }, // Merge conflicts with everything except CreateIndex and ReserveFragments. - [true, false, true, true, true, true, false, true, false], + [ + Retryable, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::ReserveFragments { num_fragments: 2 }, // ReserveFragments only conflicts with Overwrite and Restore. - [false, false, false, false, true, false, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( Operation::Update { @@ -1895,7 +2092,17 @@ mod tests { removed_fragment_ids: vec![], new_fragments: vec![fragment2], }, - [false, false, true, true, true, true, false, true, false], + [ + Compatible, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( // Update config that should not conflict with anything @@ -1908,9 +2115,7 @@ mod tests { schema_metadata: None, field_metadata: None, }, - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( // Update config that conflicts with key being upserted by other UpdateConfig operation @@ -1923,7 +2128,17 @@ mod tests { schema_metadata: None, field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Update config that conflicts with key being deleted by other UpdateConfig operation @@ -1936,7 +2151,17 @@ mod tests { schema_metadata: None, field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Delete config keys currently being deleted by other UpdateConfig operation @@ -1946,9 +2171,7 @@ mod tests { schema_metadata: None, field_metadata: None, }, - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( // Delete config keys currently being upserted by other UpdateConfig operation @@ -1958,7 +2181,17 @@ mod tests { schema_metadata: None, field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Changing schema metadata conflicts with another update changing schema @@ -1972,7 +2205,17 @@ mod tests { )])), field_metadata: None, }, - [false, false, false, false, true, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Changing field metadata conflicts with another update changing same field @@ -1989,7 +2232,17 @@ mod tests { )]), )])), }, - [false, false, false, false, true, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Updates to different field metadata are allowed @@ -2005,7 +2258,17 @@ mod tests { )]), )])), }, - [false, false, false, false, true, false, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ]; @@ -2015,13 +2278,9 @@ mod tests { assert_eq!( transaction.conflicts_with(other), *expected_conflict, - "Transaction {:?} should {} with {:?}", + "Transaction {:?} should {:?} with {:?}", transaction, - if *expected_conflict { - "conflict" - } else { - "not conflict" - }, + expected_conflict, other ); } diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 57e85f7635f..5f90691c898 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -5,14 +5,17 @@ use std::sync::Arc; use arrow_array::RecordBatch; use chrono::TimeDelta; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::SendableRecordBatchStream; -use futures::{StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use lance_core::datatypes::{ NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions, StorageClass, }; +use lance_core::error::LanceOptionExt; use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DATA, TRACE_FILE_AUDIT}; use lance_core::{datatypes::Schema, Error, Result}; use lance_datafusion::chunker::{break_stream, chunk_stream}; +use lance_datafusion::spill::{create_replay_spill, SpillReceiver, SpillSender}; use lance_datafusion::utils::StreamingWriteSource; use lance_file::v2; use lance_file::v2::writer::FileWriterOptions; @@ -644,6 +647,112 @@ async fn resolve_commit_handler( } } +/// Create an iterator of record batch streams from the given source. +/// +/// If `enable_retries` is true, then the source will be saved either in memory +/// or spilled to disk to allow replaying the source in case of a failure. The +/// source will be kept in memory if either (1) the size hint shows that +/// there is only one batch or (2) the stream contains less than 100MB of +/// data. Otherwise, the source will be spilled to a temporary file on disk. +/// +/// This is used to support retries on write operations. +async fn new_source_iter( + source: SendableRecordBatchStream, + enable_retries: bool, +) -> Result + Send + 'static>> { + if enable_retries { + let schema = source.schema(); + + // If size hint shows there is only one batch, spilling has no benefit, just keep that + // in memory. (This is a pretty common case.) + let size_hint = source.size_hint(); + if size_hint.0 == 1 && size_hint.1 == Some(1) { + let batches: Vec = source.try_collect().await?; + Ok(Box::new(std::iter::repeat_with(move || { + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::iter(batches.clone().into_iter().map(Ok)), + )) as SendableRecordBatchStream + }))) + } else { + // Allow buffering up to 100MB in memory before spilling to disk. + Ok(Box::new( + SpillStreamIter::try_new(source, 100 * 1024 * 1024).await?, + )) + } + } else { + Ok(Box::new(std::iter::once(source))) + } +} + +struct SpillStreamIter { + receiver: SpillReceiver, + #[allow(dead_code)] // Exists to keep the SpillSender alive + sender_handle: tokio::task::JoinHandle, + // This temp dir is used to store the spilled data. It is kept alive by + // this struct. When this struct is dropped, the Drop implementation of + // tempfile::TempDir will delete the temp dir. + #[allow(dead_code)] // Exists to keep the temp dir alive + tmp_dir: tempfile::TempDir, +} + +impl SpillStreamIter { + pub async fn try_new( + mut source: SendableRecordBatchStream, + memory_limit: usize, + ) -> Result { + let tmp_dir = tokio::task::spawn_blocking(|| { + tempfile::tempdir().map_err(|e| Error::InvalidInput { + source: format!("Failed to create temp dir: {}", e).into(), + location: location!(), + }) + }) + .await + .ok() + .expect_ok()??; + + let tmp_path = tmp_dir.path().join("spill.arrows"); + let (mut sender, receiver) = create_replay_spill(tmp_path, source.schema(), memory_limit); + + let sender_handle = tokio::task::spawn(async move { + while let Some(res) = source.next().await { + match res { + Ok(batch) => match sender.write(batch).await { + Ok(_) => {} + Err(e) => { + sender.send_error(e); + break; + } + }, + Err(e) => { + sender.send_error(e); + break; + } + } + } + + if let Err(err) = sender.finish().await { + sender.send_error(err); + } + sender + }); + + Ok(Self { + receiver, + tmp_dir, + sender_handle, + }) + } +} + +impl Iterator for SpillStreamIter { + type Item = SendableRecordBatchStream; + + fn next(&mut self) -> Option { + Some(self.receiver.read()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 387e90e5fef..57ac706794f 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -58,7 +58,7 @@ use futures::{ use lance_core::{ datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions}, error::{box_error, InvalidInputSnafu}, - utils::{futures::Capacity, tokio::get_num_compute_intensive_cpus}, + utils::{backoff::Backoff, futures::Capacity, tokio::get_num_compute_intensive_cpus}, Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, }; use lance_datafusion::{ @@ -224,10 +224,12 @@ struct MergeInsertParams { insert_not_matched: bool, // Controls whether data that is not matched by the source is deleted or not delete_not_matched_by_source: WhenNotMatchedBySource, + conflict_retries: u32, } /// A MergeInsertJob inserts new rows, deletes old rows, and updates existing rows all as /// part of a single transaction. +#[derive(Clone)] pub struct MergeInsertJob { // The column to merge the new data into dataset: Arc, @@ -299,6 +301,7 @@ impl MergeInsertBuilder { when_matched: WhenMatched::DoNothing, insert_not_matched: true, delete_not_matched_by_source: WhenNotMatchedBySource::Keep, + conflict_retries: 10, }, }) } @@ -328,6 +331,18 @@ impl MergeInsertBuilder { self } + /// Set number of times to retry the operation if there is contention. + /// + /// If this is set > 0, then the operation will keep a copy of the input data + /// either in memory or on disk (depending on the size of the data) and will + /// retry the operation if there is contention. + /// + /// Default is 10. + pub fn conflict_retries(&mut self, retries: u32) -> &mut Self { + self.params.conflict_retries = retries; + self + } + /// Crate a merge insert job pub fn try_build(&mut self) -> Result { if !self.params.insert_not_matched @@ -993,13 +1008,38 @@ impl MergeInsertJob { /// This will take in the source, merge it with the existing target data, and insert new /// rows, update existing rows, and delete existing rows pub async fn execute( - self, + mut self, source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { - let ds = self.dataset.clone(); - let (transaction, stats) = self.execute_uncommitted_impl(source).await?; - let dataset = CommitBuilder::new(ds).execute(transaction).await?; - Ok((Arc::new(dataset), stats)) + let mut source_iter = + super::new_source_iter(source, self.params.conflict_retries > 0).await?; + + let mut dataset_ref = self.dataset.clone(); + let max_retries = self.params.conflict_retries; + let mut backoff = Backoff::default(); + while backoff.attempt() <= max_retries { + let ds = dataset_ref.clone(); + let (transaction, stats) = self + .clone() + .execute_uncommitted_impl(source_iter.next().unwrap()) + .await?; + match CommitBuilder::new(ds).execute(transaction).await { + Ok(ds) => return Ok((Arc::new(ds), stats)), + Err(Error::RetryableCommitConflict { .. }) => { + tokio::time::sleep(backoff.next_backoff()).await; + let mut ds = dataset_ref.as_ref().clone(); + ds.checkout_latest().await?; + dataset_ref = Arc::new(ds); + self.dataset = dataset_ref.clone(); + continue; + } + Err(e) => return Err(e), + }; + } + Err(Error::TooMuchWriteContention { + message: format!("Attempted {} retries.", max_retries), + location: location!(), + }) } /// Execute the merge insert job without committing the changes. @@ -1442,12 +1482,14 @@ mod tests { }; use arrow_select::concat::concat_batches; use datafusion::common::Column; + use futures::future::try_join_all; use lance_datafusion::utils::reader_to_stream; use lance_datagen::{array, BatchCount, RowCount, Seed}; use lance_index::{scalar::ScalarIndexParams, IndexType}; use tempfile::tempdir; + use tokio::sync::{Barrier, Notify}; - use crate::dataset::{WriteMode, WriteParams}; + use crate::dataset::{builder::DatasetBuilder, InsertBuilder, WriteMode, WriteParams}; use super::*; @@ -2087,4 +2129,183 @@ mod tests { } } } + + #[tokio::test] + async fn test_merge_insert_concurrency() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])); + let num_rows = 10; + let initial_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..num_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + num_rows as usize, + ))), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = InsertBuilder::new(test_uri) + .execute(vec![initial_data]) + .await + .unwrap(); + + // do 10 merge inserts in parallel. Each will open the dataset, signal + // they have opened, and then wait for a signal to proceed. Once the signal + // is received, they will do a merge insert and close the dataset. + + let barrier = Arc::new(Barrier::new(10)); + let mut handles = Vec::new(); + for i in 0..10 { + let uri_ref = test_uri.to_string(); + let schema_ref = schema.clone(); + let barrier_ref = barrier.clone(); + let handle = tokio::task::spawn(async move { + let dataset = DatasetBuilder::from_uri(&uri_ref).load().await.unwrap(); + let dataset = Arc::new(dataset); + + let new_data = RecordBatch::try_new( + schema_ref.clone(), + vec![ + Arc::new(UInt32Array::from(vec![i])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema_ref.clone())); + + let job = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap(); + barrier_ref.wait().await; + + job.execute_reader(source).await.unwrap(); + }); + handles.push(handle); + } + + try_join_all(handles).await.unwrap(); + + dataset.checkout_latest().await.unwrap(); + let batches = dataset.scan().try_into_batch().await.unwrap(); + + let values = batches["value"].as_primitive::(); + assert!( + values.values().iter().all(|&v| v == 1), + "All values should be 1 after merge insert. Got: {:?}", + values + ); + } + + #[tokio::test] + async fn test_merge_insert_large_concurrent() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])); + let num_rows = 10; + let initial_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..num_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + num_rows as usize, + ))), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = InsertBuilder::new(test_uri) + .execute(vec![initial_data]) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + // Start one merge insert, but don't commit it yet. + let new_data1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let (transaction1, _stats) = + MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_uncommitted(RecordBatchIterator::new( + vec![Ok(new_data1)], + schema.clone(), + )) + .await + .unwrap(); + + // Setup a "large" merge insert, with many batches + let new_data2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..1000)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(2, 1000))), + ], + ) + .unwrap(); + let notify = Arc::new(Notify::new()); + let source = RecordBatchIterator::new( + (0..10) + .map(|i| { + let batch = new_data2.slice(i * 100, 100); + if i == 9 { + notify.notify_one(); + } + Ok(batch) + }) + .collect::>(), + schema.clone(), + ); + let dataset2 = DatasetBuilder::from_uri(test_uri).load().await.unwrap(); + let job = MergeInsertBuilder::try_new(Arc::new(dataset2), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_reader(source); + let task = tokio::task::spawn(job); + + // Right as the large merge insert has finished reading the last batch, + // we will commit the first merge insert. This should trigger a conflict, + // but we should resolve it automatically. + notify.notified().await; + let mut dataset = CommitBuilder::new(dataset) + .execute(transaction1) + .await + .unwrap(); + + task.await.unwrap().unwrap(); + dataset.checkout_latest().await.unwrap(); + + let batches = dataset.scan().try_into_batch().await.unwrap(); + let values = batches["value"].as_primitive::(); + assert!( + values.values().iter().all(|&v| v == 2), + "All values should be 1 after merge insert. Got: {:?}", + values + ); + } } diff --git a/rust/lance/src/io/commit.rs b/rust/lance/src/io/commit.rs index 649ae795f29..d96283ee80b 100644 --- a/rust/lance/src/io/commit.rs +++ b/rust/lance/src/io/commit.rs @@ -22,6 +22,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use lance_core::utils::backoff::Backoff; use lance_file::version::LanceFileVersion; use lance_index::metrics::NoOpMetricsCollector; use lance_table::format::{ @@ -44,7 +45,7 @@ use prost::Message; use super::ObjectStore; use crate::dataset::cleanup::auto_cleanup_hook; use crate::dataset::fragment::FileFragment; -use crate::dataset::transaction::{Operation, Transaction}; +use crate::dataset::transaction::{ConflictResult, Operation, Transaction}; use crate::dataset::{write_manifest_file, ManifestWriteConfig, BLOB_DIR}; use crate::index::DatasetIndexInternalExt; use crate::session::Session; @@ -127,7 +128,7 @@ fn check_transaction( other_version: u64, other_transaction: Option<&Transaction>, ) -> Result<()> { - if other_transaction.is_none() { + let Some(other_transaction) = other_transaction else { return Err(crate::Error::Internal { message: format!( "There was a conflicting transaction at version {}, \ @@ -136,23 +137,28 @@ fn check_transaction( ), location: location!(), }); - } + }; - if transaction.conflicts_with(other_transaction.as_ref().unwrap()) { - return Err(crate::Error::CommitConflict { - version: other_version, - source: format!( - "There was a concurrent commit that conflicts with this one and it \ - cannot be automatically resolved. Please rerun the operation off the latest version \ - of the table.\n Transaction: {:?}\n Conflicting Transaction: {:?}", - transaction, other_transaction - ) - .into(), - location: location!(), - }); + match transaction.conflicts_with(other_transaction) { + ConflictResult::Compatible => Ok(()), + ConflictResult::NotCompatible => { + Err(crate::Error::CommitConflict { + version: other_version, + source: format!( + "This {} transaction is incompatible with concurrent transaction {} at version {}.", + transaction.operation, other_transaction.operation, other_version).into(), + location: location!(), + }) + }, + ConflictResult::Retryable => { + Err(crate::Error::RetryableCommitConflict { + version: other_version, + source: format!( + "This {} transaction was preempted by concurrent transaction {} at version {}. Please retry.", + transaction.operation, other_transaction.operation, other_version).into(), + location: location!() }) + } } - - Ok(()) } #[allow(clippy::too_many_arguments)] @@ -579,7 +585,8 @@ pub(crate) async fn do_commit_detached_transaction( let transaction_file = write_transaction_file(object_store, &dataset.base, transaction).await?; // We still do a loop since we may have conflicts in the random version we pick - for attempt_i in 0..commit_config.num_retries { + let mut backoff = Backoff::default(); + while backoff.attempt() < commit_config.num_retries { // Pick a random u64 with the highest bit set to indicate it is detached let random_version = thread_rng().gen::() | DETACHED_VERSION_MASK; @@ -637,9 +644,7 @@ pub(crate) async fn do_commit_detached_transaction( Err(CommitError::CommitConflict) => { // We pick a random u64 for the version, so it's possible (though extremely unlikely) // that we have a conflict. In that case, we just try again. - - let backoff_time = backoff_time(attempt_i); - tokio::time::sleep(backoff_time).await; + tokio::time::sleep(backoff.next_backoff()).await; } Err(CommitError::OtherError(err)) => { // If other error, return @@ -746,7 +751,8 @@ pub(crate) async fn commit_transaction( dataset.checkout_latest().await?; } let num_attempts = std::cmp::max(commit_config.num_retries, 1); - for attempt_i in 0..num_attempts { + let mut backoff = Backoff::default(); + while backoff.attempt() < num_attempts { // See if we can retry the commit. Try to account for all // transactions that have been committed since the read_version. // Use small amount of backoff to handle transactions that all @@ -759,7 +765,7 @@ pub(crate) async fn commit_transaction( .buffer_unordered(dataset.object_store().io_parallelism()) .take_while(|res| { futures::future::ready( - attempt_i > 0 + backoff.attempt() > 0 || !matches!( res, Err(crate::Error::NotFound { .. }) @@ -848,9 +854,9 @@ pub(crate) async fn commit_transaction( return Ok((manifest, manifest_location.path, manifest_location.e_tag)); } Err(CommitError::CommitConflict) => { - let next_attempt_i = attempt_i + 1; + let next_attempt_i = backoff.attempt() + 1; if next_attempt_i < num_attempts { - tokio::time::sleep(backoff_time(next_attempt_i)).await; + tokio::time::sleep(backoff.next_backoff()).await; dataset.checkout_latest().await?; } } @@ -872,18 +878,6 @@ pub(crate) async fn commit_transaction( }) } -fn backoff_time(attempt_i: u32) -> std::time::Duration { - // Exponential base: - // 100ms, 200ms, 400ms, 800ms, 1600ms, 3200ms, 6400ms - let backoff = 2_i32.pow(attempt_i) * 100; - // With +-100ms jitter - let jitter = rand::thread_rng().gen_range(-100..100); - let backoff = backoff + jitter; - // No more than 5 seconds and less than 10ms. - let backoff = backoff.clamp(10, 5_000) as u64; - std::time::Duration::from_millis(backoff) -} - #[cfg(test)] mod tests { use std::sync::Mutex;