Skip to content

Commit 1cbf486

Browse files
committed
refactor: Implement BaseOne2oneValueMatcher and BaseTopkValueMatcher classes
1 parent db0ba8f commit 1cbf486

File tree

11 files changed

+208
-146
lines changed

11 files changed

+208
-146
lines changed

CONTRIBUTING.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ Contributors can add new methods for schema and value matching by following thes
3737

3838
1. Create a Python module inside the "task folder" folder (e.g., `bdikit/value_matching`).
3939

40-
2. Define a class in the module that implements either `BaseValueMatcher` (for value matching) or `BaseSchemaMatcher` (for schema matching).
40+
2. Define a class in the module that implements a base class. For value matching, it could be `BaseOne2oneValueMatcher` or `BaseTopkValueMatcher`. For schema matching, it could be `BaseOne2oneSchemaMatcher` or `BaseTopkSchemaMatcher`.
4141

42-
3. Add a new entry to the Enum class (e.g. `ValueMatchers`) in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`).
42+
3. Add a new entry to the Enum class (e.g. `One2OneValueMatchers`) in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`).
4343
Make sure to add the correct import path for your module to ensure it can be accessed without errors.
4444

4545

bdikit/api.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
get_one2one_schema_matcher,
1313
get_topk_schema_matcher,
1414
)
15-
from bdikit.value_matching.base import BaseValueMatcher, ValueMatch, ValueMatchingResult
16-
from bdikit.value_matching.matcher_factory import ValueMatchers
15+
from bdikit.value_matching.base import (
16+
BaseOne2oneValueMatcher,
17+
BaseTopkValueMatcher,
18+
ValueMatch,
19+
ValueMatchingResult,
20+
)
21+
from bdikit.value_matching.matcher_factory import (
22+
get_one2one_value_matcher,
23+
get_topk_value_matcher,
24+
)
1725
from bdikit.standards.standard_factory import Standards
1826

1927
from bdikit.mapping_functions import (
@@ -90,8 +98,7 @@ def match_schema(
9098

9199
def _load_table_for_standard(name: str, standard_args: Dict[str, Any]) -> pd.DataFrame:
92100
"""
93-
Load the table for the given standard data vocabulary. Currently, only the
94-
GDC standard is supported.
101+
Load the table for the given standard data vocabulary.
95102
"""
96103
if standard_args is None:
97104
standard_args = {}
@@ -165,7 +172,7 @@ def match_values(
165172
source: pd.DataFrame,
166173
target: Union[str, pd.DataFrame],
167174
column_mapping: Union[Tuple[str, str], pd.DataFrame],
168-
method: Union[str, BaseValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
175+
method: Union[str, BaseOne2oneValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
169176
method_args: Optional[Dict[str, Any]] = None,
170177
standard_args: Optional[Dict[str, Any]] = None,
171178
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
@@ -207,20 +214,19 @@ def match_values(
207214
ValueError: If the target is neither a DataFrame nor a standard vocabulary name.
208215
ValueError: If the source column is not present in the source dataset.
209216
"""
210-
if method_args is None:
211-
method_args = {}
212217

213218
if standard_args is None:
214219
standard_args = {}
215220

216-
if "top_k" in method_args and method_args["top_k"] > 1:
217-
logger.warning(
218-
f"Ignoring 'top_k' argument, use the 'top_value_matches()' method to get top-k value matches."
219-
)
220-
method_args["top_k"] = 1
221+
if isinstance(method, str):
222+
if method_args is None:
223+
method_args = {}
224+
matcher_instance = get_one2one_value_matcher(method, **method_args)
225+
elif isinstance(method, BaseOne2oneValueMatcher):
226+
matcher_instance = method
221227

222228
matches = _match_values(
223-
source, target, column_mapping, method, method_args, standard_args
229+
source, target, column_mapping, matcher_instance, standard_args
224230
)
225231

226232
if isinstance(column_mapping, tuple):
@@ -241,7 +247,7 @@ def top_value_matches(
241247
target: Union[str, pd.DataFrame],
242248
column_mapping: Union[Tuple[str, str], pd.DataFrame],
243249
top_k: int = 5,
244-
method: str = DEFAULT_VALUE_MATCHING_METHOD,
250+
method: Union[str, BaseTopkValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
245251
method_args: Optional[Dict[str, Any]] = None,
246252
standard_args: Optional[Dict[str, Any]] = None,
247253
) -> List[pd.DataFrame]:
@@ -284,21 +290,19 @@ def top_value_matches(
284290
ValueError: If the target is neither a DataFrame nor a standard vocabulary name.
285291
ValueError: If the source column is not present in the source dataset.
286292
"""
287-
if method_args is None:
288-
method_args = {}
289293

290294
if standard_args is None:
291295
standard_args = {}
292296

293-
if "top_k" in method_args:
294-
logger.warning(
295-
f"Ignoring 'top_k' argument, using top_k argument instead (top_k={top_k})"
296-
)
297-
298-
method_args["top_k"] = top_k
297+
if isinstance(method, str):
298+
if method_args is None:
299+
method_args = {}
300+
matcher_instance = get_topk_value_matcher(method, **method_args)
301+
elif isinstance(method, BaseTopkValueMatcher):
302+
matcher_instance = method
299303

300304
matches = _match_values(
301-
source, target, column_mapping, method, method_args, standard_args
305+
source, target, column_mapping, matcher_instance, standard_args, top_k
302306
)
303307

304308
match_list = []
@@ -359,15 +363,15 @@ def _match_values(
359363
source: pd.DataFrame,
360364
target: Union[str, pd.DataFrame],
361365
column_mapping: Union[Tuple[str, str], pd.DataFrame],
362-
method: str,
363-
method_args: Dict[str, Any],
366+
value_matcher: Union[BaseOne2oneValueMatcher, BaseTopkValueMatcher],
364367
standard_args: Dict[str, Any],
368+
top_k: int = 1,
365369
) -> List[pd.DataFrame]:
366370

367371
target_domain, column_mapping_list = _format_value_matching_input(
368372
source, target, column_mapping, standard_args
369373
)
370-
value_matcher = ValueMatchers.get_matcher(method, **method_args)
374+
371375
mapping_results: List[ValueMatchingResult] = []
372376

373377
for mapping in column_mapping_list:
@@ -389,9 +393,14 @@ def _match_values(
389393
}
390394

391395
# 3. Apply the value matcher to create value mapping dictionaries
392-
raw_matches = value_matcher.match(
393-
list(source_values_dict.keys()), list(target_values_dict.keys())
394-
)
396+
if isinstance(value_matcher, BaseTopkValueMatcher):
397+
raw_matches = value_matcher.get_topk_matches(
398+
list(source_values_dict.keys()), list(target_values_dict.keys()), top_k
399+
)
400+
else:
401+
raw_matches = value_matcher.get_one2one_match(
402+
list(source_values_dict.keys()), list(target_values_dict.keys())
403+
)
395404

396405
# 4. Transform the matches to the original
397406
matches: List[ValueMatch] = []

bdikit/schema_matching/matcher_factory.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import importlib
21
from enum import Enum
3-
from typing import Mapping, Dict, Any
2+
from typing import Mapping, Any
43
from bdikit.schema_matching.base import BaseOne2oneSchemaMatcher, BaseTopkSchemaMatcher
4+
from bdikit.utils import create_matcher
55

66

77
class One2oneSchemaMatchers(Enum):
@@ -82,24 +82,6 @@ def __init__(self, matcher_name: str, matcher_path: str):
8282
one2one_schema_matchers.update(topk_schema_matchers)
8383

8484

85-
def create_matcher(
86-
matcher_name: str,
87-
available_matchers: Dict[str, str],
88-
**matcher_kwargs: Mapping[str, Any],
89-
):
90-
if matcher_name not in available_matchers:
91-
names = ", ".join(list(available_matchers.keys()))
92-
raise ValueError(
93-
f"The {matcher_name} algorithm is not supported. "
94-
f"Supported algorithms are: {names}"
95-
)
96-
# Load the class dynamically
97-
module_path, class_name = available_matchers[matcher_name].rsplit(".", 1)
98-
module = importlib.import_module(module_path)
99-
100-
return getattr(module, class_name)(**matcher_kwargs)
101-
102-
10385
def get_one2one_schema_matcher(
10486
matcher_name: str, **matcher_kwargs: Mapping[str, Any]
10587
) -> BaseOne2oneSchemaMatcher:

bdikit/schema_matching/maxvalsim.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
ColumnScore,
88
)
99
from bdikit.schema_matching.contrastivelearning import ContrastiveLearning
10-
from bdikit.value_matching.polyfuzz import TFIDFValueMatcher
11-
from bdikit.value_matching.base import BaseValueMatcher
10+
from bdikit.value_matching.polyfuzz import TFIDF
11+
from bdikit.value_matching.base import BaseOne2oneValueMatcher
1212

1313

1414
class MaxValSim(BaseTopkSchemaMatcher):
@@ -17,7 +17,7 @@ def __init__(
1717
top_k: int = 20,
1818
contribution_factor: float = 0.5,
1919
top_k_matcher: Optional[BaseTopkSchemaMatcher] = None,
20-
value_matcher: Optional[BaseValueMatcher] = None,
20+
value_matcher: Optional[BaseOne2oneValueMatcher] = None,
2121
):
2222
if top_k_matcher is None:
2323
self.api = ContrastiveLearning(DEFAULT_CL_MODEL)
@@ -30,13 +30,13 @@ def __init__(
3030
)
3131

3232
if value_matcher is None:
33-
self.value_matcher = TFIDFValueMatcher()
34-
elif isinstance(value_matcher, BaseValueMatcher):
33+
self.value_matcher = TFIDF()
34+
elif isinstance(value_matcher, BaseOne2oneValueMatcher):
3535
self.value_matcher = value_matcher
3636
else:
3737
raise ValueError(
3838
f"Invalid value_matcher type: {type(value_matcher)}. "
39-
"Must be a subclass of {BaseValueMatcher.__name__}"
39+
"Must be a subclass of {BaseOne2oneValueMatcher.__name__}"
4040
)
4141

4242
self.top_k = top_k
@@ -76,7 +76,9 @@ def get_topk_matches(
7676
target_column_name = top_column.column_name
7777
target_column = target[target_column_name]
7878
target_values = self.unique_string_values(target_column).to_list()
79-
value_matches = self.value_matcher.match(source_values, target_values)
79+
value_matches = self.value_matcher.get_one2one_match(
80+
source_values, target_values
81+
)
8082
if len(target_values) == 0:
8183
value_score = 0.0
8284
else:

bdikit/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
import hashlib
3+
import importlib
34
import pandas as pd
45
from os.path import join, dirname, isfile
6+
from typing import Mapping, Dict, Any
57
from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR
68

79

@@ -58,3 +60,21 @@ def check_embedding_cache(table: pd.DataFrame, model_path: str):
5860
embeddings = None
5961

6062
return embedding_file, embeddings
63+
64+
65+
def create_matcher(
66+
matcher_name: str,
67+
available_matchers: Dict[str, str],
68+
**matcher_kwargs: Mapping[str, Any],
69+
):
70+
if matcher_name not in available_matchers:
71+
names = ", ".join(list(available_matchers.keys()))
72+
raise ValueError(
73+
f"The {matcher_name} algorithm is not supported. "
74+
f"Supported algorithms are: {names}"
75+
)
76+
# Load the class dynamically
77+
module_path, class_name = available_matchers[matcher_name].rsplit(".", 1)
78+
module = importlib.import_module(module_path)
79+
80+
return getattr(module, class_name)(**matcher_kwargs)

bdikit/value_matching/base.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,27 @@ class ValueMatchingResult(TypedDict):
2525
unmatch_values: Set[str]
2626

2727

28-
class BaseValueMatcher:
28+
class BaseOne2oneValueMatcher:
2929
"""
3030
Base class for value matching algorithms, i.e., algorithms that match
3131
values from a source domain to values from a target domain.
3232
"""
3333

34-
def match(
34+
def get_one2one_match(
3535
self, source_values: List[str], target_values: List[str]
3636
) -> List[ValueMatch]:
3737
raise NotImplementedError("Subclasses must implement this method")
38+
39+
40+
class BaseTopkValueMatcher(BaseOne2oneValueMatcher):
41+
def get_topk_matches(
42+
self, source_values: List[str], target_values: List[str], top_k: int
43+
) -> List[ValueMatch]:
44+
raise NotImplementedError("Subclasses must implement this method")
45+
46+
def get_one2one_match(
47+
self, source_values: List[str], target_values: List[str]
48+
) -> List[ValueMatch]:
49+
matches = self.get_topk_matches(source_values, target_values, 1)
50+
51+
return matches

bdikit/value_matching/gpt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import ast
22
from typing import List
33
from openai import OpenAI
4-
from bdikit.value_matching.base import BaseValueMatcher, ValueMatch
4+
from bdikit.value_matching.base import BaseOne2oneValueMatcher, ValueMatch
55
from bdikit.config import VALUE_MATCHING_THRESHOLD
66

77

8-
class GPTValueMatcher(BaseValueMatcher):
8+
class GPT(BaseOne2oneValueMatcher):
99
def __init__(
1010
self,
1111
threshold: float = VALUE_MATCHING_THRESHOLD,
1212
):
1313
self.client = OpenAI()
1414
self.threshold = threshold
1515

16-
def match(
16+
def get_one2one_match(
1717
self,
1818
source_values: List[str],
1919
target_values: List[str],

0 commit comments

Comments
 (0)