Skip to content

Commit

Permalink
fixed tests now: issue was set_classifier was not being caught by mag…
Browse files Browse the repository at this point in the history
…ic_mock thus it outputted a huggingface pipeline
  • Loading branch information
J-Dymond committed Dec 5, 2024
1 parent 382f567 commit d983357
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ def test_pipeline_inputs(dummy_data, dummy_metadata):
"arc_spice.variational_pipelines.RTC_variational_pipeline.pipeline",
):
with patch(
(
"arc_spice.variational_pipelines.RTC_variational_pipeline."
"RTCVariationalPipeline._init_semantic_density"
),
return_value=None,
"arc_spice.variational_pipelines.utils.pipeline",
):
pipeline = RTCVariationalPipeline(
model_pars=pipeline_config,
data_pars=dummy_metadata,
translation_batch_size=1,
)
with patch(
(
"arc_spice.variational_pipelines.RTC_variational_pipeline."
"RTCVariationalPipeline._init_semantic_density"
),
return_value=None,
):
pipeline = RTCVariationalPipeline(
model_pars=pipeline_config,
data_pars=dummy_metadata,
translation_batch_size=1,
)

dummy_recognise_output = {"outputs": "rec text"}
dummy_translate_output = {"outputs": ["translate text"]}
Expand Down Expand Up @@ -86,23 +89,26 @@ def test_single_component_inputs(dummy_data, dummy_metadata):
"arc_spice.variational_pipelines.RTC_single_component_pipeline.pipeline"
):
with patch(
(
"arc_spice.variational_pipelines.RTC_single_component_pipeline."
"RTCSingleComponentPipeline._init_semantic_density"
),
return_value=None,
"arc_spice.variational_pipelines.utils.pipeline",
):
recognise_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config,
)
translate_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config,
translation_batch_size=1,
)
classify_pipeline = ClassificationVariationalPipeline(
model_pars=pipeline_config,
data_pars=dummy_metadata,
)
with patch(
(
"arc_spice.variational_pipelines.RTC_single_component_pipeline."
"RTCSingleComponentPipeline._init_semantic_density"
),
return_value=None,
):
recognise_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config,
)
translate_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config,
translation_batch_size=1,
)
classify_pipeline = ClassificationVariationalPipeline(
model_pars=pipeline_config,
data_pars=dummy_metadata,
)

recognise_pipeline.forward_function = MagicMock(return_value=dummy_recognise_output)
translate_pipeline.forward_function = MagicMock(return_value=dummy_translate_output)
Expand Down

0 comments on commit d983357

Please sign in to comment.