Skip to content

Commit

Permalink
feature: sort by top_score and hallucination_score for rag
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh authored and cdxker committed Dec 14, 2024
1 parent b162e1e commit 086a433
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import { usePagination } from "../usePagination";
import { RAGAnalyticsFilter } from "shared/types";
import { DatasetContext } from "../../../contexts/DatasetContext";

export type RAGSortByCols = "created_at" | "latency";
export type RAGSortByCols =
| "created_at"
| "latency"
| "hallucination_score"
| "top_score";

export const useDataExplorerRag = () => {
const queryClient = useQueryClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,15 @@ export const RAGAnalyticsPage = () => {
);
},
},
{
accessorKey: "top_score",
header: "Top Score",
sortable: true,
},
{
accessorKey: "hallucination_score",
header: "Hallucination Score",
sortable: true,
},
{
accessorKey: "query_rating",
Expand Down Expand Up @@ -153,37 +159,35 @@ export const RAGAnalyticsPage = () => {
<div>
<div class="mt-4 pb-1 text-lg">All RAG Queries</div>
<div class="rounded-md bg-white">
<Show when={ragTableQuery.data}>
<Card>
<RAGFilterBar noPadding filters={filters} setFilters={setFilters} />
<div class="mt-4 overflow-x-auto">
<TanStackTable
pages={pages}
perPage={10}
table={table as unknown as Table<SearchQueryEvent>}
onRowClick={(row: SearchQueryEvent) =>
navigate(
`/dataset/${datasetContext.datasetId()}/analytics/rag/${
row.id
}`,
)
}
exportFn={(page: number) =>
getRAGQueries({
datasetId: datasetContext.datasetId(),
filter: filters,
page: page,
sort_by: sortBy(),
sort_order: sortOrder(),
})
}
/>
<Show when={ragTableQuery.data?.length === 0}>
<div class="py-8 text-center">No Data.</div>
</Show>
</div>
</Card>
</Show>
<Card>
<RAGFilterBar noPadding filters={filters} setFilters={setFilters} />
<div class="mt-4 overflow-x-auto">
<TanStackTable
pages={pages}
perPage={10}
table={table as unknown as Table<SearchQueryEvent>}
onRowClick={(row: SearchQueryEvent) =>
navigate(
`/dataset/${datasetContext.datasetId()}/analytics/rag/${
row.id
}`,
)
}
exportFn={(page: number) =>
getRAGQueries({
datasetId: datasetContext.datasetId(),
filter: filters,
page: page,
sort_by: sortBy(),
sort_order: sortOrder(),
})
}
/>
<Show when={ragTableQuery.data?.length === 0}>
<div class="py-8 text-center">No Data.</div>
</Show>
</div>
</Card>
</div>
<div class="my-4 border-b border-b-neutral-200 pt-2" />
<RagAnalyticsGraphs />
Expand Down
6 changes: 5 additions & 1 deletion frontends/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,11 @@ export interface RecommendationsAnalyticsFilter {
recommendation_type?: "chunk" | "group";
}

export type RAGSortBy = "created_at" | "latency";
export type RAGSortBy =
| "created_at"
| "latency"
| "hallucination_score"
| "top_score";
export type SearchSortBy = "created_at" | "latency" | "top_score";

export type SortOrder = "desc" | "asc";
Expand Down
8 changes: 7 additions & 1 deletion server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5040,6 +5040,7 @@ pub struct RagQueryEvent {
pub results: Vec<serde_json::Value>,
pub dataset_id: uuid::Uuid,
pub llm_response: String,
pub top_score: f32,
pub query_rating: Option<SearchQueryRating>,
pub hallucination_score: f64,
pub detected_hallucinations: Vec<String>,
Expand All @@ -5058,7 +5059,7 @@ impl From<String> for ClickhouseRagTypes {
}

impl RagQueryEventClickhouse {
pub async fn from_clickhouse(self, pool: web::Data<Pool>) -> RagQueryEvent {
pub async fn from_clickhouse(self, pool: web::Data<Pool>, top_score: f32) -> RagQueryEvent {
let chunk_ids = self
.results
.iter()
Expand Down Expand Up @@ -5096,6 +5097,7 @@ impl RagQueryEventClickhouse {
user_message: self.user_message,
search_id: uuid::Uuid::from_bytes(*self.search_id.as_bytes()),
results,
top_score,
query_rating,
dataset_id: uuid::Uuid::from_bytes(*self.dataset_id.as_bytes()),
llm_response: self.llm_response,
Expand Down Expand Up @@ -6280,6 +6282,10 @@ pub enum SearchSortBy {
#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone)]
#[serde(rename_all = "snake_case")]
pub enum RAGSortBy {
#[display(fmt = "hallucination_score")]
HallucinationScore,
#[display(fmt = "top_score")]
TopScore,
#[display(fmt = "created_at")]
CreatedAt,
#[display(fmt = "latency")]
Expand Down
22 changes: 15 additions & 7 deletions server/src/operators/analytics_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,11 @@ pub async fn get_rag_queries_query(
) -> Result<RagQueryResponse, ServiceError> {
let mut query_string = String::from(
"SELECT
?fields
?fields,
top_score,
FROM
rag_queries
JOIN search_queries ON rag_queries.search_id = search_queries.id
WHERE dataset_id = ?",
);

Expand All @@ -775,7 +777,7 @@ pub async fn get_rag_queries_query(
.query(query_string.as_str())
.bind(dataset_id)
.bind((page.unwrap_or(1) - 1) * 10)
.fetch_all::<RagQueryEventClickhouse>()
.fetch_all::<(RagQueryEventClickhouse, f32)>()
.await
.map_err(|e| {
log::error!("Error fetching query: {:?}", e);
Expand All @@ -785,7 +787,7 @@ pub async fn get_rag_queries_query(
let queries: Vec<RagQueryEvent> = join_all(
clickhouse_query
.into_iter()
.map(|q| q.from_clickhouse(pool.clone())),
.map(|(q, score)| q.from_clickhouse(pool.clone(), score)),
)
.await;

Expand Down Expand Up @@ -888,18 +890,24 @@ pub async fn get_rag_query(
pool: web::Data<Pool>,
clickhouse_client: &clickhouse::Client,
) -> Result<RagQueryEvent, ServiceError> {
let clickhouse_query = clickhouse_client
.query("SELECT ?fields FROM rag_queries WHERE id = ? AND dataset_id = ?")
let (clickhouse_query, top_score) = clickhouse_client
.query(
"SELECT ?fields, top_score FROM rag_queries
JOIN search_queries ON rag_queries.search_id = search_queries.id
WHERE id = ? AND dataset_id = ?",
)
.bind(request_id)
.bind(dataset_id)
.fetch_one::<RagQueryEventClickhouse>()
.fetch_one::<(RagQueryEventClickhouse, f32)>()
.await
.map_err(|e| {
log::error!("Error fetching query: {:?}", e);
ServiceError::InternalServerError("Error fetching query".to_string())
})?;

let query: RagQueryEvent = clickhouse_query.from_clickhouse(pool.clone()).await;
let query: RagQueryEvent = clickhouse_query
.from_clickhouse(pool.clone(), top_score)
.await;

Ok(query)
}
Expand Down
1 change: 0 additions & 1 deletion server/src/operators/message_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ pub async fn get_rag_chunks_query(
.get(0)
.map(|x| x.score as f32)
.unwrap_or(0.0),

latency: get_latency_from_header(search_timer.header_value()),
results: result_chunks
.score_chunks
Expand Down

0 comments on commit 086a433

Please sign in to comment.