Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions backend/apps/ai/common/base/chunk_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def help(self) -> str:
def process_chunks_batch(self, entities: list[Model]) -> int:
"""Process a batch of entities to create or update chunks."""
processed = 0
batch_chunks_to_create = []
batch_chunks_to_create = {}
content_type = ContentType.objects.get_for_model(self.model_class)

for entity in entities:
Expand Down Expand Up @@ -68,7 +68,9 @@ def process_chunks_batch(self, entities: list[Model]) -> int:
openai_client=self.openai_client,
save=False,
):
batch_chunks_to_create.extend(chunks)
for chunk in chunks:
key = (chunk.context_id, chunk.text)
batch_chunks_to_create[key] = chunk
processed += 1
self.stdout.write(
self.style.SUCCESS(f"Created {len(chunks)} new chunks for {entity_key}")
Expand All @@ -77,7 +79,21 @@ def process_chunks_batch(self, entities: list[Model]) -> int:
self.stdout.write(f"Chunks for {entity_key} are already up to date.")

if batch_chunks_to_create:
Chunk.bulk_save(batch_chunks_to_create)
context_ids = {context_id for context_id, _ in batch_chunks_to_create}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not using (again) our backend-wide approach for handling models data -- model.update_data + model.bulk.save instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using bulk save, the integrity error was due to duplication -- I had 2 options one was to do it while checking in this way and other was to make the ignore_conflicts = true in the bulk save

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already have a check in the Chunk::update_data method. Why it's not enough and why it can't be handled there?

Copy link
Collaborator Author

@Dishant1804 Dishant1804 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the update_data method will check for duplicates but the integrity error is caused due to race condition, while updating we are deleting the necessary chunks which are to be updated

so to avoid the race condition we are doing this
ignore_conflicts=true in the bulk_save was an option too I read about it on stackoverflow -- but I dont think that is viable
what do you think?

candidate_chunk_texts = {text for _, text in batch_chunks_to_create}

existing_keys = set(
Chunk.objects.filter(
context_id__in=context_ids, text__in=candidate_chunk_texts
).values_list("context_id", "text")
)

chunks_to_insert = [
chunk for key, chunk in batch_chunks_to_create.items() if key not in existing_keys
]

if chunks_to_insert:
Chunk.bulk_save(chunks_to_insert)

return processed

Expand Down
6 changes: 4 additions & 2 deletions backend/apps/ai/common/base/context_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def process_context_batch(self, entities: list[Model]) -> int:
):
processed += 1
entity_key = self.get_entity_key(entity)
self.stdout.write(f"Created context for {entity_key}")
self.stdout.write(f"Created/updated context for {entity_key}")
else:
entity_key = self.get_entity_key(entity)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is repetitive.

Copy link
Collaborator Author

@Dishant1804 Dishant1804 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

self.stdout.write(self.style.ERROR(f"Failed to create context for {entity_key}"))
self.stdout.write(
self.style.ERROR(f"Failed to create/update context for {entity_key}")
)

return processed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@ class Command(BaseContextCommand):
key_field_name = "slack_message_id"
model_class = Message

def add_arguments(self, parser):
"""Override to use different default batch size for messages."""
super().add_arguments(parser)
parser.add_argument(
"--message-key",
type=str,
help="Process only the message with this key",
)
parser.add_argument(
"--all",
action="store_true",
help="Process all the messages",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Number of messages to process in each batch",
)

def extract_content(self, entity: Message) -> tuple[str, str]:
"""Extract content from the message."""
return entity.cleaned_text or "", ""
Expand Down
76 changes: 70 additions & 6 deletions backend/tests/apps/ai/common/base/chunk_command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,13 @@ def test_process_chunks_batch_success(
mock_create_chunks.return_value = mock_chunks
command.openai_client = Mock()

with patch.object(command.stdout, "write") as mock_write:
with (
patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter,
patch.object(command.stdout, "write") as mock_write,
):
mock_qs = Mock()
mock_qs.values_list.return_value = []
mock_chunk_filter.return_value = mock_qs
result = command.process_chunks_batch([mock_entity])

assert result == 1
Expand Down Expand Up @@ -261,14 +267,20 @@ def test_process_chunks_batch_multiple_entities(
mock_create_chunks.return_value = mock_chunks[:2]
command.openai_client = Mock()

with patch.object(command.stdout, "write"):
with (
patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter,
patch.object(command.stdout, "write"),
):
mock_qs = Mock()
mock_qs.values_list.return_value = []
mock_chunk_filter.return_value = mock_qs
result = command.process_chunks_batch(entities)

assert result == 3
assert mock_create_chunks.call_count == 3
mock_bulk_save.assert_called_once()
bulk_save_args = mock_bulk_save.call_args[0][0]
assert len(bulk_save_args) == 6
assert len(bulk_save_args) == 2

@patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model")
@patch("apps.ai.common.base.chunk_command.Context.objects.filter")
Expand Down Expand Up @@ -325,14 +337,22 @@ def test_process_chunks_batch_content_combination(
"extract_content",
return_value=("prose", "metadata"),
):
command.process_chunks_batch([mock_entity])
with patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter:
mock_qs = Mock()
mock_qs.values_list.return_value = []
mock_chunk_filter.return_value = mock_qs
command.process_chunks_batch([mock_entity])

expected_content = "metadata\n\nprose"
mock_split_text.assert_called_once_with(expected_content)

mock_split_text.reset_mock()
with patch.object(command, "extract_content", return_value=("prose", "")):
command.process_chunks_batch([mock_entity])
with patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter:
mock_qs = Mock()
mock_qs.values_list.return_value = []
mock_chunk_filter.return_value = mock_qs
command.process_chunks_batch([mock_entity])

mock_split_text.assert_called_with("prose")

Expand Down Expand Up @@ -402,11 +422,55 @@ def test_process_chunks_batch_metadata_only_content(
"extract_content",
return_value=("", "metadata"),
):
command.process_chunks_batch([mock_entity])
with patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter:
mock_qs = Mock()
mock_qs.values_list.return_value = []
mock_chunk_filter.return_value = mock_qs
command.process_chunks_batch([mock_entity])

mock_split_text.assert_called_once_with("metadata\n\n")
mock_bulk_save.assert_called_once()

@patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model")
@patch("apps.ai.common.base.chunk_command.Context.objects.filter")
@patch("apps.ai.models.chunk.Chunk.split_text")
@patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings")
@patch("apps.ai.models.chunk.Chunk.bulk_save")
def test_process_chunks_batch_with_duplicates(
self,
mock_bulk_save,
mock_create_chunks,
mock_split_text,
mock_context_filter,
mock_get_content_type,
command,
mock_entity,
mock_context,
mock_content_type,
mock_chunks,
):
"""Test that duplicate chunks are filtered out before bulk save."""
mock_get_content_type.return_value = mock_content_type
mock_context_filter.return_value.first.return_value = mock_context
mock_split_text.return_value = ["chunk1", "chunk2", "chunk3"]
mock_create_chunks.return_value = mock_chunks
command.openai_client = Mock()

with (
patch("apps.ai.models.chunk.Chunk.objects.filter") as mock_chunk_filter,
patch.object(command.stdout, "write"),
):
mock_qs = Mock()
mock_qs.values_list.return_value = [(1, "Chunk text 1")]
mock_chunk_filter.return_value = mock_qs

result = command.process_chunks_batch([mock_entity])

assert result == 1
mock_bulk_save.assert_called_once()
bulk_save_args = mock_bulk_save.call_args[0][0]
assert len(bulk_save_args) == 2

def test_process_chunks_batch_whitespace_only_content(
self, command, mock_entity, mock_context, mock_content_type
):
Expand Down
12 changes: 6 additions & 6 deletions backend/tests/apps/ai/common/base/context_command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_process_context_batch_success(
entity=mock_entity,
source="owasp_test_entity",
)
mock_write.assert_called_once_with("Created context for test-key-123")
mock_write.assert_called_once_with("Created/updated context for test-key-123")

@patch("apps.ai.common.base.context_command.Context")
def test_process_context_batch_creation_fails(self, mock_context_class, command, mock_entity):
Expand All @@ -130,7 +130,7 @@ def test_process_context_batch_creation_fails(self, mock_context_class, command,
mock_context_class.update_data.assert_called_once()
mock_write.assert_called_once()
call_args = mock_write.call_args[0][0]
assert "Failed to create context for test-key-123" in str(call_args)
assert "Failed to create/update context for test-key-123" in str(call_args)

@patch("apps.ai.common.base.context_command.Context")
def test_process_context_batch_multiple_entities(
Expand Down Expand Up @@ -184,9 +184,9 @@ def test_process_context_batch_mixed_success_failure(
assert mock_write.call_count == 3

write_calls = mock_write.call_args_list
assert "Created context for test-key-1" in str(write_calls[0])
assert "Failed to create context for test-key-2" in str(write_calls[1])
assert "Created context for test-key-3" in str(write_calls[2])
assert "Created/updated context for test-key-1" in str(write_calls[0])
assert "Failed to create/update context for test-key-2" in str(write_calls[1])
assert "Created/updated context for test-key-3" in str(write_calls[2])

def test_process_context_batch_content_combination(self, command, mock_entity, mock_context):
"""Test that metadata and prose content are properly combined."""
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_get_entity_key_usage(self, command, mock_context):
with patch.object(command.stdout, "write") as mock_write:
command.process_context_batch([entity])

mock_write.assert_called_once_with("Created context for custom-entity-key")
mock_write.assert_called_once_with("Created/updated context for custom-entity-key")

def test_process_context_batch_empty_list(self, command):
"""Test process_context_batch with empty entity list."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def test_process_context_batch_success(self, command, mock_committee):
entity=mock_committee,
source="owasp_committee",
)
mock_write.assert_called_once_with("Created context for test-committee")
mock_write.assert_called_once_with(
"Created/updated context for test-committee"
)

def test_process_context_batch_empty_content(self, command, mock_committee):
"""Test context batch processing with empty content."""
Expand Down Expand Up @@ -206,7 +208,7 @@ def test_process_context_batch_create_failure(self, command, mock_committee):

assert result == 0
mock_error.assert_called_once_with(
"Failed to create context for test-committee"
"Failed to create/update context for test-committee"
)
mock_write.assert_called_once_with("ERROR: Failed")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ def test_add_arguments(self, command):
parser = Mock()
command.add_arguments(parser)

assert parser.add_argument.call_count == 6
assert parser.add_argument.call_count == 3
calls = parser.add_argument.call_args_list

# First 3 calls are from parent class (BaseAICommand)
assert calls[0][0] == ("--message-key",)
assert calls[0][1]["type"] is str
assert "Process only the message with this key" in calls[0][1]["help"]
Expand All @@ -86,19 +85,5 @@ def test_add_arguments(self, command):

assert calls[2][0] == ("--batch-size",)
assert calls[2][1]["type"] is int
assert calls[2][1]["default"] == 50 # Default from parent class
assert calls[2][1]["default"] == 50
assert "Number of messages to process in each batch" in calls[2][1]["help"]

# Next 3 calls are from the command itself (duplicates with different defaults)
assert calls[3][0] == ("--message-key",)
assert calls[3][1]["type"] is str
assert "Process only the message with this key" in calls[3][1]["help"]

assert calls[4][0] == ("--all",)
assert calls[4][1]["action"] == "store_true"
assert "Process all the messages" in calls[4][1]["help"]

assert calls[5][0] == ("--batch-size",)
assert calls[5][1]["type"] is int
assert calls[5][1]["default"] == 100 # Overridden default from command
assert "Number of messages to process in each batch" in calls[5][1]["help"]