From a1f0d24badfc95a3ba76233f87a18f870f5914d3 Mon Sep 17 00:00:00 2001 From: Dens Sumesh Date: Tue, 10 Dec 2024 17:53:01 -0800 Subject: [PATCH 1/2] feature: incrementally add pages --- .../search/src/components/UploadFile.tsx | 8 +- pdf2md/server/src/models.rs | 6 +- pdf2md/server/src/operators/clickhouse.rs | 6 +- pdf2md/server/src/routes/jinja_templates.rs | 2 +- server/src/bin/csv-jsonl-worker.rs | 2 +- server/src/bin/file-worker.rs | 374 +++++++++++------- server/src/handlers/file_handler.rs | 12 +- server/src/operators/file_operator.rs | 35 +- 8 files changed, 264 insertions(+), 181 deletions(-) diff --git a/frontends/search/src/components/UploadFile.tsx b/frontends/search/src/components/UploadFile.tsx index 1b54c012e..0126b867e 100644 --- a/frontends/search/src/components/UploadFile.tsx +++ b/frontends/search/src/components/UploadFile.tsx @@ -17,7 +17,11 @@ interface RequestBody { group_tracking_id?: string; metadata: any; time_stamp?: string; - use_pdf2md_ocr?: boolean; + pdf2md_options?: { + use_pdf2md_ocr: boolean; + system_prompt?: string; + split_headings?: boolean; + }; } export const UploadFile = () => { @@ -145,7 +149,7 @@ export const UploadFile = () => { split_delimiters: splitDelimiters(), target_splits_per_chunk: targetSplitsPerChunk(), rebalance_chunks: rebalanceChunks(), - use_pdf2md_ocr: useGptChunking(), + pdf2md_options: { use_pdf2md_ocr: useGptChunking() }, group_tracking_id: groupTrackingId() === "" ? undefined : groupTrackingId(), // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment diff --git a/pdf2md/server/src/models.rs b/pdf2md/server/src/models.rs index 29d2bee5f..d4ff05784 100644 --- a/pdf2md/server/src/models.rs +++ b/pdf2md/server/src/models.rs @@ -251,7 +251,7 @@ impl From for Vec { #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] pub struct GetTaskRequest { - pub pagination_token: Option, + pub pagination_token: Option, pub limit: Option, } @@ -265,7 +265,7 @@ pub struct GetTaskResponse { pub status: String, pub created_at: String, pub pages: Option>, - pub pagination_token: Option, + pub pagination_token: Option, } impl GetTaskResponse { @@ -296,7 +296,7 @@ impl GetTaskResponse { pages_processed: task.pages_processed, status: task.status, created_at: task.created_at.to_string(), - pagination_token: pages.last().map(|c| c.id.clone()), + pagination_token: pages.last().map(|c| c.page), pages: Some(pages.into_iter().map(Chunk::from).collect()), } } diff --git a/pdf2md/server/src/operators/clickhouse.rs b/pdf2md/server/src/operators/clickhouse.rs index 22a86c6fa..9f307e5bd 100644 --- a/pdf2md/server/src/operators/clickhouse.rs +++ b/pdf2md/server/src/operators/clickhouse.rs @@ -170,7 +170,7 @@ pub async fn get_task( pub async fn get_task_pages( task: FileTaskClickhouse, limit: Option, - offset_id: Option, + offset_id: Option, clickhouse_client: &clickhouse::Client, ) -> Result, ServiceError> { if FileTaskStatus::from(task.status.clone()) == FileTaskStatus::Completed || task.pages > 0 { @@ -178,10 +178,10 @@ pub async fn get_task_pages( let pages: Vec = clickhouse_client .query( - "SELECT ?fields FROM file_chunks WHERE task_id = ? AND id > ? ORDER BY page LIMIT ?", + "SELECT ?fields FROM file_chunks WHERE task_id = ? AND page > ? ORDER BY page LIMIT ?", ) .bind(task.id.clone()) - .bind(offset_id.unwrap_or(uuid::Uuid::nil())) + .bind(offset_id.unwrap_or(0)) .bind(limit) .fetch_all() .await diff --git a/pdf2md/server/src/routes/jinja_templates.rs b/pdf2md/server/src/routes/jinja_templates.rs index 0baad2041..17cfc4e86 100644 --- a/pdf2md/server/src/routes/jinja_templates.rs +++ b/pdf2md/server/src/routes/jinja_templates.rs @@ -3,7 +3,7 @@ use crate::{ get_env, Templates, }; use actix_web::{get, HttpResponse}; -use minijinja::{context, path_loader, Environment}; +use minijinja::context; #[utoipa::path( get, diff --git a/server/src/bin/csv-jsonl-worker.rs b/server/src/bin/csv-jsonl-worker.rs index 727685d54..9c763c414 100644 --- a/server/src/bin/csv-jsonl-worker.rs +++ b/server/src/bin/csv-jsonl-worker.rs @@ -505,7 +505,7 @@ async fn process_csv_jsonl_file( rebalance_chunks: Some(false), split_delimiters: None, target_splits_per_chunk: None, - use_pdf2md_ocr: None, + pdf2md_options: None, base64_file: "".to_string(), }, csv_jsonl_worker_message.dataset_id, diff --git a/server/src/bin/file-worker.rs b/server/src/bin/file-worker.rs index b12d768b6..532620fc0 100644 --- a/server/src/bin/file-worker.rs +++ b/server/src/bin/file-worker.rs @@ -10,6 +10,7 @@ use trieve_server::{ data::models::{self, FileWorkerMessage}, errors::ServiceError, establish_connection, get_env, + handlers::chunk_handler::ChunkReqPayload, operators::{ clickhouse_operator::{ClickHouseEvent, EventQueue}, dataset_operator::get_dataset_and_organization_from_dataset_id_query, @@ -257,7 +258,7 @@ pub struct PollTaskResponse { pub status: String, pub created_at: String, pub pages: Option>, - pub pagination_token: Option, + pub pagination_token: Option, } #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] @@ -298,153 +299,217 @@ async fn upload_file( ) .await?; - if file_name.ends_with(".pdf") { - if let Some(true) = file_worker_message.upload_file_data.use_pdf2md_ocr { - log::info!("Using pdf2md for OCR for file"); - let pdf2md_url = std::env::var("PDF2MD_URL") - .expect("PDF2MD_URL must be set") - .to_string(); - - let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string()); - - let pdf2md_client = reqwest::Client::new(); - let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone()); - - let json_value = serde_json::json!({ - "file_name": file_name, - "base64_file": encoded_file.clone() - }); - - log::info!("Sending file to pdf2md"); - let pdf2md_response = pdf2md_client - .post(format!("{}/api/task", pdf2md_url)) - .header("Content-Type", "application/json") - .header("Authorization", &pdf2md_auth) - .json(&json_value) - .send() - .await - .map_err(|err| { - log::error!("Could not send file to pdf2md {:?}", err); - ServiceError::BadRequest("Could not send file to pdf2md".to_string()) - })?; - - let response = pdf2md_response.json::().await; - - let task_id = match response { - Ok(response) => response.id, - Err(err) => { - log::error!("Could not parse id from pdf2md {:?}", err); - return Err(ServiceError::BadRequest(format!( - "Could not parse id from pdf2md {:?}", - err - ))); - } - }; + if file_name.ends_with(".pdf") + && file_worker_message + .upload_file_data + .pdf2md_options + .as_ref() + .is_some_and(|options| options.use_pdf2md_ocr) + { + log::info!("Using pdf2md for OCR for file"); + let pdf2md_url = std::env::var("PDF2MD_URL") + .expect("PDF2MD_URL must be set") + .to_string(); - log::info!("Waiting on Task {}", task_id); - #[allow(unused_assignments)] - let mut chunk_htmls: Vec = vec![]; - let mut pagination_token: Option = None; - let mut completed = false; + let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string()); - loop { - if completed && pagination_token.is_none() { - break; - } - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - let request = if let Some(pagination_token) = &pagination_token { - log::info!( - "Polling on task {} with pagination token {}", - task_id, - pagination_token - ); - pdf2md_client - .get( - format!( - "{}/api/task/{}?pagination_token={}", - pdf2md_url, task_id, pagination_token - ) - .as_str(), + let pdf2md_client = reqwest::Client::new(); + let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone()); + + let mut json_value = serde_json::json!({ + "file_name": file_name, + "base64_file": encoded_file.clone() + }); + + if let Some(system_prompt) = &file_worker_message + .upload_file_data + .pdf2md_options + .as_ref() + .map(|options| options.system_prompt.clone()) + { + json_value["system_prompt"] = serde_json::json!(system_prompt); + } + + log::info!("Sending file to pdf2md"); + let pdf2md_response = pdf2md_client + .post(format!("{}/api/task", pdf2md_url)) + .header("Content-Type", "application/json") + .header("Authorization", &pdf2md_auth) + .json(&json_value) + .send() + .await + .map_err(|err| { + log::error!("Could not send file to pdf2md {:?}", err); + ServiceError::BadRequest("Could not send file to pdf2md".to_string()) + })?; + + let response = pdf2md_response.json::().await; + + let task_id = match response { + Ok(response) => response.id, + Err(err) => { + log::error!("Could not parse id from pdf2md {:?}", err); + return Err(ServiceError::BadRequest(format!( + "Could not parse id from pdf2md {:?}", + err + ))); + } + }; + + let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64; + let created_file = create_file_query( + file_id, + file_size_mb, + file_worker_message.upload_file_data.clone(), + file_worker_message.dataset_id, + web_pool.clone(), + ) + .await?; + + log::info!("Waiting on Task {}", task_id); + let mut processed_pages = std::collections::HashSet::new(); + let mut pagination_token: Option = None; + let mut completed = false; + const PAGE_SIZE: u32 = 20; + + loop { + if completed { + break; + } + + let request = if let Some(pagination_token) = &pagination_token { + log::info!( + "Polling on task {} with pagination token {}", + task_id, + pagination_token + ); + pdf2md_client + .get( + format!( + "{}/api/task/{}?pagination_token={}", + pdf2md_url, task_id, pagination_token ) - .header("Content-Type", "application/json") - .header("Authorization", &pdf2md_auth) - .send() - .await - .map_err(|err| { - log::error!("Could not send poll request to pdf2md {:?}", err); - ServiceError::BadRequest(format!( - "Could not send request to pdf2md {:?}", - err - )) - })? - } else { - log::info!("Waiting on task {}", task_id); - pdf2md_client - .get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str()) - .header("Content-Type", "application/json") - .header("Authorization", &pdf2md_auth) - .send() - .await - .map_err(|err| { - log::error!("Could not send poll request to pdf2md {:?}", err); - ServiceError::BadRequest(format!( - "Could not send request to pdf2md {:?}", - err - )) - })? - }; - - let task_response = request.json::().await.map_err(|err| { - log::error!("Could not parse response from pdf2md {:?}", err); - ServiceError::BadRequest(format!( - "Could not parse response from pdf2md {:?}", - err - )) - })?; - - if task_response.status == "Completed" && task_response.pages.is_some() { - completed = true; - pagination_token = task_response.pagination_token.clone(); - if let Some(pages) = task_response.pages { - log::info!("Got {} pages from task {}", pages.len(), task_id); - for page in pages { - log::info!("."); - chunk_htmls.push(page.content.clone()); - } + .as_str(), + ) + .header("Content-Type", "application/json") + .header("Authorization", &pdf2md_auth) + .send() + .await + .map_err(|err| { + log::error!("Could not send poll request to pdf2md {:?}", err); + ServiceError::BadRequest(format!( + "Could not send request to pdf2md {:?}", + err + )) + })? + } else { + log::info!("Waiting on task {}", task_id); + pdf2md_client + .get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str()) + .header("Content-Type", "application/json") + .header("Authorization", &pdf2md_auth) + .send() + .await + .map_err(|err| { + log::error!("Could not send poll request to pdf2md {:?}", err); + ServiceError::BadRequest(format!( + "Could not send request to pdf2md {:?}", + err + )) + })? + }; + + let task_response = request.json::().await.map_err(|err| { + log::error!("Could not parse response from pdf2md {:?}", err); + ServiceError::BadRequest(format!("Could not parse response from pdf2md {:?}", err)) + })?; + + let mut new_chunks = Vec::new(); + if let Some(pages) = task_response.pages { + log::info!("Got {} pages from task {}", pages.len(), task_id); + + for page in pages { + let page_id = format!("{}", page.page_num); + + if !processed_pages.contains(&page_id) { + processed_pages.insert(page_id); + let metadata = file_worker_message + .upload_file_data + .metadata + .clone() + .map(|mut metadata| { + metadata["page_num"] = serde_json::json!(page.page_num); + metadata["file_name"] = serde_json::json!(task_response.file_name); + metadata + }) + .or(Some(serde_json::json!({ + "page_num": page.page_num, + "file_name": task_response.file_name + }))); + + let create_chunk_data = ChunkReqPayload { + chunk_html: Some(page.content.clone()), + semantic_content: None, + link: file_worker_message.upload_file_data.link.clone(), + tag_set: file_worker_message.upload_file_data.tag_set.clone(), + metadata, + group_ids: None, + group_tracking_ids: None, + location: None, + tracking_id: file_worker_message + .upload_file_data + .group_tracking_id + .clone() + .map(|tracking_id| format!("{}|{}", tracking_id, page.page_num)), + upsert_by_tracking_id: None, + time_stamp: file_worker_message.upload_file_data.time_stamp.clone(), + weight: None, + split_avg: None, + convert_html_to_text: None, + image_urls: None, + num_value: None, + fulltext_boost: None, + semantic_boost: None, + }; + new_chunks.push(create_chunk_data); } + } + + if !new_chunks.is_empty() { + create_file_chunks( + created_file.id, + file_worker_message.upload_file_data.clone(), + new_chunks.clone(), + dataset_org_plan_sub.clone(), + web_pool.clone(), + event_queue.clone(), + redis_conn.clone(), + ) + .await?; + } + } + + completed = task_response.status == "Completed"; + + let page_start = pagination_token.unwrap_or(0); + + let has_complete_range = (page_start..page_start + PAGE_SIZE) + .all(|page_num| processed_pages.contains(&(page_num + 1).to_string())); - continue; - } else { - log::info!("Task {} not ready", task_id); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - continue; + if let Some(token) = task_response.pagination_token { + if has_complete_range || completed { + pagination_token = Some(token); } } - // Poll Chunks from pdf chunks from service - let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64; - let created_file = create_file_query( - file_id, - file_size_mb, - file_worker_message.upload_file_data.clone(), - file_worker_message.dataset_id, - web_pool.clone(), - ) - .await?; - - create_file_chunks( - created_file.id, - file_worker_message.upload_file_data, - chunk_htmls, - dataset_org_plan_sub, - web_pool.clone(), - event_queue.clone(), - redis_conn, - ) - .await?; - - return Ok(Some(file_id)); + if new_chunks.is_empty() { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } else if !has_complete_range && !completed { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } } + + return Ok(Some(file_id)); } let tika_url = std::env::var("TIKA_URL") @@ -512,10 +577,39 @@ async fn upload_file( return Err(ServiceError::BadRequest("Could not parse file".to_string())); }; + let chunks = chunk_htmls + .into_iter() + .enumerate() + .map(|(i, chunk_html)| ChunkReqPayload { + chunk_html: Some(chunk_html), + semantic_content: None, + link: file_worker_message.upload_file_data.link.clone(), + tag_set: file_worker_message.upload_file_data.tag_set.clone(), + metadata: file_worker_message.upload_file_data.metadata.clone(), + group_ids: None, + group_tracking_ids: None, + location: None, + tracking_id: file_worker_message + .upload_file_data + .group_tracking_id + .clone() + .map(|tracking_id| format!("{}|{}", tracking_id, i)), + upsert_by_tracking_id: None, + time_stamp: file_worker_message.upload_file_data.time_stamp.clone(), + weight: None, + split_avg: None, + convert_html_to_text: None, + image_urls: None, + num_value: None, + fulltext_boost: None, + semantic_boost: None, + }) + .collect::>(); + create_file_chunks( created_file.id, file_worker_message.upload_file_data, - chunk_htmls, + chunks, dataset_org_plan_sub, web_pool.clone(), event_queue.clone(), diff --git a/server/src/handlers/file_handler.rs b/server/src/handlers/file_handler.rs index 353a78ff8..2d17431a6 100644 --- a/server/src/handlers/file_handler.rs +++ b/server/src/handlers/file_handler.rs @@ -83,7 +83,17 @@ pub struct UploadFileReqPayload { /// Group tracking id is an optional field which allows you to specify the tracking id of the group that is created from the file. Chunks created will be created with the tracking id of `group_tracking_id|` pub group_tracking_id: Option, /// Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o. Default is false. - pub use_pdf2md_ocr: Option, + pub pdf2md_options: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)] +pub struct Pdf2MdOptions { + /// Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o. Default is false. + pub use_pdf2md_ocr: bool, + /// Prompt to use for the gpt-4o model. Default is None. + pub system_prompt: Option, + /// Split headings is an optional field which allows you to specify whether or not to split headings into separate chunks. Default is false. + pub split_headings: Option, } #[derive(Debug, Serialize, Deserialize, Clone, ToSchema)] diff --git a/server/src/operators/file_operator.rs b/server/src/operators/file_operator.rs index 601448f3c..d99d1ff30 100644 --- a/server/src/operators/file_operator.rs +++ b/server/src/operators/file_operator.rs @@ -126,14 +126,12 @@ pub fn preprocess_file_to_chunks( pub async fn create_file_chunks( created_file_id: uuid::Uuid, upload_file_data: UploadFileReqPayload, - chunk_htmls: Vec, + mut chunks: Vec, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, pool: web::Data, event_queue: web::Data, mut redis_conn: MultiplexedConnection, ) -> Result<(), ServiceError> { - let mut chunks: Vec = [].to_vec(); - let name = upload_file_data.file_name.clone(); let chunk_group = ChunkGroup::from_details( @@ -167,6 +165,10 @@ pub async fn create_file_chunks( let group_id = chunk_group.id; + chunks.iter_mut().for_each(|chunk| { + chunk.group_ids = Some(vec![group_id]); + }); + create_group_from_file_query(group_id, created_file_id, pool.clone()) .await .map_err(|e| { @@ -174,33 +176,6 @@ pub async fn create_file_chunks( e })?; - for (i, chunk_html) in chunk_htmls.iter().enumerate() { - let create_chunk_data = ChunkReqPayload { - chunk_html: Some(chunk_html.clone()), - semantic_content: None, - link: upload_file_data.link.clone(), - tag_set: upload_file_data.tag_set.clone(), - metadata: upload_file_data.metadata.clone(), - group_ids: Some(vec![group_id]), - group_tracking_ids: None, - location: None, - tracking_id: upload_file_data - .group_tracking_id - .clone() - .map(|tracking_id| format!("{}|{}", tracking_id, i)), - upsert_by_tracking_id: None, - time_stamp: upload_file_data.time_stamp.clone(), - weight: None, - split_avg: None, - convert_html_to_text: None, - image_urls: None, - num_value: None, - fulltext_boost: None, - semantic_boost: None, - }; - chunks.push(create_chunk_data); - } - let chunk_count = get_row_count_for_organization_id_query( dataset_org_plan_sub.organization.organization.id, pool.clone(), From 5da1466ffb8e372a7f9d943080dc019763b9e4b5 Mon Sep 17 00:00:00 2001 From: Dens Sumesh Date: Fri, 13 Dec 2024 17:07:46 -0800 Subject: [PATCH 2/2] feature: heading based chunking --- .../search/src/components/UploadFile.tsx | 39 ++++- pdf2md/server/src/operators/pdf_chunk.rs | 16 +- server/src/bin/file-worker.rs | 116 +++++++++--- server/src/lib.rs | 1 + server/src/operators/file_operator.rs | 165 +++++++++++++----- 5 files changed, 271 insertions(+), 66 deletions(-) diff --git a/frontends/search/src/components/UploadFile.tsx b/frontends/search/src/components/UploadFile.tsx index 0126b867e..5d6f5570a 100644 --- a/frontends/search/src/components/UploadFile.tsx +++ b/frontends/search/src/components/UploadFile.tsx @@ -44,7 +44,10 @@ export const UploadFile = () => { const [targetSplitsPerChunk, setTargetSplitsPerChunk] = createSignal(20); const [rebalanceChunks, setRebalanceChunks] = createSignal(false); const [useGptChunking, setUseGptChunking] = createSignal(false); + const [useHeadingBasedChunking, setUseHeadingBasedChunking] = + createSignal(false); const [groupTrackingId, setGroupTrackingId] = createSignal(""); + const [systemPrompt, setSystemPrompt] = createSignal(""); const [showFileInput, setShowFileInput] = createSignal(true); const [showFolderInput, setShowFolderInput] = createSignal(false); @@ -149,7 +152,11 @@ export const UploadFile = () => { split_delimiters: splitDelimiters(), target_splits_per_chunk: targetSplitsPerChunk(), rebalance_chunks: rebalanceChunks(), - pdf2md_options: { use_pdf2md_ocr: useGptChunking() }, + pdf2md_options: { + use_pdf2md_ocr: useGptChunking(), + split_headings: useHeadingBasedChunking(), + system_prompt: systemPrompt(), + }, group_tracking_id: groupTrackingId() === "" ? undefined : groupTrackingId(), // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment @@ -343,6 +350,36 @@ export const UploadFile = () => { onInput={(e) => setUseGptChunking(e.currentTarget.checked)} class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700" /> +
+
Heading Based Chunking
+ } + tooltipText="If set to true, Trieve will use the headings in the document to chunk the text." + /> +
+ + setUseHeadingBasedChunking(e.currentTarget.checked) + } + class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700" + /> +
+
+
System Prompt
+ } + tooltipText="System prompt to use when chunking. This is an optional field which allows you to specify the system prompt to use when chunking the text. If not specified, the default system prompt is used. However, you may want to use a different system prompt." + /> +
+