Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: sort by top_score and hallucination_score for rag #2946

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import { usePagination } from "../usePagination";
import { RAGAnalyticsFilter } from "shared/types";
import { DatasetContext } from "../../../contexts/DatasetContext";

export type RAGSortByCols = "created_at" | "latency";
export type RAGSortByCols =
| "created_at"
| "latency"
| "hallucination_score"
| "top_score";

export const useDataExplorerRag = () => {
const queryClient = useQueryClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,15 @@ export const RAGAnalyticsPage = () => {
);
},
},
{
accessorKey: "top_score",
header: "Top Score",
sortable: true,
},
{
accessorKey: "hallucination_score",
header: "Hallucination Score",
sortable: true,
},
{
accessorKey: "query_rating",
Expand Down Expand Up @@ -153,37 +159,35 @@ export const RAGAnalyticsPage = () => {
<div>
<div class="mt-4 pb-1 text-lg">All RAG Queries</div>
<div class="rounded-md bg-white">
<Show when={ragTableQuery.data}>
<Card>
<RAGFilterBar noPadding filters={filters} setFilters={setFilters} />
<div class="mt-4 overflow-x-auto">
<TanStackTable
pages={pages}
perPage={10}
table={table as unknown as Table<SearchQueryEvent>}
onRowClick={(row: SearchQueryEvent) =>
navigate(
`/dataset/${datasetContext.datasetId()}/analytics/rag/${
row.id
}`,
)
}
exportFn={(page: number) =>
getRAGQueries({
datasetId: datasetContext.datasetId(),
filter: filters,
page: page,
sort_by: sortBy(),
sort_order: sortOrder(),
})
}
/>
<Show when={ragTableQuery.data?.length === 0}>
<div class="py-8 text-center">No Data.</div>
</Show>
</div>
</Card>
</Show>
<Card>
<RAGFilterBar noPadding filters={filters} setFilters={setFilters} />
<div class="mt-4 overflow-x-auto">
<TanStackTable
pages={pages}
perPage={10}
table={table as unknown as Table<SearchQueryEvent>}
onRowClick={(row: SearchQueryEvent) =>
navigate(
`/dataset/${datasetContext.datasetId()}/analytics/rag/${
row.id
}`,
)
}
exportFn={(page: number) =>
getRAGQueries({
datasetId: datasetContext.datasetId(),
filter: filters,
page: page,
sort_by: sortBy(),
sort_order: sortOrder(),
})
}
/>
<Show when={ragTableQuery.data?.length === 0}>
<div class="py-8 text-center">No Data.</div>
</Show>
</div>
</Card>
</div>
<div class="my-4 border-b border-b-neutral-200 pt-2" />
<RagAnalyticsGraphs />
Expand Down
6 changes: 5 additions & 1 deletion frontends/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,11 @@ export interface RecommendationsAnalyticsFilter {
recommendation_type?: "chunk" | "group";
}

export type RAGSortBy = "created_at" | "latency";
export type RAGSortBy =
| "created_at"
| "latency"
| "hallucination_score"
| "top_score";
export type SearchSortBy = "created_at" | "latency" | "top_score";

export type SortOrder = "desc" | "asc";
Expand Down
8 changes: 7 additions & 1 deletion server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5037,6 +5037,7 @@ pub struct RagQueryEvent {
pub results: Vec<serde_json::Value>,
pub dataset_id: uuid::Uuid,
pub llm_response: String,
pub top_score: f32,
pub query_rating: Option<SearchQueryRating>,
pub hallucination_score: f64,
pub detected_hallucinations: Vec<String>,
Expand All @@ -5055,7 +5056,7 @@ impl From<String> for ClickhouseRagTypes {
}

impl RagQueryEventClickhouse {
pub async fn from_clickhouse(self, pool: web::Data<Pool>) -> RagQueryEvent {
pub async fn from_clickhouse(self, pool: web::Data<Pool>, top_score: f32) -> RagQueryEvent {
let chunk_ids = self
.results
.iter()
Expand Down Expand Up @@ -5093,6 +5094,7 @@ impl RagQueryEventClickhouse {
user_message: self.user_message,
search_id: uuid::Uuid::from_bytes(*self.search_id.as_bytes()),
results,
top_score,
query_rating,
dataset_id: uuid::Uuid::from_bytes(*self.dataset_id.as_bytes()),
llm_response: self.llm_response,
Expand Down Expand Up @@ -6277,6 +6279,10 @@ pub enum SearchSortBy {
#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone)]
#[serde(rename_all = "snake_case")]
pub enum RAGSortBy {
#[display(fmt = "hallucination_score")]
HallucinationScore,
#[display(fmt = "top_score")]
TopScore,
#[display(fmt = "created_at")]
CreatedAt,
#[display(fmt = "latency")]
Expand Down
22 changes: 15 additions & 7 deletions server/src/operators/analytics_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,11 @@ pub async fn get_rag_queries_query(
) -> Result<RagQueryResponse, ServiceError> {
let mut query_string = String::from(
"SELECT
?fields
?fields,
top_score,
FROM
rag_queries
JOIN search_queries ON rag_queries.search_id = search_queries.id
WHERE dataset_id = ?",
);

Expand All @@ -775,7 +777,7 @@ pub async fn get_rag_queries_query(
.query(query_string.as_str())
.bind(dataset_id)
.bind((page.unwrap_or(1) - 1) * 10)
.fetch_all::<RagQueryEventClickhouse>()
.fetch_all::<(RagQueryEventClickhouse, f32)>()
.await
.map_err(|e| {
log::error!("Error fetching query: {:?}", e);
Expand All @@ -785,7 +787,7 @@ pub async fn get_rag_queries_query(
let queries: Vec<RagQueryEvent> = join_all(
clickhouse_query
.into_iter()
.map(|q| q.from_clickhouse(pool.clone())),
.map(|(q, score)| q.from_clickhouse(pool.clone(), score)),
)
.await;

Expand Down Expand Up @@ -888,18 +890,24 @@ pub async fn get_rag_query(
pool: web::Data<Pool>,
clickhouse_client: &clickhouse::Client,
) -> Result<RagQueryEvent, ServiceError> {
let clickhouse_query = clickhouse_client
.query("SELECT ?fields FROM rag_queries WHERE id = ? AND dataset_id = ?")
let (clickhouse_query, top_score) = clickhouse_client
.query(
"SELECT ?fields, top_score FROM rag_queries
JOIN search_queries ON rag_queries.search_id = search_queries.id
WHERE id = ? AND dataset_id = ?",
)
.bind(request_id)
.bind(dataset_id)
.fetch_one::<RagQueryEventClickhouse>()
.fetch_one::<(RagQueryEventClickhouse, f32)>()
.await
.map_err(|e| {
log::error!("Error fetching query: {:?}", e);
ServiceError::InternalServerError("Error fetching query".to_string())
})?;

let query: RagQueryEvent = clickhouse_query.from_clickhouse(pool.clone()).await;
let query: RagQueryEvent = clickhouse_query
.from_clickhouse(pool.clone(), top_score)
.await;

Ok(query)
}
Expand Down
1 change: 0 additions & 1 deletion server/src/operators/message_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ pub async fn get_rag_chunks_query(
.get(0)
.map(|x| x.score as f32)
.unwrap_or(0.0),

latency: get_latency_from_header(search_timer.header_value()),
results: result_chunks
.score_chunks
Expand Down