Skip to content

Commit

Permalink
feature: heading based chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh committed Dec 14, 2024
1 parent a1f0d24 commit 5da1466
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 66 deletions.
39 changes: 38 additions & 1 deletion frontends/search/src/components/UploadFile.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ export const UploadFile = () => {
const [targetSplitsPerChunk, setTargetSplitsPerChunk] = createSignal(20);
const [rebalanceChunks, setRebalanceChunks] = createSignal(false);
const [useGptChunking, setUseGptChunking] = createSignal(false);
const [useHeadingBasedChunking, setUseHeadingBasedChunking] =
createSignal(false);
const [groupTrackingId, setGroupTrackingId] = createSignal("");
const [systemPrompt, setSystemPrompt] = createSignal("");

const [showFileInput, setShowFileInput] = createSignal(true);
const [showFolderInput, setShowFolderInput] = createSignal(false);
Expand Down Expand Up @@ -149,7 +152,11 @@ export const UploadFile = () => {
split_delimiters: splitDelimiters(),
target_splits_per_chunk: targetSplitsPerChunk(),
rebalance_chunks: rebalanceChunks(),
pdf2md_options: { use_pdf2md_ocr: useGptChunking() },
pdf2md_options: {
use_pdf2md_ocr: useGptChunking(),
split_headings: useHeadingBasedChunking(),
system_prompt: systemPrompt(),
},
group_tracking_id:
groupTrackingId() === "" ? undefined : groupTrackingId(),
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
Expand Down Expand Up @@ -343,6 +350,36 @@ export const UploadFile = () => {
onInput={(e) => setUseGptChunking(e.currentTarget.checked)}
class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
<div class="flex flex-row items-center space-x-2">
<div>Heading Based Chunking</div>
<Tooltip
body={<BsInfoCircle />}
tooltipText="If set to true, Trieve will use the headings in the document to chunk the text."
/>
</div>
<input
type="checkbox"
checked={useHeadingBasedChunking()}
onInput={(e) =>
setUseHeadingBasedChunking(e.currentTarget.checked)
}
class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
<div class="flex flex-col space-y-2">
<div class="flex flex-row items-center space-x-2">
<div>System Prompt</div>
<Tooltip
body={<BsInfoCircle />}
tooltipText="System prompt to use when chunking. This is an optional field which allows you to specify the system prompt to use when chunking the text. If not specified, the default system prompt is used. However, you may want to use a different system prompt."
/>
</div>
<textarea
placeholder="optional system prompt to use when chunking"
value={systemPrompt()}
onInput={(e) => setSystemPrompt(e.target.value)}
class="w-full rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
</div>
</div>
</Show>
<div class="m-1 mb-1 flex flex-row gap-2">
Expand Down
16 changes: 12 additions & 4 deletions pdf2md/server/src/operators/pdf_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ use regex::Regex;
use s3::creds::time::OffsetDateTime;

const CHUNK_SYSTEM_PROMPT: &str = "
Convert the following PDF page to markdown.
Return only the markdown with no explanation text.
Do not exclude any content from the page.";
Convert this PDF page to markdown formatting, following these requirements:
1. Break the content into logical sections with clear markdown headings (# for main sections, ## for subsections, etc.)
2. Create section headers that accurately reflect the content and hierarchy of each part
3. Include all body content from the page
4. Exclude any PDF headers and footers
5. Return only the formatted markdown without any explanatory text
6. Match the original document's content organization but with explicit markdown structure
Please provide the markdown version using this structured approach.
";

fn get_data_url_from_image(img: DynamicImage) -> Result<String, ServiceError> {
let mut encoded = Vec::new();
Expand Down Expand Up @@ -108,7 +116,7 @@ async fn get_markdown_from_image(
if let Some(prev_md_doc) = prev_md_doc {
let prev_md_doc_message = ChatMessage::System {
content: ChatMessageContent::Text(format!(
"Markdown must maintain consistent formatting with the following page: \n\n {}",
"Markdown must maintain consistent formatting with the following page, DO NOT INCLUDE CONTENT FROM THIS PAGE IN YOUR RESPONSE: \n\n {}",
prev_md_doc
)),
name: None,
Expand Down
116 changes: 94 additions & 22 deletions server/src/bin/file-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::{
Arc,
};
use trieve_server::{
data::models::{self, FileWorkerMessage},
data::models::{self, ChunkGroup, FileWorkerMessage},
errors::ServiceError,
establish_connection, get_env,
handlers::chunk_handler::ChunkReqPayload,
Expand All @@ -17,9 +17,23 @@ use trieve_server::{
file_operator::{
create_file_chunks, create_file_query, get_aws_bucket, preprocess_file_to_chunks,
},
group_operator::{create_group_from_file_query, create_groups_query},
},
};

const HEADING_CHUNKING_SYSTEM_PROMPT: &str = "
Analyze this PDF page and restructure it into clear markdown sections based on the content topics. For each distinct topic or theme discussed:
1. Create a meaningful section heading using markdown (# for main topics, ## for subtopics)
2. Group related content under each heading
3. Break up dense paragraphs into more readable chunks where appropriate
4. Maintain the key information but organize it by subject matter
5. Skip headers, footers, and page numbers
6. Focus on semantic organization rather than matching the original layout
Please provide just the reorganized markdown without any explanatory text
";

fn main() {
dotenvy::dotenv().ok();
env_logger::builder()
Expand Down Expand Up @@ -299,6 +313,70 @@ async fn upload_file(
)
.await?;

let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;

let created_file = create_file_query(
file_id,
file_size_mb,
file_worker_message.upload_file_data.clone(),
file_worker_message.dataset_id,
web_pool.clone(),
)
.await?;

let group_id = if !file_worker_message
.upload_file_data
.pdf2md_options
.as_ref()
.is_some_and(|options| options.split_headings.unwrap_or(false))
{
let chunk_group = ChunkGroup::from_details(
Some(file_worker_message.upload_file_data.file_name.clone()),
file_worker_message.upload_file_data.description.clone(),
dataset_org_plan_sub.dataset.id,
file_worker_message
.upload_file_data
.group_tracking_id
.clone(),
None,
file_worker_message
.upload_file_data
.tag_set
.clone()
.map(|tag_set| tag_set.into_iter().map(Some).collect()),
);

let chunk_group_option = create_groups_query(vec![chunk_group], true, web_pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group {:?}", e);
ServiceError::BadRequest("Could not create group".to_string())
})?
.pop();

let chunk_group = match chunk_group_option {
Some(group) => group,
None => {
return Err(ServiceError::BadRequest(
"Could not create group from file".to_string(),
));
}
};

let group_id = chunk_group.id;

create_group_from_file_query(group_id, created_file.id, web_pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group from file {:?}", e);
e
})?;

Some(group_id)
} else {
None
};

if file_name.ends_with(".pdf")
&& file_worker_message
.upload_file_data
Expand Down Expand Up @@ -330,6 +408,19 @@ async fn upload_file(
json_value["system_prompt"] = serde_json::json!(system_prompt);
}

if file_worker_message
.upload_file_data
.pdf2md_options
.as_ref()
.is_some_and(|options| options.split_headings.unwrap_or(false))
{
json_value["system_prompt"] = serde_json::json!(format!(
"{}\n\n{}",
json_value["system_prompt"].as_str().unwrap_or(""),
HEADING_CHUNKING_SYSTEM_PROMPT
));
}

log::info!("Sending file to pdf2md");
let pdf2md_response = pdf2md_client
.post(format!("{}/api/task", pdf2md_url))
Expand All @@ -356,16 +447,6 @@ async fn upload_file(
}
};

let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
let created_file = create_file_query(
file_id,
file_size_mb,
file_worker_message.upload_file_data.clone(),
file_worker_message.dataset_id,
web_pool.clone(),
)
.await?;

log::info!("Waiting on Task {}", task_id);
let mut processed_pages = std::collections::HashSet::new();
let mut pagination_token: Option<u32> = None;
Expand Down Expand Up @@ -481,6 +562,7 @@ async fn upload_file(
file_worker_message.upload_file_data.clone(),
new_chunks.clone(),
dataset_org_plan_sub.clone(),
group_id,
web_pool.clone(),
event_queue.clone(),
redis_conn.clone(),
Expand Down Expand Up @@ -545,17 +627,6 @@ async fn upload_file(
));
}

let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;

let created_file = create_file_query(
file_id,
file_size_mb,
file_worker_message.upload_file_data.clone(),
file_worker_message.dataset_id,
web_pool.clone(),
)
.await?;

if file_worker_message
.upload_file_data
.create_chunks
Expand Down Expand Up @@ -611,6 +682,7 @@ async fn upload_file(
file_worker_message.upload_file_data,
chunks,
dataset_org_plan_sub,
group_id,
web_pool.clone(),
event_queue.clone(),
redis_conn,
Expand Down
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ impl Modify for SecurityAddon {
handlers::file_handler::CreatePresignedUrlForCsvJsonlReqPayload,
handlers::file_handler::CreatePresignedUrlForCsvJsonResponseBody,
handlers::file_handler::UploadHtmlPageReqPayload,
handlers::file_handler::Pdf2MdOptions,
handlers::invitation_handler::InvitationData,
handlers::event_handler::GetEventsData,
handlers::organization_handler::CreateOrganizationReqPayload,
Expand Down
Loading

0 comments on commit 5da1466

Please sign in to comment.