Skip to content

Commit 18d6f8b

Browse files
NestBot AI Assistant Contexts (#1891)
* rag tool for agent * code rabbit suggestions implemented * suggestions implemented * code rabbit suggestion * added context model * retrieving data from context model * removed try except * Suggestions implemented * code rabbit suggestion * removed deafult * updated tests * de coupled context and chunks * update method for context * major revamp and test cases * code rabbit suggestions * major revamp * suggestions implemented * refactoring * more tests * more refactoring * suggestions implemented * chunk model update * update logic and suggestions * code rabbit suggestions * before tests and question * sugesstions and decoupling with tests * sugesstions implemented * Update code * updated code * spelling fixes * test changes * Update tests --------- Co-authored-by: Arkadii Yakovets <[email protected]> Co-authored-by: Arkadii Yakovets <[email protected]>
1 parent 7d81085 commit 18d6f8b

File tree

68 files changed

+6295
-973
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+6295
-973
lines changed

backend/apps/ai/Makefile

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,43 @@
1-
ai-create-chapter-chunks:
2-
@echo "Creating chapter chunks"
3-
@CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command
1+
ai-run-rag-tool:
2+
@echo "Running RAG tool"
3+
@CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command
44

5-
ai-create-committee-chunks:
6-
@echo "Creating committee chunks"
7-
@CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command
5+
ai-update-chapter-chunks:
6+
@echo "Updating chapter chunks"
7+
@CMD="python manage.py ai_update_chapter_chunks" $(MAKE) exec-backend-command
88

9-
ai-create-event-chunks:
10-
@echo "Creating event chunks"
11-
@CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command
9+
ai-update-chapter-context:
10+
@echo "Updating chapter context"
11+
@CMD="python manage.py ai_update_chapter_context" $(MAKE) exec-backend-command
1212

13-
ai-create-project-chunks:
14-
@echo "Creating project chunks"
15-
@CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command
13+
ai-update-committee-chunks:
14+
@echo "Updating committee chunks"
15+
@CMD="python manage.py ai_update_committee_chunks" $(MAKE) exec-backend-command
1616

17-
ai-create-slack-message-chunks:
18-
@echo "Creating Slack message chunks"
19-
@CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command
17+
ai-update-committee-context:
18+
@echo "Updating committee context"
19+
@CMD="python manage.py ai_update_committee_context" $(MAKE) exec-backend-command
2020

21-
ai-run-rag-tool:
22-
@echo "Running RAG tool"
23-
@CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command
21+
ai-update-event-chunks:
22+
@echo "Updating event chunks"
23+
@CMD="python manage.py ai_update_event_chunks" $(MAKE) exec-backend-command
24+
25+
ai-update-event-context:
26+
@echo "Updating event context"
27+
@CMD="python manage.py ai_update_event_context" $(MAKE) exec-backend-command
28+
29+
ai-update-project-chunks:
30+
@echo "Updating project chunks"
31+
@CMD="python manage.py ai_update_project_chunks" $(MAKE) exec-backend-command
32+
33+
ai-update-project-context:
34+
@echo "Updating project context"
35+
@CMD="python manage.py ai_update_project_context" $(MAKE) exec-backend-command
36+
37+
ai-update-slack-message-chunks:
38+
@echo "Updating Slack message chunks"
39+
@CMD="python manage.py ai_update_slack_message_chunks" $(MAKE) exec-backend-command
40+
41+
ai-update-slack-message-context:
42+
@echo "Updating Slack message context"
43+
@CMD="python manage.py ai_update_slack_message_context" $(MAKE) exec-backend-command

backend/apps/ai/admin.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.contrib import admin
44

55
from apps.ai.models.chunk import Chunk
6+
from apps.ai.models.context import Context
67

78

89
class ChunkAdmin(admin.ModelAdmin):
@@ -11,9 +12,25 @@ class ChunkAdmin(admin.ModelAdmin):
1112
list_display = (
1213
"id",
1314
"text",
14-
"content_type",
15+
"context",
1516
)
16-
search_fields = ("text", "object_id")
17+
list_filter = ("context__entity_type",)
18+
search_fields = ("text",)
19+
20+
21+
class ContextAdmin(admin.ModelAdmin):
22+
"""Admin for Context model."""
23+
24+
list_display = (
25+
"id",
26+
"content",
27+
"entity_type",
28+
"entity_id",
29+
"source",
30+
)
31+
list_filter = ("entity_type", "source")
32+
search_fields = ("content", "source")
1733

1834

1935
admin.site.register(Chunk, ChunkAdmin)
36+
admin.site.register(Context, ContextAdmin)

backend/apps/ai/agent/tools/rag/rag_tool.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@ def __init__(
2828
ValueError: If the OpenAI API key is not set.
2929
3030
"""
31-
try:
32-
self.retriever = Retriever(embedding_model=embedding_model)
33-
self.generator = Generator(chat_model=chat_model)
34-
except Exception:
35-
logger.exception("Failed to initialize RAG tool")
36-
raise
31+
self.retriever = Retriever(embedding_model=embedding_model)
32+
self.generator = Generator(chat_model=chat_model)
3733

3834
def query(
3935
self,

backend/apps/ai/agent/tools/rag/retriever.py

Lines changed: 83 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
class Retriever:
2222
"""A class for retrieving relevant text chunks for a RAG."""
2323

24-
SUPPORTED_CONTENT_TYPES = ["event", "project", "chapter", "committee", "message"]
24+
SUPPORTED_ENTITY_TYPES = (
25+
"chapter",
26+
"committee",
27+
"event",
28+
"message",
29+
"project",
30+
)
2531

2632
def __init__(self, embedding_model: str = "text-embedding-3-small"):
2733
"""Initialize the Retriever.
@@ -36,7 +42,6 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"):
3642
if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")):
3743
error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set"
3844
raise ValueError(error_msg)
39-
4045
self.openai_client = openai.OpenAI(api_key=openai_api_key)
4146
self.embedding_model = embedding_model
4247
logger.info("Retriever initialized with embedding model: %s", self.embedding_model)
@@ -64,121 +69,116 @@ def get_query_embedding(self, query: str) -> list[float]:
6469
logger.exception("Unexpected error while generating embedding")
6570
raise
6671

67-
def get_source_name(self, content_object) -> str:
72+
def get_source_name(self, entity) -> str:
6873
"""Get the name/identifier for the content object."""
6974
for attr in ("name", "title", "login", "key", "summary"):
70-
if getattr(content_object, attr, None):
71-
return str(getattr(content_object, attr))
72-
73-
return str(content_object)
75+
if getattr(entity, attr, None):
76+
return str(getattr(entity, attr))
77+
return str(entity)
7478

75-
def get_additional_context(self, content_object, content_type: str) -> dict[str, Any]:
79+
def get_additional_context(self, entity) -> dict[str, Any]:
7680
"""Get additional context information based on content type.
7781
7882
Args:
79-
content_object: The source object.
80-
content_type: The model name of the content object.
83+
entity: The source object.
8184
8285
Returns:
8386
A dictionary with additional context information.
8487
8588
"""
8689
context = {}
87-
clean_content_type = content_type.split(".")[-1] if "." in content_type else content_type
88-
90+
clean_content_type = entity.__class__.__name__.lower()
8991
if clean_content_type == "chapter":
9092
context.update(
9193
{
92-
"location": getattr(content_object, "suggested_location", None),
93-
"region": getattr(content_object, "region", None),
94-
"country": getattr(content_object, "country", None),
95-
"postal_code": getattr(content_object, "postal_code", None),
96-
"currency": getattr(content_object, "currency", None),
97-
"meetup_group": getattr(content_object, "meetup_group", None),
98-
"tags": getattr(content_object, "tags", []),
99-
"topics": getattr(content_object, "topics", []),
100-
"leaders": getattr(content_object, "leaders_raw", []),
101-
"related_urls": getattr(content_object, "related_urls", []),
102-
"is_active": getattr(content_object, "is_active", None),
103-
"url": getattr(content_object, "url", None),
94+
"location": getattr(entity, "suggested_location", None),
95+
"region": getattr(entity, "region", None),
96+
"country": getattr(entity, "country", None),
97+
"postal_code": getattr(entity, "postal_code", None),
98+
"currency": getattr(entity, "currency", None),
99+
"meetup_group": getattr(entity, "meetup_group", None),
100+
"tags": getattr(entity, "tags", []),
101+
"topics": getattr(entity, "topics", []),
102+
"leaders": getattr(entity, "leaders_raw", []),
103+
"related_urls": getattr(entity, "related_urls", []),
104+
"is_active": getattr(entity, "is_active", None),
105+
"url": getattr(entity, "url", None),
104106
}
105107
)
106108
elif clean_content_type == "project":
107109
context.update(
108110
{
109-
"level": getattr(content_object, "level", None),
110-
"project_type": getattr(content_object, "type", None),
111-
"languages": getattr(content_object, "languages", []),
112-
"topics": getattr(content_object, "topics", []),
113-
"licenses": getattr(content_object, "licenses", []),
114-
"tags": getattr(content_object, "tags", []),
115-
"custom_tags": getattr(content_object, "custom_tags", []),
116-
"stars_count": getattr(content_object, "stars_count", None),
117-
"forks_count": getattr(content_object, "forks_count", None),
118-
"contributors_count": getattr(content_object, "contributors_count", None),
119-
"releases_count": getattr(content_object, "releases_count", None),
120-
"open_issues_count": getattr(content_object, "open_issues_count", None),
121-
"leaders": getattr(content_object, "leaders_raw", []),
122-
"related_urls": getattr(content_object, "related_urls", []),
123-
"created_at": getattr(content_object, "created_at", None),
124-
"updated_at": getattr(content_object, "updated_at", None),
125-
"released_at": getattr(content_object, "released_at", None),
126-
"health_score": getattr(content_object, "health_score", None),
127-
"is_active": getattr(content_object, "is_active", None),
128-
"track_issues": getattr(content_object, "track_issues", None),
129-
"url": getattr(content_object, "url", None),
111+
"level": getattr(entity, "level", None),
112+
"project_type": getattr(entity, "type", None),
113+
"languages": getattr(entity, "languages", []),
114+
"topics": getattr(entity, "topics", []),
115+
"licenses": getattr(entity, "licenses", []),
116+
"tags": getattr(entity, "tags", []),
117+
"custom_tags": getattr(entity, "custom_tags", []),
118+
"stars_count": getattr(entity, "stars_count", None),
119+
"forks_count": getattr(entity, "forks_count", None),
120+
"contributors_count": getattr(entity, "contributors_count", None),
121+
"releases_count": getattr(entity, "releases_count", None),
122+
"open_issues_count": getattr(entity, "open_issues_count", None),
123+
"leaders": getattr(entity, "leaders_raw", []),
124+
"related_urls": getattr(entity, "related_urls", []),
125+
"created_at": getattr(entity, "created_at", None),
126+
"updated_at": getattr(entity, "updated_at", None),
127+
"released_at": getattr(entity, "released_at", None),
128+
"health_score": getattr(entity, "health_score", None),
129+
"is_active": getattr(entity, "is_active", None),
130+
"track_issues": getattr(entity, "track_issues", None),
131+
"url": getattr(entity, "url", None),
130132
}
131133
)
132134
elif clean_content_type == "event":
133135
context.update(
134136
{
135-
"start_date": getattr(content_object, "start_date", None),
136-
"end_date": getattr(content_object, "end_date", None),
137-
"location": getattr(content_object, "suggested_location", None),
138-
"category": getattr(content_object, "category", None),
139-
"latitude": getattr(content_object, "latitude", None),
140-
"longitude": getattr(content_object, "longitude", None),
141-
"url": getattr(content_object, "url", None),
142-
"description": getattr(content_object, "description", None),
143-
"summary": getattr(content_object, "summary", None),
137+
"start_date": getattr(entity, "start_date", None),
138+
"end_date": getattr(entity, "end_date", None),
139+
"location": getattr(entity, "suggested_location", None),
140+
"category": getattr(entity, "category", None),
141+
"latitude": getattr(entity, "latitude", None),
142+
"longitude": getattr(entity, "longitude", None),
143+
"url": getattr(entity, "url", None),
144+
"description": getattr(entity, "description", None),
145+
"summary": getattr(entity, "summary", None),
144146
}
145147
)
146148
elif clean_content_type == "committee":
147149
context.update(
148150
{
149-
"is_active": getattr(content_object, "is_active", None),
150-
"leaders": getattr(content_object, "leaders", []),
151-
"url": getattr(content_object, "url", None),
152-
"description": getattr(content_object, "description", None),
153-
"summary": getattr(content_object, "summary", None),
154-
"tags": getattr(content_object, "tags", []),
155-
"topics": getattr(content_object, "topics", []),
156-
"related_urls": getattr(content_object, "related_urls", []),
151+
"is_active": getattr(entity, "is_active", None),
152+
"leaders": getattr(entity, "leaders", []),
153+
"url": getattr(entity, "url", None),
154+
"description": getattr(entity, "description", None),
155+
"summary": getattr(entity, "summary", None),
156+
"tags": getattr(entity, "tags", []),
157+
"topics": getattr(entity, "topics", []),
158+
"related_urls": getattr(entity, "related_urls", []),
157159
}
158160
)
159161
elif clean_content_type == "message":
160162
context.update(
161163
{
162164
"channel": (
163-
getattr(content_object.conversation, "slack_channel_id", None)
164-
if hasattr(content_object, "conversation") and content_object.conversation
165+
getattr(entity.conversation, "slack_channel_id", None)
166+
if hasattr(entity, "conversation") and entity.conversation
165167
else None
166168
),
167169
"thread_ts": (
168-
getattr(content_object.parent_message, "ts", None)
169-
if hasattr(content_object, "parent_message")
170-
and content_object.parent_message
170+
getattr(entity.parent_message, "ts", None)
171+
if hasattr(entity, "parent_message") and entity.parent_message
171172
else None
172173
),
173-
"ts": getattr(content_object, "ts", None),
174+
"ts": getattr(entity, "ts", None),
174175
"user": (
175-
getattr(content_object.author, "name", None)
176-
if hasattr(content_object, "author") and content_object.author
176+
getattr(entity.author, "name", None)
177+
if hasattr(entity, "author") and entity.author
177178
else None
178179
),
179180
}
180181
)
181-
182182
return {k: v for k, v in context.items() if v is not None}
183183

184184
def retrieve(
@@ -201,51 +201,43 @@ def retrieve(
201201
202202
"""
203203
query_embedding = self.get_query_embedding(query)
204-
205204
if not content_types:
206205
content_types = self.extract_content_types_from_query(query)
207-
208206
queryset = Chunk.objects.annotate(
209207
similarity=1 - CosineDistance("embedding", query_embedding)
210208
).filter(similarity__gte=similarity_threshold)
211-
212209
if content_types:
213210
content_type_query = Q()
214211
for name in content_types:
215212
lower_name = name.lower()
216213
if "." in lower_name:
217214
app_label, model = lower_name.split(".", 1)
218215
content_type_query |= Q(
219-
content_type__app_label=app_label, content_type__model=model
216+
context__entity_type__app_label=app_label,
217+
context__entity_type__model=model,
220218
)
221219
else:
222-
content_type_query |= Q(content_type__model=lower_name)
220+
content_type_query |= Q(context__entity_type__model=lower_name)
223221
queryset = queryset.filter(content_type_query)
224222

225-
chunks = (
226-
queryset.select_related("content_type")
227-
.prefetch_related("content_object")
228-
.order_by("-similarity")[:limit]
229-
)
223+
chunks = queryset.select_related("context__entity_type").order_by("-similarity")[:limit]
230224

231225
results = []
232226
for chunk in chunks:
233-
if not chunk.content_object:
227+
if not chunk.context or not chunk.context.entity:
234228
logger.warning("Content object is None for chunk %s. Skipping.", chunk.id)
235229
continue
236230

237-
source_name = self.get_source_name(chunk.content_object)
238-
additional_context = self.get_additional_context(
239-
chunk.content_object, chunk.content_type.model
240-
)
231+
source_name = self.get_source_name(chunk.context.entity)
232+
additional_context = self.get_additional_context(chunk.context.entity)
241233

242234
results.append(
243235
{
244236
"text": chunk.text,
245237
"similarity": float(chunk.similarity),
246-
"source_type": chunk.content_type.model,
238+
"source_type": chunk.context.entity_type.model,
247239
"source_name": source_name,
248-
"source_id": chunk.object_id,
240+
"source_id": chunk.context.entity_id,
249241
"additional_context": additional_context,
250242
}
251243
)
@@ -262,13 +254,12 @@ def extract_content_types_from_query(self, query: str) -> list[str]:
262254
A list of detected content type names.
263255
264256
"""
265-
detected_types = []
266257
query_words = set(re.findall(r"\b\w+\b", query.lower()))
267258

268259
detected_types = [
269-
content_type
270-
for content_type in self.SUPPORTED_CONTENT_TYPES
271-
if content_type in query_words or f"{content_type}s" in query_words
260+
entity_type
261+
for entity_type in self.SUPPORTED_ENTITY_TYPES
262+
if entity_type in query_words or f"{entity_type}s" in query_words
272263
]
273264

274265
if detected_types:

backend/apps/ai/common/base/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)