Skip to content

Commit

Permalink
[ENH] Make TestSysDb thread safe (#2010)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- This PR makes TestSysDb thread safe and improves unit tests to take
advantage of that.
 - New functionality
	 - ...

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara authored Apr 12, 2024
1 parent 663a02d commit df29a41
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
7 changes: 0 additions & 7 deletions rust/worker/src/compactor/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ impl Scheduler {
}
}
}

filtered_collections
}

Expand Down Expand Up @@ -166,11 +165,6 @@ impl Scheduler {
pub(crate) fn set_memberlist(&mut self, memberlist: Memberlist) {
self.memberlist = Some(memberlist);
}

// For testing
pub(crate) fn set_sysdb(&mut self, sysdb: Box<dyn SysDb>) {
self.sysdb = sysdb;
}
}

#[cfg(test)]
Expand Down Expand Up @@ -303,7 +297,6 @@ mod tests {

let last_compaction_time_2 = 1;
sysdb.add_tenant_last_compaction_time(tenant_2, last_compaction_time_2);
scheduler.set_sysdb(sysdb.clone());
scheduler.schedule().await;
let jobs = scheduler.get_jobs();
let jobs = jobs.collect::<Vec<&CompactionJob>>();
Expand Down
21 changes: 20 additions & 1 deletion rust/worker/src/execution/operators/flush_sysdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ mod tests {
log_position,
collection_version,
segment_flush_info.into(),
sysdb,
sysdb.clone(),
);

let result = operator.run(&input).await;
Expand All @@ -196,5 +196,24 @@ mod tests {
let result = result.unwrap();
assert_eq!(result.result.collection_id, collection_uuid_1.to_string());
assert_eq!(result.result.collection_version, collection_version + 1);

let collections = sysdb
.get_collections(Some(collection_uuid_1), None, None, None)
.await;

assert!(collections.is_ok());
let collection = collections.unwrap();
assert_eq!(collection.len(), 1);
let collection = collection[0].clone();
assert_eq!(collection.log_position, log_position);

let segments = sysdb.get_segments(None, None, None, None).await;
assert!(segments.is_ok());
let segments = segments.unwrap();
assert_eq!(segments.len(), 2);
let segment_1 = segments.iter().find(|s| s.id == segment_id_1).unwrap();
assert_eq!(segment_1.file_path, file_path_3);
let segment_2 = segments.iter().find(|s| s.id == segment_id_2).unwrap();
assert_eq!(segment_2.file_path, file_path_4);
}
}
44 changes: 30 additions & 14 deletions rust/worker/src/sysdb/test_sysdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::types::SegmentScope;
use crate::types::SegmentType;
use crate::types::Tenant;
use async_trait::async_trait;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
Expand All @@ -18,6 +19,11 @@ use super::sysdb::GetLastCompactionTimeError;

#[derive(Clone, Debug)]
pub(crate) struct TestSysDb {
inner: Arc<Mutex<Inner>>,
}

#[derive(Debug)]
struct Inner {
collections: HashMap<Uuid, Collection>,
segments: HashMap<Uuid, Segment>,
tenant_last_compaction_time: HashMap<String, i64>,
Expand All @@ -26,26 +32,32 @@ pub(crate) struct TestSysDb {
impl TestSysDb {
pub(crate) fn new() -> Self {
TestSysDb {
collections: HashMap::new(),
segments: HashMap::new(),
tenant_last_compaction_time: HashMap::new(),
inner: Arc::new(Mutex::new(Inner {
collections: HashMap::new(),
segments: HashMap::new(),
tenant_last_compaction_time: HashMap::new(),
})),
}
}

pub(crate) fn add_collection(&mut self, collection: Collection) {
self.collections.insert(collection.id, collection);
let mut inner = self.inner.lock();
inner.collections.insert(collection.id, collection);
}

pub(crate) fn add_segment(&mut self, segment: Segment) {
self.segments.insert(segment.id, segment);
let mut inner = self.inner.lock();
inner.segments.insert(segment.id, segment);
}

pub(crate) fn add_tenant_last_compaction_time(
&mut self,
tenant: String,
last_compaction_time: i64,
) {
self.tenant_last_compaction_time
let mut inner = self.inner.lock();
inner
.tenant_last_compaction_time
.insert(tenant, last_compaction_time);
}

Expand Down Expand Up @@ -112,8 +124,9 @@ impl SysDb for TestSysDb {
tenant: Option<String>,
database: Option<String>,
) -> Result<Vec<Collection>, GetCollectionsError> {
let inner = self.inner.lock();
let mut collections = Vec::new();
for collection in self.collections.values() {
for collection in inner.collections.values() {
if !TestSysDb::filter_collections(
&collection,
collection_id,
Expand All @@ -135,8 +148,9 @@ impl SysDb for TestSysDb {
scope: Option<SegmentScope>,
collection: Option<Uuid>,
) -> Result<Vec<Segment>, GetSegmentsError> {
let inner = self.inner.lock();
let mut segments = Vec::new();
for segment in self.segments.values() {
for segment in inner.segments.values() {
if !TestSysDb::filter_segments(&segment, id, r#type.clone(), scope.clone(), collection)
{
continue;
Expand All @@ -150,9 +164,10 @@ impl SysDb for TestSysDb {
&mut self,
tenant_ids: Vec<String>,
) -> Result<Vec<Tenant>, GetLastCompactionTimeError> {
let inner = self.inner.lock();
let mut tenants = Vec::new();
for tenant_id in tenant_ids {
let last_compaction_time = match self.tenant_last_compaction_time.get(&tenant_id) {
let last_compaction_time = match inner.tenant_last_compaction_time.get(&tenant_id) {
Some(last_compaction_time) => *last_compaction_time,
None => {
// TODO: Log an error
Expand All @@ -175,7 +190,8 @@ impl SysDb for TestSysDb {
collection_version: i32,
segment_flush_info: Arc<[SegmentFlushInfo]>,
) -> Result<FlushCompactionResponse, FlushCompactionError> {
let collection = self
let mut inner = self.inner.lock();
let collection = inner
.collections
.get(&Uuid::parse_str(&collection_id).unwrap());
if collection.is_none() {
Expand All @@ -186,22 +202,22 @@ impl SysDb for TestSysDb {
collection.log_position = log_position;
let new_collection_version = collection_version + 1;
collection.version = new_collection_version;
self.collections.insert(collection.id, collection);
let mut last_compaction_time = match self.tenant_last_compaction_time.get(&tenant_id) {
inner.collections.insert(collection.id, collection);
let mut last_compaction_time = match inner.tenant_last_compaction_time.get(&tenant_id) {
Some(last_compaction_time) => *last_compaction_time,
None => 0,
};
last_compaction_time += 1;

// update segments
for segment_flush_info in segment_flush_info.iter() {
let segment = self.segments.get(&segment_flush_info.segment_id);
let segment = inner.segments.get(&segment_flush_info.segment_id);
if segment.is_none() {
return Err(FlushCompactionError::SegmentNotFound);
}
let mut segment = segment.unwrap().clone();
segment.file_path = segment_flush_info.file_paths.clone();
self.segments.insert(segment.id, segment);
inner.segments.insert(segment.id, segment);
}

Ok(FlushCompactionResponse::new(
Expand Down

0 comments on commit df29a41

Please sign in to comment.