+ }
+ tooltipText="System prompt to use when chunking. This is an optional field which allows you to specify the system prompt to use when chunking the text. If not specified, the default system prompt is used. However, you may want to use a different system prompt."
+ />
+
+
diff --git a/pdf2md/server/src/models.rs b/pdf2md/server/src/models.rs
index 29d2bee5f..d4ff05784 100644
--- a/pdf2md/server/src/models.rs
+++ b/pdf2md/server/src/models.rs
@@ -251,7 +251,7 @@ impl From for Vec {
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
pub struct GetTaskRequest {
- pub pagination_token: Option,
+ pub pagination_token: Option,
pub limit: Option,
}
@@ -265,7 +265,7 @@ pub struct GetTaskResponse {
pub status: String,
pub created_at: String,
pub pages: Option>,
- pub pagination_token: Option,
+ pub pagination_token: Option,
}
impl GetTaskResponse {
@@ -296,7 +296,7 @@ impl GetTaskResponse {
pages_processed: task.pages_processed,
status: task.status,
created_at: task.created_at.to_string(),
- pagination_token: pages.last().map(|c| c.id.clone()),
+ pagination_token: pages.last().map(|c| c.page),
pages: Some(pages.into_iter().map(Chunk::from).collect()),
}
}
diff --git a/pdf2md/server/src/operators/clickhouse.rs b/pdf2md/server/src/operators/clickhouse.rs
index 22a86c6fa..9f307e5bd 100644
--- a/pdf2md/server/src/operators/clickhouse.rs
+++ b/pdf2md/server/src/operators/clickhouse.rs
@@ -170,7 +170,7 @@ pub async fn get_task(
pub async fn get_task_pages(
task: FileTaskClickhouse,
limit: Option,
- offset_id: Option,
+ offset_id: Option,
clickhouse_client: &clickhouse::Client,
) -> Result, ServiceError> {
if FileTaskStatus::from(task.status.clone()) == FileTaskStatus::Completed || task.pages > 0 {
@@ -178,10 +178,10 @@ pub async fn get_task_pages(
let pages: Vec = clickhouse_client
.query(
- "SELECT ?fields FROM file_chunks WHERE task_id = ? AND id > ? ORDER BY page LIMIT ?",
+ "SELECT ?fields FROM file_chunks WHERE task_id = ? AND page > ? ORDER BY page LIMIT ?",
)
.bind(task.id.clone())
- .bind(offset_id.unwrap_or(uuid::Uuid::nil()))
+ .bind(offset_id.unwrap_or(0))
.bind(limit)
.fetch_all()
.await
diff --git a/pdf2md/server/src/operators/pdf_chunk.rs b/pdf2md/server/src/operators/pdf_chunk.rs
index c6423a6cd..72c182508 100644
--- a/pdf2md/server/src/operators/pdf_chunk.rs
+++ b/pdf2md/server/src/operators/pdf_chunk.rs
@@ -20,9 +20,17 @@ use regex::Regex;
use s3::creds::time::OffsetDateTime;
const CHUNK_SYSTEM_PROMPT: &str = "
- Convert the following PDF page to markdown.
- Return only the markdown with no explanation text.
- Do not exclude any content from the page.";
+Convert this PDF page to markdown formatting, following these requirements:
+
+1. Break the content into logical sections with clear markdown headings (# for main sections, ## for subsections, etc.)
+2. Create section headers that accurately reflect the content and hierarchy of each part
+3. Include all body content from the page
+4. Exclude any PDF headers and footers
+5. Return only the formatted markdown without any explanatory text
+6. Match the original document's content organization but with explicit markdown structure
+
+Please provide the markdown version using this structured approach.
+";
fn get_data_url_from_image(img: DynamicImage) -> Result {
let mut encoded = Vec::new();
@@ -108,7 +116,7 @@ async fn get_markdown_from_image(
if let Some(prev_md_doc) = prev_md_doc {
let prev_md_doc_message = ChatMessage::System {
content: ChatMessageContent::Text(format!(
- "Markdown must maintain consistent formatting with the following page: \n\n {}",
+ "Markdown must maintain consistent formatting with the following page, DO NOT INCLUDE CONTENT FROM THIS PAGE IN YOUR RESPONSE: \n\n {}",
prev_md_doc
)),
name: None,
diff --git a/pdf2md/server/src/routes/jinja_templates.rs b/pdf2md/server/src/routes/jinja_templates.rs
index 0baad2041..17cfc4e86 100644
--- a/pdf2md/server/src/routes/jinja_templates.rs
+++ b/pdf2md/server/src/routes/jinja_templates.rs
@@ -3,7 +3,7 @@ use crate::{
get_env, Templates,
};
use actix_web::{get, HttpResponse};
-use minijinja::{context, path_loader, Environment};
+use minijinja::context;
#[utoipa::path(
get,
diff --git a/server/src/bin/csv-jsonl-worker.rs b/server/src/bin/csv-jsonl-worker.rs
index 727685d54..9c763c414 100644
--- a/server/src/bin/csv-jsonl-worker.rs
+++ b/server/src/bin/csv-jsonl-worker.rs
@@ -505,7 +505,7 @@ async fn process_csv_jsonl_file(
rebalance_chunks: Some(false),
split_delimiters: None,
target_splits_per_chunk: None,
- use_pdf2md_ocr: None,
+ pdf2md_options: None,
base64_file: "".to_string(),
},
csv_jsonl_worker_message.dataset_id,
diff --git a/server/src/bin/file-worker.rs b/server/src/bin/file-worker.rs
index b12d768b6..37557438b 100644
--- a/server/src/bin/file-worker.rs
+++ b/server/src/bin/file-worker.rs
@@ -7,18 +7,33 @@ use std::sync::{
Arc,
};
use trieve_server::{
- data::models::{self, FileWorkerMessage},
+ data::models::{self, ChunkGroup, FileWorkerMessage},
errors::ServiceError,
establish_connection, get_env,
+ handlers::chunk_handler::ChunkReqPayload,
operators::{
clickhouse_operator::{ClickHouseEvent, EventQueue},
dataset_operator::get_dataset_and_organization_from_dataset_id_query,
file_operator::{
create_file_chunks, create_file_query, get_aws_bucket, preprocess_file_to_chunks,
},
+ group_operator::{create_group_from_file_query, create_groups_query},
},
};
+const HEADING_CHUNKING_SYSTEM_PROMPT: &str = "
+Analyze this PDF page and restructure it into clear markdown sections based on the content topics. For each distinct topic or theme discussed:
+
+1. Create a meaningful section heading using markdown (# for main topics, ## for subtopics)
+2. Group related content under each heading
+3. Break up dense paragraphs into more readable chunks where appropriate
+4. Maintain the key information but organize it by subject matter
+5. Skip headers, footers, and page numbers
+6. Focus on semantic organization rather than matching the original layout
+
+Please provide just the reorganized markdown without any explanatory text
+";
+
fn main() {
dotenvy::dotenv().ok();
env_logger::builder()
@@ -257,7 +272,7 @@ pub struct PollTaskResponse {
pub status: String,
pub created_at: String,
pub pages: Option>,
- pub pagination_token: Option,
+ pub pagination_token: Option,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
@@ -298,153 +313,285 @@ async fn upload_file(
)
.await?;
- if file_name.ends_with(".pdf") {
- if let Some(true) = file_worker_message.upload_file_data.use_pdf2md_ocr {
- log::info!("Using pdf2md for OCR for file");
- let pdf2md_url = std::env::var("PDF2MD_URL")
- .expect("PDF2MD_URL must be set")
- .to_string();
-
- let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string());
-
- let pdf2md_client = reqwest::Client::new();
- let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone());
-
- let json_value = serde_json::json!({
- "file_name": file_name,
- "base64_file": encoded_file.clone()
- });
-
- log::info!("Sending file to pdf2md");
- let pdf2md_response = pdf2md_client
- .post(format!("{}/api/task", pdf2md_url))
- .header("Content-Type", "application/json")
- .header("Authorization", &pdf2md_auth)
- .json(&json_value)
- .send()
- .await
- .map_err(|err| {
- log::error!("Could not send file to pdf2md {:?}", err);
- ServiceError::BadRequest("Could not send file to pdf2md".to_string())
- })?;
-
- let response = pdf2md_response.json::().await;
-
- let task_id = match response {
- Ok(response) => response.id,
- Err(err) => {
- log::error!("Could not parse id from pdf2md {:?}", err);
- return Err(ServiceError::BadRequest(format!(
- "Could not parse id from pdf2md {:?}",
- err
- )));
- }
- };
+ let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
- log::info!("Waiting on Task {}", task_id);
- #[allow(unused_assignments)]
- let mut chunk_htmls: Vec = vec![];
- let mut pagination_token: Option = None;
- let mut completed = false;
+ let created_file = create_file_query(
+ file_id,
+ file_size_mb,
+ file_worker_message.upload_file_data.clone(),
+ file_worker_message.dataset_id,
+ web_pool.clone(),
+ )
+ .await?;
- loop {
- if completed && pagination_token.is_none() {
- break;
- }
- tokio::time::sleep(std::time::Duration::from_secs(5)).await;
- let request = if let Some(pagination_token) = &pagination_token {
- log::info!(
- "Polling on task {} with pagination token {}",
- task_id,
- pagination_token
- );
- pdf2md_client
- .get(
- format!(
- "{}/api/task/{}?pagination_token={}",
- pdf2md_url, task_id, pagination_token
- )
- .as_str(),
+ let group_id = if !file_worker_message
+ .upload_file_data
+ .pdf2md_options
+ .as_ref()
+ .is_some_and(|options| options.split_headings.unwrap_or(false))
+ {
+ let chunk_group = ChunkGroup::from_details(
+ Some(file_worker_message.upload_file_data.file_name.clone()),
+ file_worker_message.upload_file_data.description.clone(),
+ dataset_org_plan_sub.dataset.id,
+ file_worker_message
+ .upload_file_data
+ .group_tracking_id
+ .clone(),
+ None,
+ file_worker_message
+ .upload_file_data
+ .tag_set
+ .clone()
+ .map(|tag_set| tag_set.into_iter().map(Some).collect()),
+ );
+
+ let chunk_group_option = create_groups_query(vec![chunk_group], true, web_pool.clone())
+ .await
+ .map_err(|e| {
+ log::error!("Could not create group {:?}", e);
+ ServiceError::BadRequest("Could not create group".to_string())
+ })?
+ .pop();
+
+ let chunk_group = match chunk_group_option {
+ Some(group) => group,
+ None => {
+ return Err(ServiceError::BadRequest(
+ "Could not create group from file".to_string(),
+ ));
+ }
+ };
+
+ let group_id = chunk_group.id;
+
+ create_group_from_file_query(group_id, created_file.id, web_pool.clone())
+ .await
+ .map_err(|e| {
+ log::error!("Could not create group from file {:?}", e);
+ e
+ })?;
+
+ Some(group_id)
+ } else {
+ None
+ };
+
+ if file_name.ends_with(".pdf")
+ && file_worker_message
+ .upload_file_data
+ .pdf2md_options
+ .as_ref()
+ .is_some_and(|options| options.use_pdf2md_ocr)
+ {
+ log::info!("Using pdf2md for OCR for file");
+ let pdf2md_url = std::env::var("PDF2MD_URL")
+ .expect("PDF2MD_URL must be set")
+ .to_string();
+
+ let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string());
+
+ let pdf2md_client = reqwest::Client::new();
+ let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone());
+
+ let mut json_value = serde_json::json!({
+ "file_name": file_name,
+ "base64_file": encoded_file.clone()
+ });
+
+ if let Some(system_prompt) = &file_worker_message
+ .upload_file_data
+ .pdf2md_options
+ .as_ref()
+ .map(|options| options.system_prompt.clone())
+ {
+ json_value["system_prompt"] = serde_json::json!(system_prompt);
+ }
+
+ if file_worker_message
+ .upload_file_data
+ .pdf2md_options
+ .as_ref()
+ .is_some_and(|options| options.split_headings.unwrap_or(false))
+ {
+ json_value["system_prompt"] = serde_json::json!(format!(
+ "{}\n\n{}",
+ json_value["system_prompt"].as_str().unwrap_or(""),
+ HEADING_CHUNKING_SYSTEM_PROMPT
+ ));
+ }
+
+ log::info!("Sending file to pdf2md");
+ let pdf2md_response = pdf2md_client
+ .post(format!("{}/api/task", pdf2md_url))
+ .header("Content-Type", "application/json")
+ .header("Authorization", &pdf2md_auth)
+ .json(&json_value)
+ .send()
+ .await
+ .map_err(|err| {
+ log::error!("Could not send file to pdf2md {:?}", err);
+ ServiceError::BadRequest("Could not send file to pdf2md".to_string())
+ })?;
+
+ let response = pdf2md_response.json::().await;
+
+ let task_id = match response {
+ Ok(response) => response.id,
+ Err(err) => {
+ log::error!("Could not parse id from pdf2md {:?}", err);
+ return Err(ServiceError::BadRequest(format!(
+ "Could not parse id from pdf2md {:?}",
+ err
+ )));
+ }
+ };
+
+ log::info!("Waiting on Task {}", task_id);
+ let mut processed_pages = std::collections::HashSet::new();
+ let mut pagination_token: Option = None;
+ let mut completed = false;
+ const PAGE_SIZE: u32 = 20;
+
+ loop {
+ if completed {
+ break;
+ }
+
+ let request = if let Some(pagination_token) = &pagination_token {
+ log::info!(
+ "Polling on task {} with pagination token {}",
+ task_id,
+ pagination_token
+ );
+ pdf2md_client
+ .get(
+ format!(
+ "{}/api/task/{}?pagination_token={}",
+ pdf2md_url, task_id, pagination_token
)
- .header("Content-Type", "application/json")
- .header("Authorization", &pdf2md_auth)
- .send()
- .await
- .map_err(|err| {
- log::error!("Could not send poll request to pdf2md {:?}", err);
- ServiceError::BadRequest(format!(
- "Could not send request to pdf2md {:?}",
- err
- ))
- })?
- } else {
- log::info!("Waiting on task {}", task_id);
- pdf2md_client
- .get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str())
- .header("Content-Type", "application/json")
- .header("Authorization", &pdf2md_auth)
- .send()
- .await
- .map_err(|err| {
- log::error!("Could not send poll request to pdf2md {:?}", err);
- ServiceError::BadRequest(format!(
- "Could not send request to pdf2md {:?}",
- err
- ))
- })?
- };
-
- let task_response = request.json::().await.map_err(|err| {
- log::error!("Could not parse response from pdf2md {:?}", err);
- ServiceError::BadRequest(format!(
- "Could not parse response from pdf2md {:?}",
- err
- ))
- })?;
-
- if task_response.status == "Completed" && task_response.pages.is_some() {
- completed = true;
- pagination_token = task_response.pagination_token.clone();
- if let Some(pages) = task_response.pages {
- log::info!("Got {} pages from task {}", pages.len(), task_id);
- for page in pages {
- log::info!(".");
- chunk_htmls.push(page.content.clone());
- }
+ .as_str(),
+ )
+ .header("Content-Type", "application/json")
+ .header("Authorization", &pdf2md_auth)
+ .send()
+ .await
+ .map_err(|err| {
+ log::error!("Could not send poll request to pdf2md {:?}", err);
+ ServiceError::BadRequest(format!(
+ "Could not send request to pdf2md {:?}",
+ err
+ ))
+ })?
+ } else {
+ log::info!("Waiting on task {}", task_id);
+ pdf2md_client
+ .get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str())
+ .header("Content-Type", "application/json")
+ .header("Authorization", &pdf2md_auth)
+ .send()
+ .await
+ .map_err(|err| {
+ log::error!("Could not send poll request to pdf2md {:?}", err);
+ ServiceError::BadRequest(format!(
+ "Could not send request to pdf2md {:?}",
+ err
+ ))
+ })?
+ };
+
+ let task_response = request.json::().await.map_err(|err| {
+ log::error!("Could not parse response from pdf2md {:?}", err);
+ ServiceError::BadRequest(format!("Could not parse response from pdf2md {:?}", err))
+ })?;
+
+ let mut new_chunks = Vec::new();
+ if let Some(pages) = task_response.pages {
+ log::info!("Got {} pages from task {}", pages.len(), task_id);
+
+ for page in pages {
+ let page_id = format!("{}", page.page_num);
+
+ if !processed_pages.contains(&page_id) {
+ processed_pages.insert(page_id);
+ let metadata = file_worker_message
+ .upload_file_data
+ .metadata
+ .clone()
+ .map(|mut metadata| {
+ metadata["page_num"] = serde_json::json!(page.page_num);
+ metadata["file_name"] = serde_json::json!(task_response.file_name);
+ metadata
+ })
+ .or(Some(serde_json::json!({
+ "page_num": page.page_num,
+ "file_name": task_response.file_name
+ })));
+
+ let create_chunk_data = ChunkReqPayload {
+ chunk_html: Some(page.content.clone()),
+ semantic_content: None,
+ link: file_worker_message.upload_file_data.link.clone(),
+ tag_set: file_worker_message.upload_file_data.tag_set.clone(),
+ metadata,
+ group_ids: None,
+ group_tracking_ids: None,
+ location: None,
+ tracking_id: file_worker_message
+ .upload_file_data
+ .group_tracking_id
+ .clone()
+ .map(|tracking_id| format!("{}|{}", tracking_id, page.page_num)),
+ upsert_by_tracking_id: None,
+ time_stamp: file_worker_message.upload_file_data.time_stamp.clone(),
+ weight: None,
+ split_avg: None,
+ convert_html_to_text: None,
+ image_urls: None,
+ num_value: None,
+ fulltext_boost: None,
+ semantic_boost: None,
+ };
+ new_chunks.push(create_chunk_data);
}
+ }
- continue;
- } else {
- log::info!("Task {} not ready", task_id);
- tokio::time::sleep(std::time::Duration::from_secs(5)).await;
- continue;
+ if !new_chunks.is_empty() {
+ create_file_chunks(
+ created_file.id,
+ file_worker_message.upload_file_data.clone(),
+ new_chunks.clone(),
+ dataset_org_plan_sub.clone(),
+ group_id,
+ web_pool.clone(),
+ event_queue.clone(),
+ redis_conn.clone(),
+ )
+ .await?;
}
}
- // Poll Chunks from pdf chunks from service
- let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
- let created_file = create_file_query(
- file_id,
- file_size_mb,
- file_worker_message.upload_file_data.clone(),
- file_worker_message.dataset_id,
- web_pool.clone(),
- )
- .await?;
-
- create_file_chunks(
- created_file.id,
- file_worker_message.upload_file_data,
- chunk_htmls,
- dataset_org_plan_sub,
- web_pool.clone(),
- event_queue.clone(),
- redis_conn,
- )
- .await?;
-
- return Ok(Some(file_id));
+ completed = task_response.status == "Completed";
+
+ let page_start = pagination_token.unwrap_or(0);
+
+ let has_complete_range = (page_start..page_start + PAGE_SIZE)
+ .all(|page_num| processed_pages.contains(&(page_num + 1).to_string()));
+
+ if let Some(token) = task_response.pagination_token {
+ if has_complete_range || completed {
+ pagination_token = Some(token);
+ }
+ }
+
+ if new_chunks.is_empty() {
+ tokio::time::sleep(std::time::Duration::from_secs(5)).await;
+ } else if !has_complete_range && !completed {
+ tokio::time::sleep(std::time::Duration::from_secs(1)).await;
+ }
}
+
+ return Ok(Some(file_id));
}
let tika_url = std::env::var("TIKA_URL")
@@ -480,17 +627,6 @@ async fn upload_file(
));
}
- let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
-
- let created_file = create_file_query(
- file_id,
- file_size_mb,
- file_worker_message.upload_file_data.clone(),
- file_worker_message.dataset_id,
- web_pool.clone(),
- )
- .await?;
-
if file_worker_message
.upload_file_data
.create_chunks
@@ -512,11 +648,41 @@ async fn upload_file(
return Err(ServiceError::BadRequest("Could not parse file".to_string()));
};
+ let chunks = chunk_htmls
+ .into_iter()
+ .enumerate()
+ .map(|(i, chunk_html)| ChunkReqPayload {
+ chunk_html: Some(chunk_html),
+ semantic_content: None,
+ link: file_worker_message.upload_file_data.link.clone(),
+ tag_set: file_worker_message.upload_file_data.tag_set.clone(),
+ metadata: file_worker_message.upload_file_data.metadata.clone(),
+ group_ids: None,
+ group_tracking_ids: None,
+ location: None,
+ tracking_id: file_worker_message
+ .upload_file_data
+ .group_tracking_id
+ .clone()
+ .map(|tracking_id| format!("{}|{}", tracking_id, i)),
+ upsert_by_tracking_id: None,
+ time_stamp: file_worker_message.upload_file_data.time_stamp.clone(),
+ weight: None,
+ split_avg: None,
+ convert_html_to_text: None,
+ image_urls: None,
+ num_value: None,
+ fulltext_boost: None,
+ semantic_boost: None,
+ })
+ .collect::>();
+
create_file_chunks(
created_file.id,
file_worker_message.upload_file_data,
- chunk_htmls,
+ chunks,
dataset_org_plan_sub,
+ group_id,
web_pool.clone(),
event_queue.clone(),
redis_conn,
diff --git a/server/src/handlers/file_handler.rs b/server/src/handlers/file_handler.rs
index 353a78ff8..2d17431a6 100644
--- a/server/src/handlers/file_handler.rs
+++ b/server/src/handlers/file_handler.rs
@@ -83,7 +83,17 @@ pub struct UploadFileReqPayload {
/// Group tracking id is an optional field which allows you to specify the tracking id of the group that is created from the file. Chunks created will be created with the tracking id of `group_tracking_id|`
pub group_tracking_id: Option,
/// Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o. Default is false.
- pub use_pdf2md_ocr: Option,
+ pub pdf2md_options: Option,
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
+pub struct Pdf2MdOptions {
+ /// Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o. Default is false.
+ pub use_pdf2md_ocr: bool,
+ /// Prompt to use for the gpt-4o model. Default is None.
+ pub system_prompt: Option,
+ /// Split headings is an optional field which allows you to specify whether or not to split headings into separate chunks. Default is false.
+ pub split_headings: Option,
}
#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
diff --git a/server/src/lib.rs b/server/src/lib.rs
index a5ecfa173..3cbc8778b 100644
--- a/server/src/lib.rs
+++ b/server/src/lib.rs
@@ -357,6 +357,7 @@ impl Modify for SecurityAddon {
handlers::file_handler::CreatePresignedUrlForCsvJsonlReqPayload,
handlers::file_handler::CreatePresignedUrlForCsvJsonResponseBody,
handlers::file_handler::UploadHtmlPageReqPayload,
+ handlers::file_handler::Pdf2MdOptions,
handlers::invitation_handler::InvitationData,
handlers::event_handler::GetEventsData,
handlers::organization_handler::CreateOrganizationReqPayload,
diff --git a/server/src/operators/file_operator.rs b/server/src/operators/file_operator.rs
index 601448f3c..a5bf10649 100644
--- a/server/src/operators/file_operator.rs
+++ b/server/src/operators/file_operator.rs
@@ -122,83 +122,145 @@ pub fn preprocess_file_to_chunks(
Ok(chunk_htmls)
}
+pub fn split_markdown_by_headings(markdown_text: &str) -> Vec {
+ let lines: Vec<&str> = markdown_text
+ .trim()
+ .lines()
+ .filter(|x| !x.trim().is_empty())
+ .collect();
+ let mut chunks = Vec::new();
+ let mut current_content = Vec::new();
+ let mut pending_heading: Option = None;
+
+ fn is_heading(line: &str) -> bool {
+ line.trim().starts_with('#')
+ }
+
+ fn save_chunk(chunks: &mut Vec, content: &[String]) {
+ if !content.is_empty() {
+ chunks.push(content.join("\n").trim().to_string());
+ }
+ }
+
+ for (i, line) in lines.iter().enumerate() {
+ if is_heading(line) {
+ if !current_content.is_empty() {
+ save_chunk(&mut chunks, ¤t_content);
+ current_content.clear();
+ }
+
+ if i + 1 < lines.len() && !is_heading(lines[i + 1]) {
+ if let Some(heading) = pending_heading.take() {
+ current_content.push(heading);
+ }
+ current_content.push(line.to_string());
+ } else {
+ pending_heading = Some(line.to_string());
+ }
+ } else if !line.trim().is_empty() || !current_content.is_empty() {
+ current_content.push(line.to_string());
+ }
+ }
+
+ if !current_content.is_empty() {
+ save_chunk(&mut chunks, ¤t_content);
+ }
+
+ if let Some(heading) = pending_heading {
+ chunks.push(heading);
+ }
+
+ if chunks.is_empty() && !lines.is_empty() {
+ chunks.push(lines.join("\n").trim().to_string());
+ }
+
+ chunks
+}
+
#[allow(clippy::too_many_arguments)]
pub async fn create_file_chunks(
created_file_id: uuid::Uuid,
upload_file_data: UploadFileReqPayload,
- chunk_htmls: Vec,
+ mut chunks: Vec,
dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan,
+ group_id: Option,
pool: web::Data,
event_queue: web::Data,
mut redis_conn: MultiplexedConnection,
) -> Result<(), ServiceError> {
- let mut chunks: Vec = [].to_vec();
-
let name = upload_file_data.file_name.clone();
- let chunk_group = ChunkGroup::from_details(
- Some(name.clone()),
- upload_file_data.description.clone(),
- dataset_org_plan_sub.dataset.id,
- upload_file_data.group_tracking_id.clone(),
- None,
- upload_file_data
- .tag_set
- .clone()
- .map(|tag_set| tag_set.into_iter().map(Some).collect()),
- );
-
- let chunk_group_option = create_groups_query(vec![chunk_group], true, pool.clone())
- .await
- .map_err(|e| {
- log::error!("Could not create group {:?}", e);
- ServiceError::BadRequest("Could not create group".to_string())
- })?
- .pop();
-
- let chunk_group = match chunk_group_option {
- Some(group) => group,
- None => {
- return Err(ServiceError::BadRequest(
- "Could not create group from file".to_string(),
- ));
+ if upload_file_data
+ .pdf2md_options
+ .is_some_and(|x| x.split_headings.unwrap_or(false))
+ {
+ let mut new_chunks = Vec::new();
+
+ for chunk in chunks {
+ let chunk_group = ChunkGroup::from_details(
+ Some(format!(
+ "{}-page-{}",
+ name,
+ chunk.metadata.as_ref().unwrap_or(&serde_json::json!({
+ "page_num": 0
+ }))["page_num"]
+ .as_i64()
+ .unwrap_or(0)
+ )),
+ upload_file_data.description.clone(),
+ dataset_org_plan_sub.dataset.id,
+ upload_file_data.group_tracking_id.clone(),
+ chunk.metadata.clone(),
+ upload_file_data
+ .tag_set
+ .clone()
+ .map(|tag_set| tag_set.into_iter().map(Some).collect()),
+ );
+
+ let chunk_group_option = create_groups_query(vec![chunk_group], true, pool.clone())
+ .await
+ .map_err(|e| {
+ log::error!("Could not create group {:?}", e);
+ ServiceError::BadRequest("Could not create group".to_string())
+ })?
+ .pop();
+
+ let chunk_group = match chunk_group_option {
+ Some(group) => group,
+ None => {
+ return Err(ServiceError::BadRequest(
+ "Could not create group from file".to_string(),
+ ));
+ }
+ };
+
+ let group_id = chunk_group.id;
+
+ create_group_from_file_query(group_id, created_file_id, pool.clone())
+ .await
+ .map_err(|e| {
+ log::error!("Could not create group from file {:?}", e);
+ e
+ })?;
+
+ let split_chunks =
+ split_markdown_by_headings(chunk.chunk_html.as_ref().unwrap_or(&String::new()));
+
+ for (i, split_chunk) in split_chunks.into_iter().enumerate() {
+ new_chunks.push(ChunkReqPayload {
+ chunk_html: Some(split_chunk),
+ tracking_id: chunk.tracking_id.clone().map(|x| format!("{}-{}", x, i)),
+ group_ids: Some(vec![group_id]),
+ ..chunk.clone()
+ });
+ }
}
- };
-
- let group_id = chunk_group.id;
-
- create_group_from_file_query(group_id, created_file_id, pool.clone())
- .await
- .map_err(|e| {
- log::error!("Could not create group from file {:?}", e);
- e
- })?;
- for (i, chunk_html) in chunk_htmls.iter().enumerate() {
- let create_chunk_data = ChunkReqPayload {
- chunk_html: Some(chunk_html.clone()),
- semantic_content: None,
- link: upload_file_data.link.clone(),
- tag_set: upload_file_data.tag_set.clone(),
- metadata: upload_file_data.metadata.clone(),
- group_ids: Some(vec![group_id]),
- group_tracking_ids: None,
- location: None,
- tracking_id: upload_file_data
- .group_tracking_id
- .clone()
- .map(|tracking_id| format!("{}|{}", tracking_id, i)),
- upsert_by_tracking_id: None,
- time_stamp: upload_file_data.time_stamp.clone(),
- weight: None,
- split_avg: None,
- convert_html_to_text: None,
- image_urls: None,
- num_value: None,
- fulltext_boost: None,
- semantic_boost: None,
- };
- chunks.push(create_chunk_data);
+ chunks = new_chunks;
+ } else {
+ chunks.iter_mut().for_each(|chunk| {
+ chunk.group_ids = group_id.map(|id| vec![id]);
+ });
}
let chunk_count = get_row_count_for_organization_id_query(