From b31756db6a9b2533fb985e226953ca9593c2fe74 Mon Sep 17 00:00:00 2001 From: Jegor Kitskerkin Date: Wed, 7 Dec 2022 18:32:57 +0100 Subject: [PATCH 1/3] Add SetFitModelConfig --- src/setfit/config.py | 20 ++++++++++++++++++++ src/setfit/modeling.py | 16 ++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 src/setfit/config.py diff --git a/src/setfit/config.py b/src/setfit/config.py new file mode 100644 index 00000000..922f4879 --- /dev/null +++ b/src/setfit/config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from transformers import PretrainedConfig + + +@dataclass +class SetFitModelConfig: + """ + Config for SetFitModel. + + Parameters: + model_body (`PretrainedConfig`): + Config of the model_body transformer. + model_head (`Optional[Any]`): + Config of the model_head. + """ + + model_body: PretrainedConfig + model_head: Optional[Any] diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 49c5b0c4..0cc355b1 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union +from .config import SetFitModelConfig + # Google Colab runs on Python 3.7, so we need this to be compatible try: @@ -211,6 +213,20 @@ def __init__( self.model_original_state = copy.deepcopy(self.model_body.state_dict()) self.normalize_embeddings = normalize_embeddings + @property + def config(self) -> SetFitModelConfig: + model_body_config = self.model_body._modules["0"]._modules["auto_model"].config + + model_head_config = None + if isinstance(self.model_head, SetFitHead): + model_head_config = self.model_head.get_config_dict() + + return SetFitModelConfig(model_body=model_body_config, model_head=model_head_config) + + @config.setter + def config(self, config: SetFitModelConfig) -> None: + self.model_body._modules["0"]._modules["auto_model"].config = config.model_body + def fit( self, x_train: List[str], From a9f88df61dd91e6bcf10e96677f3ac12da930727 Mon Sep 17 00:00:00 2001 From: Jegor Kitskerkin Date: Sat, 21 Jan 2023 20:03:43 +0100 Subject: [PATCH 2/3] Replace SetFitModelConfig -> SetFitConfig --- src/setfit/config.py | 2 +- src/setfit/modeling.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/setfit/config.py b/src/setfit/config.py index 922f4879..f3a9c496 100644 --- a/src/setfit/config.py +++ b/src/setfit/config.py @@ -5,7 +5,7 @@ @dataclass -class SetFitModelConfig: +class SetFitConfig: """ Config for SetFitModel. diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 0cc355b1..b3b49ffa 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union -from .config import SetFitModelConfig +from .config import SetFitConfig # Google Colab runs on Python 3.7, so we need this to be compatible @@ -214,17 +214,17 @@ def __init__( self.normalize_embeddings = normalize_embeddings @property - def config(self) -> SetFitModelConfig: + def config(self) -> SetFitConfig: model_body_config = self.model_body._modules["0"]._modules["auto_model"].config model_head_config = None if isinstance(self.model_head, SetFitHead): model_head_config = self.model_head.get_config_dict() - return SetFitModelConfig(model_body=model_body_config, model_head=model_head_config) + return SetFitConfig(model_body=model_body_config, model_head=model_head_config) @config.setter - def config(self, config: SetFitModelConfig) -> None: + def config(self, config: SetFitConfig) -> None: self.model_body._modules["0"]._modules["auto_model"].config = config.model_body def fit( From ef87e51dff0a543317339ee0a882d01e452e7493 Mon Sep 17 00:00:00 2001 From: Jegor Kitskerkin Date: Sat, 21 Jan 2023 20:04:20 +0100 Subject: [PATCH 3/3] Fix SetFitConfig docstring --- src/setfit/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/setfit/config.py b/src/setfit/config.py index f3a9c496..f7f102e5 100644 --- a/src/setfit/config.py +++ b/src/setfit/config.py @@ -7,13 +7,13 @@ @dataclass class SetFitConfig: """ - Config for SetFitModel. + Configuration for SetFitModel. - Parameters: + Args: model_body (`PretrainedConfig`): - Config of the model_body transformer. + Configuration of the Sentence Transformer body. model_head (`Optional[Any]`): - Config of the model_head. + Configuration of the head. """ model_body: PretrainedConfig