Skip to content

Commit

Permalink
feat(cli): Reuse command implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoh committed Apr 30, 2024
1 parent 004cd9a commit 7f9e8ca
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 191 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cargo install --git https://github.com/emmyoh/zebra --features="cli"

You should specify the features relevant to your use case. For example, if you're interested in using the Zebra CLI on an Apple silicon device, run:
```bash
cargo install --git https://github.com/emmyoh/zebra --features="cli,metal"
cargo install --git https://github.com/emmyoh/zebra --features="cli,accelerate,metal"
```

## Features
Expand Down
323 changes: 133 additions & 190 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
use clap::{command, Parser, Subcommand};
use fastembed::Embedding;
use fastembed::TextEmbedding;
use indicatif::HumanCount;
use indicatif::ProgressStyle;
use indicatif::{ProgressBar, ProgressDrawTarget};
use pretty_duration::pretty_duration;
use rodio::{Decoder, OutputStream, Sink};
use space::Metric;
use std::error::Error;
use std::fs::File;
use std::io::BufReader;
use std::io::Write;
use std::io::{stdout, BufWriter};
use std::path::PathBuf;
use ticky::Stopwatch;
use zebra::db::Database;
use zebra::db::DocumentType;
use zebra::distance::DistanceUnit;
use zebra::model::{AudioEmbeddingModel, DatabaseEmbeddingModel, ImageEmbeddingModel};

const INSERT_BATCH_SIZE: usize = 100;
Expand Down Expand Up @@ -112,154 +117,71 @@ enum AudioCommands {

fn main() -> Result<(), Box<dyn Error>> {
let cli = Cli::parse();
let mut sw = Stopwatch::start_new();
match cli.commands {
Commands::Text(text) => {
match text.text_commands {
TextCommands::Insert { mut texts } => {
let mut db = zebra::text::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Inserting {} text(s).", texts.len())?;
let insertion_results = db.insert_documents(&model, &mut texts)?;
sw.stop();
writeln!(
buffer,
"{} embeddings of {} dimensions inserted into the database in {}.",
insertion_results.0,
insertion_results.1,
pretty_duration(&sw.elapsed(), None)
)?;
}
TextCommands::InsertFromFiles { file_paths } => {
let mut db = zebra::text::create_or_load_database()?;
let num_texts = file_paths.len();
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
println!("Inserting texts from {} file(s).", num_texts);
let progress_bar = ProgressBar::with_draw_target(
Some(num_texts.try_into()?),
ProgressDrawTarget::hidden(),
);
// .with_message(format!("Inserting texts from {} file(s).", num_texts));
progress_bar.set_style(progress_bar_style()?);
let mut i = 0;
let mut texts = Vec::new();
// Insert texts in batches of INSERT_BATCH_SIZE.
for file_path in file_paths {
let text = std::fs::read_to_string(file_path)?;
texts.push(text);
if i == INSERT_BATCH_SIZE - 1 {
let insertion_results = db.insert_documents(&model, &mut texts)?;
progress_bar.println(format!(
"{} embeddings of {} dimensions inserted into the database.",
insertion_results.0, insertion_results.1
));
texts.clear();
i = 0;
} else {
i += 1;
}
progress_bar.inc(1);
if progress_bar.is_hidden() {
progress_bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(100));
}
}
// Insert the remaining texts, if any.
if !texts.is_empty() {
let insertion_results = db.insert_documents(&model, &mut texts)?;
progress_bar.println(format!(
"{} embeddings of {} dimensions inserted into the database.",
insertion_results.0, insertion_results.1
));
}
sw.stop();
progress_bar.println(format!(
"Inserted {} text(s) in {}.",
num_texts,
pretty_duration(&sw.elapsed(), None)
));
}
TextCommands::Query {
texts,
number_of_results,
} => {
let mut db = zebra::text::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let num_texts = texts.len();
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Querying {} text(s).", num_texts)?;
let query_results = db.query_documents(&model, texts, number_of_results)?;
let result_texts: Vec<String> = query_results
.iter()
.map(|x| String::from_utf8(x.to_vec()).unwrap())
.collect();
sw.stop();
writeln!(
buffer,
"Queried {} text(s) in {}.",
num_texts,
pretty_duration(&sw.elapsed(), None)
)?;
writeln!(buffer, "Results:")?;
for result in result_texts {
writeln!(buffer, "- \t{}", result)?;
}
}
TextCommands::Clear => {
let text_type = DocumentType::Text;
let mut buffer = BufWriter::new(stdout().lock());
writeln!(buffer, "Clearing database.")?;
std::fs::remove_file(text_type.database_name()).unwrap_or(());
std::fs::remove_dir_all(text_type.subdirectory_name()).unwrap_or(());
std::fs::remove_dir_all(".fastembed_cache").unwrap_or(());
sw.stop();
writeln!(
buffer,
"Database cleared in {}.",
pretty_duration(&sw.elapsed(), None)
)?;
Commands::Text(text) => match text.text_commands {
TextCommands::Insert { mut texts } => {
let mut sw = Stopwatch::start_new();
let mut db = zebra::text::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Inserting {} text(s).", texts.len())?;
let insertion_results = db.insert_documents(&model, &mut texts)?;
sw.stop();
writeln!(
buffer,
"{} embeddings of {} dimensions inserted into the database in {}.",
HumanCount(insertion_results.0.try_into()?).to_string(),
HumanCount(insertion_results.1.try_into()?).to_string(),
pretty_duration(&sw.elapsed(), None)
)?;
}
TextCommands::InsertFromFiles { file_paths } => {
let mut db = zebra::text::create_or_load_database()?;
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
insert_from_files(&mut db, model, file_paths)?;
}
TextCommands::Query {
texts,
number_of_results,
} => {
let mut sw = Stopwatch::start_new();
let mut db = zebra::text::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let num_texts = texts.len();
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Querying {} text(s).", num_texts)?;
let query_results = db.query_documents(&model, texts, number_of_results)?;
let result_texts: Vec<String> = query_results
.iter()
.map(|x| String::from_utf8(x.to_vec()).unwrap())
.collect();
sw.stop();
writeln!(
buffer,
"Queried {} text(s) in {}.",
num_texts,
pretty_duration(&sw.elapsed(), None)
)?;
writeln!(buffer, "Results:")?;
for result in result_texts {
writeln!(buffer, "- \t{}", result)?;
}
}
}
TextCommands::Clear => {
clear_database(DocumentType::Text)?;
}
},
Commands::Image(image) => match image.image_commands {
ImageCommands::Insert { file_paths } => {
let mut db = zebra::image::create_or_load_database()?;
let num_images = file_paths.len();
let model: ImageEmbeddingModel = DatabaseEmbeddingModel::new()?;
println!("Inserting images from {} file(s).", num_images);
let progress_bar = ProgressBar::with_draw_target(
Some(num_images.try_into()?),
ProgressDrawTarget::hidden(),
);
// .with_message(format!("Inserting images from {} file(s).", num_images));
progress_bar.set_style(progress_bar_style()?);
let images: Vec<String> = file_paths
.into_iter()
.map(|x| x.to_str().unwrap().to_string())
.collect();
// Insert images in batches of INSERT_BATCH_SIZE.
for image_batch in images.chunks(INSERT_BATCH_SIZE) {
let insertion_results = db.insert_documents(&model, image_batch)?;
progress_bar.println(format!(
"{} embeddings of {} dimensions inserted into the database.",
insertion_results.0, insertion_results.1
));
progress_bar.inc(INSERT_BATCH_SIZE.try_into()?);
if progress_bar.is_hidden() {
progress_bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(100));
}
}
sw.stop();
progress_bar.println(format!(
"Inserted {} image(s) in {}.",
num_images,
pretty_duration(&sw.elapsed(), None)
));
insert_from_files(&mut db, model, file_paths)?;
}
ImageCommands::Query {
image_path,
number_of_results,
} => {
let mut sw = Stopwatch::start_new();
let mut db = zebra::image::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let image_print_config = viuer::Config {
Expand Down Expand Up @@ -296,59 +218,20 @@ fn main() -> Result<(), Box<dyn Error>> {
}
}
ImageCommands::Clear => {
let image_type = DocumentType::Image;
let mut buffer = BufWriter::new(stdout().lock());
writeln!(buffer, "Clearing database.")?;
std::fs::remove_file(image_type.database_name()).unwrap_or(());
std::fs::remove_dir_all(image_type.subdirectory_name()).unwrap_or(());
// std::fs::remove_dir_all(".fastembed_cache").unwrap_or(());
sw.stop();
writeln!(
buffer,
"Database cleared in {}.",
pretty_duration(&sw.elapsed(), None)
)?;
clear_database(DocumentType::Image)?;
}
},
Commands::Audio(audio) => match audio.audio_commands {
AudioCommands::Insert { file_paths } => {
let mut db = zebra::audio::create_or_load_database()?;
let num_sounds = file_paths.len();
let model: AudioEmbeddingModel = DatabaseEmbeddingModel::new()?;
println!("Inserting sounds from {} file(s).", num_sounds);
let progress_bar = ProgressBar::with_draw_target(
Some(num_sounds.try_into()?),
ProgressDrawTarget::hidden(),
);
// .with_message(format!("Inserting sounds from {} file(s).", num_sounds));
progress_bar.set_style(progress_bar_style()?);
let sounds: Vec<String> = file_paths
.into_iter()
.map(|x| x.to_str().unwrap().to_string())
.collect();
// Insert sounds in batches of INSERT_BATCH_SIZE.
for image_batch in sounds.chunks(INSERT_BATCH_SIZE) {
let insertion_results = db.insert_documents(&model, image_batch)?;
progress_bar.println(format!(
"{} embeddings of {} dimensions inserted into the database.",
insertion_results.0, insertion_results.1
));
progress_bar.inc(INSERT_BATCH_SIZE.try_into()?);
if progress_bar.is_hidden() {
progress_bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(100));
}
}
sw.stop();
progress_bar.println(format!(
"Inserted {} sound(s) in {}.",
num_sounds,
pretty_duration(&sw.elapsed(), None)
));
insert_from_files(&mut db, model, file_paths)?;
}
AudioCommands::Query {
audio_path,
number_of_results,
} => {
let mut sw = Stopwatch::start_new();
let mut db = zebra::audio::create_or_load_database()?;
let (_stream, stream_handle) = OutputStream::try_default()?;
let sink = Sink::try_new(&stream_handle)?;
Expand Down Expand Up @@ -377,18 +260,7 @@ fn main() -> Result<(), Box<dyn Error>> {
}
}
AudioCommands::Clear => {
let audio_type = DocumentType::Audio;
let mut buffer = BufWriter::new(stdout().lock());
writeln!(buffer, "Clearing database.")?;
std::fs::remove_file(audio_type.database_name()).unwrap_or(());
std::fs::remove_dir_all(audio_type.subdirectory_name()).unwrap_or(());
// std::fs::remove_dir_all(".fastembed_cache").unwrap_or(());
sw.stop();
writeln!(
buffer,
"Database cleared in {}.",
pretty_duration(&sw.elapsed(), None)
)?;
clear_database(DocumentType::Audio)?;
}
},
// _ => unreachable!(),
Expand All @@ -399,3 +271,74 @@ fn main() -> Result<(), Box<dyn Error>> {
fn progress_bar_style() -> Result<ProgressStyle, Box<dyn Error>> {
Ok(ProgressStyle::with_template("[{elapsed} elapsed, {eta} remaining ({duration} total)] {wide_bar:.cyan/blue} {human_pos} of {human_len} ({percent}%) {msg}")?)
}

fn clear_database(document_type: DocumentType) -> Result<(), Box<dyn Error>> {
let mut sw = Stopwatch::start_new();
let mut buffer = BufWriter::new(stdout().lock());
writeln!(buffer, "Clearing database.")?;
std::fs::remove_file(document_type.database_name()).unwrap_or(());
std::fs::remove_dir_all(document_type.subdirectory_name()).unwrap_or(());
if document_type == DocumentType::Text {
std::fs::remove_dir_all(".fastembed_cache").unwrap_or(());
}
sw.stop();
writeln!(
buffer,
"Database cleared in {}.",
pretty_duration(&sw.elapsed(), None)
)?;
Ok(())
}

fn insert_from_files<
Met: Metric<Embedding, Unit = DistanceUnit> + serde::ser::Serialize,
const EF_CONSTRUCTION: usize,
const M: usize,
const M0: usize,
>(
db: &mut Database<Met, EF_CONSTRUCTION, M, M0>,
model: impl DatabaseEmbeddingModel,
file_paths: Vec<PathBuf>,
) -> Result<(), Box<dyn Error>>
where
for<'de> Met: serde::Deserialize<'de>,
{
let mut sw = Stopwatch::start_new();
let num_documents = file_paths.len();
println!(
"Inserting documents from {} file(s).",
HumanCount(num_documents.try_into()?).to_string()
);
let progress_bar = ProgressBar::with_draw_target(
Some(num_documents.try_into()?),
ProgressDrawTarget::hidden(),
);
progress_bar.set_style(progress_bar_style()?);
let documents: Vec<String> = file_paths
.into_iter()
.map(|x| x.to_str().unwrap().to_string())
.collect();
// Insert documents in batches of INSERT_BATCH_SIZE.
for document_batch in documents.chunks(INSERT_BATCH_SIZE) {
let mut batch_sw = Stopwatch::start_new();
let insertion_results = db.insert_documents(&model, document_batch)?;
batch_sw.stop();
progress_bar.println(format!(
"{} embeddings of {} dimensions inserted into the database in {}.",
HumanCount(insertion_results.0.try_into()?).to_string(),
HumanCount(insertion_results.1.try_into()?).to_string(),
pretty_duration(&batch_sw.elapsed(), None)
));
progress_bar.inc(INSERT_BATCH_SIZE.try_into()?);
if progress_bar.is_hidden() {
progress_bar.set_draw_target(ProgressDrawTarget::stderr_with_hz(100));
}
}
sw.stop();
progress_bar.println(format!(
"Inserted {} document(s) in {}.",
num_documents,
pretty_duration(&sw.elapsed(), None)
));
Ok(())
}

0 comments on commit 7f9e8ca

Please sign in to comment.