Skip to content

Commit

Permalink
Add retrieval results combination and deduplication
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel Koch committed Dec 9, 2024
1 parent c313623 commit d3c5ca4
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
39 changes: 39 additions & 0 deletions lib/rag/retrieval.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
defmodule Rag.Retrieval do
@moduledoc """
Functions to transform retrieval results.
"""

@doc """
Pops the retrieval result for each key in `retrieval_result_keys` from `rag_state`.
Then, appends the retrieval result to the list at `output_key`.
"""
@spec combine_retrieval_results(map(), list(atom()), atom()) :: map()
def combine_retrieval_results(rag_state, retrieval_result_keys, output_key) do
rag_state = Map.put_new(rag_state, output_key, [])

for retrieval_result_key <- retrieval_result_keys, reduce: rag_state do
state ->
{retrieval_result, state} = Map.pop!(state, retrieval_result_key)

Map.update!(state, output_key, fn combined_results ->
combined_results ++ retrieval_result
end)
end
end

@doc """
Deduplicates entries at `entries_key` in `rag_state`.
Two entries are considered duplicates if they hold the same value at **all** `unique_by_keys`.
In case of duplicates, the first entry is kept.
"""
@spec deduplicate(map(), atom(), list(atom())) :: map()
def deduplicate(rag_state, entries_key, unique_by_keys) do
if unique_by_keys == [] do
raise ArgumentError, "unique_by_keys must not be empty"
end

Map.update!(rag_state, entries_key, fn entries ->
Enum.uniq_by(entries, &Map.take(&1, unique_by_keys))
end)
end
end
72 changes: 72 additions & 0 deletions test/rag/retrieval_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
defmodule Rag.RetrievalTest do
use ExUnit.Case
use Mimic

alias Rag.Retrieval

describe "combine_retrieval_results/3" do
test "pops the results at retrieval_result_keys and combines them into a list at output_key" do
foo_results = [%{id: 0, text: "something"}, %{id: 1, text: "something else"}]
bar_results = [%{id: 0, text: "bar"}, %{id: 1, text: "bar else"}]

rag_state = %{foo: foo_results, bar: bar_results}

retrieval_result_keys = [:foo, :bar]
output_key = :combined_result

output_rag_state =
Retrieval.combine_retrieval_results(rag_state, retrieval_result_keys, output_key)

for key <- retrieval_result_keys do
refute Map.has_key?(output_rag_state, key)
end

assert Map.fetch!(output_rag_state, output_key) == foo_results ++ bar_results
end

test "keeps existing results at output_key" do
existing_results = [%{id: 0, text: "existing"}]
new_results = [%{id: 1001, text: "new result"}]
rag_state = %{results: existing_results, new: new_results}

output_rag_state = Retrieval.combine_retrieval_results(rag_state, [:new], :results)

assert Map.fetch!(output_rag_state, :results) == existing_results ++ new_results
end

test "errors if one of retrieval_result_keys is not in rag_state" do
rag_state = %{text: "hello"}

assert_raise KeyError, fn ->
Retrieval.combine_retrieval_results(rag_state, [:foo], :results)
end
end
end

describe "deduplicate_results/3" do
test "keeps only first result for entries with same values at all unique_by_keys" do
results = [%{id: 0, value: "hello"}, %{id: 1, value: "hola"}, %{id: 0, value: "something"}]
rag_state = %{results: results}

assert Retrieval.deduplicate(rag_state, :results, [:id]) == %{
results: [%{id: 0, value: "hello"}, %{id: 1, value: "hola"}]
}
end

test "errors if one of entries_key is not in rag_state" do
rag_state = %{text: "hello"}

assert_raise KeyError, fn ->
Retrieval.deduplicate(rag_state, :results, [:foo])
end
end

test "errors if unique_by_keys is empty" do
rag_state = %{text: "hello"}

assert_raise ArgumentError, fn ->
Retrieval.deduplicate(rag_state, :text, [])
end
end
end
end

0 comments on commit d3c5ca4

Please sign in to comment.