Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Small clean-up (#417)
Browse files Browse the repository at this point in the history
* Remove duplicated file

* Move serve tests

* Fix

* Update tests/conftest.py

Co-authored-by: Jirka Borovec <[email protected]>

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
ethanwharris and Borda authored Jun 16, 2021
1 parent 6ea2a3c commit 146e05d
Show file tree
Hide file tree
Showing 30 changed files with 21 additions and 64 deletions.
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _generate_sequential_uuid():

@pytest.fixture(scope="session")
def original_global_datadir():
return pathlib.Path(os.path.realpath(__file__)).parent.joinpath("serve").joinpath("data")
return pathlib.Path(os.path.realpath(__file__)).parent / "core" / "serve" / "data"


def prep_global_datadir(tmp_path_factory, original_global_datadir):
Expand Down Expand Up @@ -69,7 +69,7 @@ def squeezenet1_1_model():

@pytest.fixture(scope="session")
def lightning_squeezenet1_1_obj():
from tests.serve.models import LightningSqueezenet
from tests.core.serve.models import LightningSqueezenet

model = LightningSqueezenet()
model.eval()
Expand All @@ -88,7 +88,7 @@ def squeezenet_gridmodel(squeezenet1_1_model, session_global_datadir):

@pytest.fixture()
def lightning_squeezenet_checkpoint_path(tmp_path):
from tests.serve.models import LightningSqueezenet
from tests.core.serve.models import LightningSqueezenet

model = LightningSqueezenet()
state_dict = {"state_dict": model.state_dict()}
Expand Down
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flash.core.serve.types import Label
from flash.core.utilities.imports import _SERVE_AVAILABLE, _TORCHVISION_AVAILABLE
from tests.serve.models import ClassificationInferenceComposable, LightningSqueezenet
from tests.core.serve.models import ClassificationInferenceComposable, LightningSqueezenet


@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_composit_endpoint_data(lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

comp = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
composit = Composition(comp=comp)
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj):

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

Expand Down Expand Up @@ -149,7 +149,7 @@ def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj):
_ = Composition(hello="world")

# no endpoints multiple components
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
Expand All @@ -160,7 +160,7 @@ def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj):

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_gridmodel_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel):
from tests.serve.models import ClassificationInferenceModelSequence
from tests.core.serve.models import ClassificationInferenceModelSequence

squeezenet_gm, _ = squeezenet_gridmodel
model_seq = [squeezenet_gm, squeezenet_gm]
Expand All @@ -174,7 +174,7 @@ def test_gridmodel_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gr

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_gridmodel_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel):
from tests.serve.models import ClassificationInferenceModelMapping
from tests.core.serve.models import ClassificationInferenceModelMapping

squeezenet_gm, _ = squeezenet_gridmodel
model_map = {"model_one": squeezenet_gm, "model_two": squeezenet_gm}
Expand All @@ -188,7 +188,7 @@ def test_gridmodel_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gri

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_invalid_gridmodel_composition(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel):
from tests.serve.models import ClassificationInferenceModelMapping
from tests.core.serve.models import ClassificationInferenceModelMapping

squeezenet_gm, _ = squeezenet_gridmodel

Expand All @@ -202,7 +202,7 @@ def test_invalid_gridmodel_composition(tmp_path, lightning_squeezenet1_1_obj, sq

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj):

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj):

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_start_server_from_composition(tmp_path, squeezenet_gridmodel, session_global_datadir):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

squeezenet_gm, _ = squeezenet_gridmodel
comp1 = ClassificationInferenceComposable(squeezenet_gm)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_metho
This is noted because it differes from some of the other metaclass validations
which will raise an exception at class defiition time.
"""
from tests.serve.models import ClassificationInference
from tests.core.serve.models import ClassificationInference

class FailedExposedDecorator(ModelComponent):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInference
from tests.core.serve.models import ClassificationInference

comp = ClassificationInference(lightning_squeezenet1_1_obj)
composit = Composition(comp=comp, TESTING=True, DEBUG=True)
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_start_server_with_repeated_exposed(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceRepeated
from tests.core.serve.models import ClassificationInferenceRepeated

comp = ClassificationInferenceRepeated(lightning_squeezenet1_1_obj)
composit = Composition(comp=comp, TESTING=True, DEBUG=True)
Expand All @@ -65,7 +65,7 @@ def test_start_server_with_repeated_exposed(session_global_datadir, lightning_sq

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_serving_single_component_and_endpoint_no_composition(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInference
from tests.core.serve.models import ClassificationInference

comp = ClassificationInference(lightning_squeezenet1_1_obj)
assert hasattr(comp.inputs, "img")
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInference, SeatClassifier
from tests.core.serve.models import ClassificationInference, SeatClassifier

resnet_comp = ClassificationInference(lightning_squeezenet1_1_obj)
seat_comp = SeatClassifier(lightning_squeezenet1_1_obj, config={"sport": "football"})
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj):

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_composed_does_not_eliminate_endpoint_serialization(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInference, SeatClassifier
from tests.core.serve.models import ClassificationInference, SeatClassifier

resnet_comp = ClassificationInference(lightning_squeezenet1_1_obj)
seat_comp = SeatClassifier(lightning_squeezenet1_1_obj, config={"sport": "football"})
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInference, SeatClassifier
from tests.core.serve.models import ClassificationInference, SeatClassifier

resnet_comp = ClassificationInference(lightning_squeezenet1_1_obj)
seat_comp = SeatClassifier(lightning_squeezenet1_1_obj, config={"sport": "football"})
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ

@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.")
def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1_1_obj):
from tests.serve.models import ClassificationInferenceComposable
from tests.core.serve.models import ClassificationInferenceComposable

c1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)

Expand Down
40 changes: 0 additions & 40 deletions tests/core/test_integrations.py

This file was deleted.

3 changes: 0 additions & 3 deletions tests/examples/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, Tuple
from unittest import mock

import pytest
Expand Down

0 comments on commit 146e05d

Please sign in to comment.