diff --git a/lib/src/file_util.rs b/lib/src/file_util.rs new file mode 100644 index 0000000000..c52385cc35 --- /dev/null +++ b/lib/src/file_util.rs @@ -0,0 +1,73 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fs::File; +use std::path::Path; + +use tempfile::{NamedTempFile, PersistError}; + +// Like NamedTempFile::persist(), but also succeeds if the target already +// exists. +pub fn persist_temp_file>( + temp_file: NamedTempFile, + new_path: P, +) -> Result { + match temp_file.persist(&new_path) { + Ok(file) => Ok(file), + Err(PersistError { error, file }) => { + if let Ok(existing_file) = File::open(new_path) { + Ok(existing_file) + } else { + Err(PersistError { error, file }) + } + } + } +} + +#[cfg(test)] +mod tests { + + use std::env::temp_dir; + use std::io::Write; + + use test_case::test_case; + + use super::*; + + #[test] + fn test_persist_no_existing_file() { + let temp_dir = temp_dir(); + let target = temp_dir.join("file"); + let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap(); + temp_file.write_all(b"contents").unwrap(); + assert!(persist_temp_file(temp_file, &target).is_ok()); + } + + #[test_case(false ; "existing file open")] + #[test_case(true ; "existing file closed")] + fn test_persist_target_exists(existing_file_closed: bool) { + let temp_dir = temp_dir(); + let target = temp_dir.join("file"); + let mut temp_file = NamedTempFile::new_in(&temp_dir).unwrap(); + temp_file.write_all(b"contents").unwrap(); + + let mut file = File::create(&target).unwrap(); + file.write_all(b"contents").unwrap(); + if existing_file_closed { + drop(file); + } + + assert!(persist_temp_file(temp_file, &target).is_ok()); + } +} diff --git a/lib/src/index.rs b/lib/src/index.rs index a7b46758ff..bb6d492d8e 100644 --- a/lib/src/index.rs +++ b/lib/src/index.rs @@ -31,6 +31,7 @@ use itertools::Itertools; use tempfile::NamedTempFile; use crate::commit::Commit; +use crate::file_util::persist_temp_file; use crate::store::{ChangeId, CommitId}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] @@ -617,7 +618,7 @@ impl MutableIndex { let mut temp_file = NamedTempFile::new_in(&dir)?; let file = temp_file.as_file_mut(); file.write_all(&buf).unwrap(); - temp_file.persist(&index_file_path)?; + persist_temp_file(temp_file, &index_file_path)?; let mut cursor = Cursor::new(&buf); ReadonlyIndex::load_from(&mut cursor, dir, index_file_id_hex, hash_length) diff --git a/lib/src/index_store.rs b/lib/src/index_store.rs index 11c9a9beee..f3f375eada 100644 --- a/lib/src/index_store.rs +++ b/lib/src/index_store.rs @@ -24,6 +24,7 @@ use tempfile::NamedTempFile; use crate::commit::Commit; use crate::dag_walk; +use crate::file_util::persist_temp_file; use crate::index::{MutableIndex, ReadonlyIndex}; use crate::op_store::OperationId; use crate::operation::Operation; @@ -151,7 +152,7 @@ impl IndexStore { let mut temp_file = NamedTempFile::new_in(&self.dir)?; let file = temp_file.as_file_mut(); file.write_all(&index.name().as_bytes()).unwrap(); - temp_file.persist(&self.dir.join("operations").join(op_id.hex()))?; + persist_temp_file(temp_file, &self.dir.join("operations").join(op_id.hex()))?; Ok(()) } } diff --git a/lib/src/lib.rs b/lib/src/lib.rs index e32f1d5844..dc588f1216 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -30,6 +30,7 @@ pub mod conflicts; pub mod dag_walk; pub mod diff; pub mod evolution; +pub mod file_util; pub mod files; pub mod git; pub mod git_store; diff --git a/lib/src/local_store.rs b/lib/src/local_store.rs index c95a9cda89..e05acf6ce5 100644 --- a/lib/src/local_store.rs +++ b/lib/src/local_store.rs @@ -22,6 +22,7 @@ use blake2::{Blake2b, Digest}; use protobuf::{Message, ProtobufError}; use tempfile::{NamedTempFile, PersistError}; +use crate::file_util::persist_temp_file; use crate::repo_path::{RepoPath, RepoPathComponent}; use crate::store::{ ChangeId, Commit, CommitId, Conflict, ConflictId, ConflictPart, FileId, MillisSinceEpoch, @@ -140,7 +141,7 @@ impl Store for LocalStore { encoder.finish()?; let id = FileId(hasher.finalize().to_vec()); - temp_file.persist(self.file_path(&id))?; + persist_temp_file(temp_file, self.file_path(&id))?; Ok(id) } @@ -159,7 +160,7 @@ impl Store for LocalStore { hasher.update(&target.as_bytes()); let id = SymlinkId(hasher.finalize().to_vec()); - temp_file.persist(self.symlink_path(&id))?; + persist_temp_file(temp_file, self.symlink_path(&id))?; Ok(id) } @@ -186,7 +187,7 @@ impl Store for LocalStore { let id = TreeId(Blake2b::digest(&proto_bytes).to_vec()); - temp_file.persist(self.tree_path(&id))?; + persist_temp_file(temp_file, self.tree_path(&id))?; Ok(id) } @@ -209,7 +210,7 @@ impl Store for LocalStore { let id = CommitId(Blake2b::digest(&proto_bytes).to_vec()); - temp_file.persist(self.commit_path(&id))?; + persist_temp_file(temp_file, self.commit_path(&id))?; Ok(id) } @@ -232,7 +233,7 @@ impl Store for LocalStore { let id = ConflictId(Blake2b::digest(&proto_bytes).to_vec()); - temp_file.persist(self.conflict_path(&id))?; + persist_temp_file(temp_file, self.conflict_path(&id))?; Ok(id) } } diff --git a/lib/src/simple_op_store.rs b/lib/src/simple_op_store.rs index 16418ca765..dda2f1c11f 100644 --- a/lib/src/simple_op_store.rs +++ b/lib/src/simple_op_store.rs @@ -22,6 +22,7 @@ use blake2::{Blake2b, Digest}; use protobuf::{Message, ProtobufError}; use tempfile::{NamedTempFile, PersistError}; +use crate::file_util::persist_temp_file; use crate::op_store::{ OpStore, OpStoreError, OpStoreResult, Operation, OperationId, OperationMetadata, View, ViewId, }; @@ -98,7 +99,7 @@ impl OpStore for SimpleOpStore { let id = ViewId(Blake2b::digest(&proto_bytes).to_vec()); - temp_file.persist(self.view_path(&id))?; + persist_temp_file(temp_file, self.view_path(&id))?; Ok(id) } @@ -121,7 +122,7 @@ impl OpStore for SimpleOpStore { let id = OperationId(Blake2b::digest(&proto_bytes).to_vec()); - temp_file.persist(self.operation_path(&id))?; + persist_temp_file(temp_file, self.operation_path(&id))?; Ok(id) } } diff --git a/lib/src/working_copy.rs b/lib/src/working_copy.rs index bce6686e4d..1f1c2dd9d8 100644 --- a/lib/src/working_copy.rs +++ b/lib/src/working_copy.rs @@ -33,6 +33,7 @@ use thiserror::Error; use crate::commit::Commit; use crate::commit_builder::CommitBuilder; +use crate::file_util::persist_temp_file; use crate::gitignore::GitIgnoreFile; use crate::lock::FileLock; use crate::matchers::EverythingMatcher; @@ -237,9 +238,7 @@ impl TreeState { // there is no unknown data in it self.update_read_time(); proto.write_to_writer(temp_file.as_file_mut()).unwrap(); - temp_file - .persist(self.state_path.join("tree_state")) - .unwrap(); + persist_temp_file(temp_file, self.state_path.join("tree_state")).unwrap(); } fn file_state(&self, path: &Path) -> Option { @@ -643,7 +642,7 @@ impl WorkingCopy { fn write_proto(&self, proto: crate::protos::working_copy::Checkout) { let mut temp_file = NamedTempFile::new_in(&self.state_path).unwrap(); proto.write_to_writer(temp_file.as_file_mut()).unwrap(); - temp_file.persist(self.state_path.join("checkout")).unwrap(); + persist_temp_file(temp_file, self.state_path.join("checkout")).unwrap(); } fn read_proto(&self) -> crate::protos::working_copy::Checkout {