2121class Retriever :
2222 """A class for retrieving relevant text chunks for a RAG."""
2323
24- SUPPORTED_CONTENT_TYPES = ["event" , "project" , "chapter" , "committee" , "message" ]
24+ SUPPORTED_ENTITY_TYPES = (
25+ "chapter" ,
26+ "committee" ,
27+ "event" ,
28+ "message" ,
29+ "project" ,
30+ )
2531
2632 def __init__ (self , embedding_model : str = "text-embedding-3-small" ):
2733 """Initialize the Retriever.
@@ -36,7 +42,6 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"):
3642 if not (openai_api_key := os .getenv ("DJANGO_OPEN_AI_SECRET_KEY" )):
3743 error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set"
3844 raise ValueError (error_msg )
39-
4045 self .openai_client = openai .OpenAI (api_key = openai_api_key )
4146 self .embedding_model = embedding_model
4247 logger .info ("Retriever initialized with embedding model: %s" , self .embedding_model )
@@ -64,121 +69,116 @@ def get_query_embedding(self, query: str) -> list[float]:
6469 logger .exception ("Unexpected error while generating embedding" )
6570 raise
6671
67- def get_source_name (self , content_object ) -> str :
72+ def get_source_name (self , entity ) -> str :
6873 """Get the name/identifier for the content object."""
6974 for attr in ("name" , "title" , "login" , "key" , "summary" ):
70- if getattr (content_object , attr , None ):
71- return str (getattr (content_object , attr ))
72-
73- return str (content_object )
75+ if getattr (entity , attr , None ):
76+ return str (getattr (entity , attr ))
77+ return str (entity )
7478
75- def get_additional_context (self , content_object , content_type : str ) -> dict [str , Any ]:
79+ def get_additional_context (self , entity ) -> dict [str , Any ]:
7680 """Get additional context information based on content type.
7781
7882 Args:
79- content_object: The source object.
80- content_type: The model name of the content object.
83+ entity: The source object.
8184
8285 Returns:
8386 A dictionary with additional context information.
8487
8588 """
8689 context = {}
87- clean_content_type = content_type .split ("." )[- 1 ] if "." in content_type else content_type
88-
90+ clean_content_type = entity .__class__ .__name__ .lower ()
8991 if clean_content_type == "chapter" :
9092 context .update (
9193 {
92- "location" : getattr (content_object , "suggested_location" , None ),
93- "region" : getattr (content_object , "region" , None ),
94- "country" : getattr (content_object , "country" , None ),
95- "postal_code" : getattr (content_object , "postal_code" , None ),
96- "currency" : getattr (content_object , "currency" , None ),
97- "meetup_group" : getattr (content_object , "meetup_group" , None ),
98- "tags" : getattr (content_object , "tags" , []),
99- "topics" : getattr (content_object , "topics" , []),
100- "leaders" : getattr (content_object , "leaders_raw" , []),
101- "related_urls" : getattr (content_object , "related_urls" , []),
102- "is_active" : getattr (content_object , "is_active" , None ),
103- "url" : getattr (content_object , "url" , None ),
94+ "location" : getattr (entity , "suggested_location" , None ),
95+ "region" : getattr (entity , "region" , None ),
96+ "country" : getattr (entity , "country" , None ),
97+ "postal_code" : getattr (entity , "postal_code" , None ),
98+ "currency" : getattr (entity , "currency" , None ),
99+ "meetup_group" : getattr (entity , "meetup_group" , None ),
100+ "tags" : getattr (entity , "tags" , []),
101+ "topics" : getattr (entity , "topics" , []),
102+ "leaders" : getattr (entity , "leaders_raw" , []),
103+ "related_urls" : getattr (entity , "related_urls" , []),
104+ "is_active" : getattr (entity , "is_active" , None ),
105+ "url" : getattr (entity , "url" , None ),
104106 }
105107 )
106108 elif clean_content_type == "project" :
107109 context .update (
108110 {
109- "level" : getattr (content_object , "level" , None ),
110- "project_type" : getattr (content_object , "type" , None ),
111- "languages" : getattr (content_object , "languages" , []),
112- "topics" : getattr (content_object , "topics" , []),
113- "licenses" : getattr (content_object , "licenses" , []),
114- "tags" : getattr (content_object , "tags" , []),
115- "custom_tags" : getattr (content_object , "custom_tags" , []),
116- "stars_count" : getattr (content_object , "stars_count" , None ),
117- "forks_count" : getattr (content_object , "forks_count" , None ),
118- "contributors_count" : getattr (content_object , "contributors_count" , None ),
119- "releases_count" : getattr (content_object , "releases_count" , None ),
120- "open_issues_count" : getattr (content_object , "open_issues_count" , None ),
121- "leaders" : getattr (content_object , "leaders_raw" , []),
122- "related_urls" : getattr (content_object , "related_urls" , []),
123- "created_at" : getattr (content_object , "created_at" , None ),
124- "updated_at" : getattr (content_object , "updated_at" , None ),
125- "released_at" : getattr (content_object , "released_at" , None ),
126- "health_score" : getattr (content_object , "health_score" , None ),
127- "is_active" : getattr (content_object , "is_active" , None ),
128- "track_issues" : getattr (content_object , "track_issues" , None ),
129- "url" : getattr (content_object , "url" , None ),
111+ "level" : getattr (entity , "level" , None ),
112+ "project_type" : getattr (entity , "type" , None ),
113+ "languages" : getattr (entity , "languages" , []),
114+ "topics" : getattr (entity , "topics" , []),
115+ "licenses" : getattr (entity , "licenses" , []),
116+ "tags" : getattr (entity , "tags" , []),
117+ "custom_tags" : getattr (entity , "custom_tags" , []),
118+ "stars_count" : getattr (entity , "stars_count" , None ),
119+ "forks_count" : getattr (entity , "forks_count" , None ),
120+ "contributors_count" : getattr (entity , "contributors_count" , None ),
121+ "releases_count" : getattr (entity , "releases_count" , None ),
122+ "open_issues_count" : getattr (entity , "open_issues_count" , None ),
123+ "leaders" : getattr (entity , "leaders_raw" , []),
124+ "related_urls" : getattr (entity , "related_urls" , []),
125+ "created_at" : getattr (entity , "created_at" , None ),
126+ "updated_at" : getattr (entity , "updated_at" , None ),
127+ "released_at" : getattr (entity , "released_at" , None ),
128+ "health_score" : getattr (entity , "health_score" , None ),
129+ "is_active" : getattr (entity , "is_active" , None ),
130+ "track_issues" : getattr (entity , "track_issues" , None ),
131+ "url" : getattr (entity , "url" , None ),
130132 }
131133 )
132134 elif clean_content_type == "event" :
133135 context .update (
134136 {
135- "start_date" : getattr (content_object , "start_date" , None ),
136- "end_date" : getattr (content_object , "end_date" , None ),
137- "location" : getattr (content_object , "suggested_location" , None ),
138- "category" : getattr (content_object , "category" , None ),
139- "latitude" : getattr (content_object , "latitude" , None ),
140- "longitude" : getattr (content_object , "longitude" , None ),
141- "url" : getattr (content_object , "url" , None ),
142- "description" : getattr (content_object , "description" , None ),
143- "summary" : getattr (content_object , "summary" , None ),
137+ "start_date" : getattr (entity , "start_date" , None ),
138+ "end_date" : getattr (entity , "end_date" , None ),
139+ "location" : getattr (entity , "suggested_location" , None ),
140+ "category" : getattr (entity , "category" , None ),
141+ "latitude" : getattr (entity , "latitude" , None ),
142+ "longitude" : getattr (entity , "longitude" , None ),
143+ "url" : getattr (entity , "url" , None ),
144+ "description" : getattr (entity , "description" , None ),
145+ "summary" : getattr (entity , "summary" , None ),
144146 }
145147 )
146148 elif clean_content_type == "committee" :
147149 context .update (
148150 {
149- "is_active" : getattr (content_object , "is_active" , None ),
150- "leaders" : getattr (content_object , "leaders" , []),
151- "url" : getattr (content_object , "url" , None ),
152- "description" : getattr (content_object , "description" , None ),
153- "summary" : getattr (content_object , "summary" , None ),
154- "tags" : getattr (content_object , "tags" , []),
155- "topics" : getattr (content_object , "topics" , []),
156- "related_urls" : getattr (content_object , "related_urls" , []),
151+ "is_active" : getattr (entity , "is_active" , None ),
152+ "leaders" : getattr (entity , "leaders" , []),
153+ "url" : getattr (entity , "url" , None ),
154+ "description" : getattr (entity , "description" , None ),
155+ "summary" : getattr (entity , "summary" , None ),
156+ "tags" : getattr (entity , "tags" , []),
157+ "topics" : getattr (entity , "topics" , []),
158+ "related_urls" : getattr (entity , "related_urls" , []),
157159 }
158160 )
159161 elif clean_content_type == "message" :
160162 context .update (
161163 {
162164 "channel" : (
163- getattr (content_object .conversation , "slack_channel_id" , None )
164- if hasattr (content_object , "conversation" ) and content_object .conversation
165+ getattr (entity .conversation , "slack_channel_id" , None )
166+ if hasattr (entity , "conversation" ) and entity .conversation
165167 else None
166168 ),
167169 "thread_ts" : (
168- getattr (content_object .parent_message , "ts" , None )
169- if hasattr (content_object , "parent_message" )
170- and content_object .parent_message
170+ getattr (entity .parent_message , "ts" , None )
171+ if hasattr (entity , "parent_message" ) and entity .parent_message
171172 else None
172173 ),
173- "ts" : getattr (content_object , "ts" , None ),
174+ "ts" : getattr (entity , "ts" , None ),
174175 "user" : (
175- getattr (content_object .author , "name" , None )
176- if hasattr (content_object , "author" ) and content_object .author
176+ getattr (entity .author , "name" , None )
177+ if hasattr (entity , "author" ) and entity .author
177178 else None
178179 ),
179180 }
180181 )
181-
182182 return {k : v for k , v in context .items () if v is not None }
183183
184184 def retrieve (
@@ -201,51 +201,43 @@ def retrieve(
201201
202202 """
203203 query_embedding = self .get_query_embedding (query )
204-
205204 if not content_types :
206205 content_types = self .extract_content_types_from_query (query )
207-
208206 queryset = Chunk .objects .annotate (
209207 similarity = 1 - CosineDistance ("embedding" , query_embedding )
210208 ).filter (similarity__gte = similarity_threshold )
211-
212209 if content_types :
213210 content_type_query = Q ()
214211 for name in content_types :
215212 lower_name = name .lower ()
216213 if "." in lower_name :
217214 app_label , model = lower_name .split ("." , 1 )
218215 content_type_query |= Q (
219- content_type__app_label = app_label , content_type__model = model
216+ context__entity_type__app_label = app_label ,
217+ context__entity_type__model = model ,
220218 )
221219 else :
222- content_type_query |= Q (content_type__model = lower_name )
220+ content_type_query |= Q (context__entity_type__model = lower_name )
223221 queryset = queryset .filter (content_type_query )
224222
225- chunks = (
226- queryset .select_related ("content_type" )
227- .prefetch_related ("content_object" )
228- .order_by ("-similarity" )[:limit ]
229- )
223+ chunks = queryset .select_related ("context__entity_type" ).order_by ("-similarity" )[:limit ]
230224
231225 results = []
232226 for chunk in chunks :
233- if not chunk .content_object :
227+ if not chunk .context or not chunk . context . entity :
234228 logger .warning ("Content object is None for chunk %s. Skipping." , chunk .id )
235229 continue
236230
237- source_name = self .get_source_name (chunk .content_object )
238- additional_context = self .get_additional_context (
239- chunk .content_object , chunk .content_type .model
240- )
231+ source_name = self .get_source_name (chunk .context .entity )
232+ additional_context = self .get_additional_context (chunk .context .entity )
241233
242234 results .append (
243235 {
244236 "text" : chunk .text ,
245237 "similarity" : float (chunk .similarity ),
246- "source_type" : chunk .content_type .model ,
238+ "source_type" : chunk .context . entity_type .model ,
247239 "source_name" : source_name ,
248- "source_id" : chunk .object_id ,
240+ "source_id" : chunk .context . entity_id ,
249241 "additional_context" : additional_context ,
250242 }
251243 )
@@ -262,13 +254,12 @@ def extract_content_types_from_query(self, query: str) -> list[str]:
262254 A list of detected content type names.
263255
264256 """
265- detected_types = []
266257 query_words = set (re .findall (r"\b\w+\b" , query .lower ()))
267258
268259 detected_types = [
269- content_type
270- for content_type in self .SUPPORTED_CONTENT_TYPES
271- if content_type in query_words or f"{ content_type } s" in query_words
260+ entity_type
261+ for entity_type in self .SUPPORTED_ENTITY_TYPES
262+ if entity_type in query_words or f"{ entity_type } s" in query_words
272263 ]
273264
274265 if detected_types :
0 commit comments