-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add retrieval results combination and deduplication
- Loading branch information
Joel Koch
committed
Dec 9, 2024
1 parent
c313623
commit d3c5ca4
Showing
2 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
defmodule Rag.Retrieval do | ||
@moduledoc """ | ||
Functions to transform retrieval results. | ||
""" | ||
|
||
@doc """ | ||
Pops the retrieval result for each key in `retrieval_result_keys` from `rag_state`. | ||
Then, appends the retrieval result to the list at `output_key`. | ||
""" | ||
@spec combine_retrieval_results(map(), list(atom()), atom()) :: map() | ||
def combine_retrieval_results(rag_state, retrieval_result_keys, output_key) do | ||
rag_state = Map.put_new(rag_state, output_key, []) | ||
|
||
for retrieval_result_key <- retrieval_result_keys, reduce: rag_state do | ||
state -> | ||
{retrieval_result, state} = Map.pop!(state, retrieval_result_key) | ||
|
||
Map.update!(state, output_key, fn combined_results -> | ||
combined_results ++ retrieval_result | ||
end) | ||
end | ||
end | ||
|
||
@doc """ | ||
Deduplicates entries at `entries_key` in `rag_state`. | ||
Two entries are considered duplicates if they hold the same value at **all** `unique_by_keys`. | ||
In case of duplicates, the first entry is kept. | ||
""" | ||
@spec deduplicate(map(), atom(), list(atom())) :: map() | ||
def deduplicate(rag_state, entries_key, unique_by_keys) do | ||
if unique_by_keys == [] do | ||
raise ArgumentError, "unique_by_keys must not be empty" | ||
end | ||
|
||
Map.update!(rag_state, entries_key, fn entries -> | ||
Enum.uniq_by(entries, &Map.take(&1, unique_by_keys)) | ||
end) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
defmodule Rag.RetrievalTest do | ||
use ExUnit.Case | ||
use Mimic | ||
|
||
alias Rag.Retrieval | ||
|
||
describe "combine_retrieval_results/3" do | ||
test "pops the results at retrieval_result_keys and combines them into a list at output_key" do | ||
foo_results = [%{id: 0, text: "something"}, %{id: 1, text: "something else"}] | ||
bar_results = [%{id: 0, text: "bar"}, %{id: 1, text: "bar else"}] | ||
|
||
rag_state = %{foo: foo_results, bar: bar_results} | ||
|
||
retrieval_result_keys = [:foo, :bar] | ||
output_key = :combined_result | ||
|
||
output_rag_state = | ||
Retrieval.combine_retrieval_results(rag_state, retrieval_result_keys, output_key) | ||
|
||
for key <- retrieval_result_keys do | ||
refute Map.has_key?(output_rag_state, key) | ||
end | ||
|
||
assert Map.fetch!(output_rag_state, output_key) == foo_results ++ bar_results | ||
end | ||
|
||
test "keeps existing results at output_key" do | ||
existing_results = [%{id: 0, text: "existing"}] | ||
new_results = [%{id: 1001, text: "new result"}] | ||
rag_state = %{results: existing_results, new: new_results} | ||
|
||
output_rag_state = Retrieval.combine_retrieval_results(rag_state, [:new], :results) | ||
|
||
assert Map.fetch!(output_rag_state, :results) == existing_results ++ new_results | ||
end | ||
|
||
test "errors if one of retrieval_result_keys is not in rag_state" do | ||
rag_state = %{text: "hello"} | ||
|
||
assert_raise KeyError, fn -> | ||
Retrieval.combine_retrieval_results(rag_state, [:foo], :results) | ||
end | ||
end | ||
end | ||
|
||
describe "deduplicate_results/3" do | ||
test "keeps only first result for entries with same values at all unique_by_keys" do | ||
results = [%{id: 0, value: "hello"}, %{id: 1, value: "hola"}, %{id: 0, value: "something"}] | ||
rag_state = %{results: results} | ||
|
||
assert Retrieval.deduplicate(rag_state, :results, [:id]) == %{ | ||
results: [%{id: 0, value: "hello"}, %{id: 1, value: "hola"}] | ||
} | ||
end | ||
|
||
test "errors if one of entries_key is not in rag_state" do | ||
rag_state = %{text: "hello"} | ||
|
||
assert_raise KeyError, fn -> | ||
Retrieval.deduplicate(rag_state, :results, [:foo]) | ||
end | ||
end | ||
|
||
test "errors if unique_by_keys is empty" do | ||
rag_state = %{text: "hello"} | ||
|
||
assert_raise ArgumentError, fn -> | ||
Retrieval.deduplicate(rag_state, :text, []) | ||
end | ||
end | ||
end | ||
end |