Skip to content

Commit

Permalink
Merge pull request #5 from VowpalWabbit/nosockettests
Browse files Browse the repository at this point in the history
unit tests to use mock encoder
  • Loading branch information
olgavrou authored Aug 29, 2023
2 parents 5de212d + f8b5c29 commit 42bdb00
Showing 1 changed file with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def setup() -> tuple:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(
Expand All @@ -36,7 +40,11 @@ def test_multiple_ToSelectFrom_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(action=rl_chain.ToSelectFrom(actions))
Expand All @@ -45,7 +53,11 @@ def test_missing_basedOn_from_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
chain.run(
Expand All @@ -63,6 +75,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
Expand All @@ -85,6 +98,7 @@ def test_update_with_delayed_score_force() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
Expand All @@ -104,7 +118,10 @@ def test_update_with_delayed_score_force() -> None:
def test_update_with_delayed_score() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
Expand All @@ -128,7 +145,10 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
return score

chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
llm=llm,
prompt=PROMPT,
selection_scorer=CustomSelectionScorer(),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
Expand Down Expand Up @@ -239,7 +259,11 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
def test_default_no_scorer_specified() -> None:
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=chain_llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
Expand All @@ -254,7 +278,10 @@ def test_default_no_scorer_specified() -> None:
def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
Expand All @@ -274,6 +301,7 @@ def test_auto_scorer_with_user_defined_llm() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
Expand All @@ -288,7 +316,11 @@ def test_auto_scorer_with_user_defined_llm() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
Expand Down

0 comments on commit 42bdb00

Please sign in to comment.