Skip to content

Commit e7b1042

Browse files
authored
refactor: Reorganize matchers to clearly differentiate between one-to-one and top-k matching (#101)
1 parent 3ae3632 commit e7b1042

25 files changed

+409
-413
lines changed

CONTRIBUTING.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ Contributors can add new methods for schema and value matching by following thes
3737

3838
1. Create a Python module inside the "task folder" folder (e.g., `bdikit/value_matching`).
3939

40-
2. Define a class in the module that implements either `BaseValueMatcher` (for value matching) or `BaseSchemaMatcher` (for schema matching).
40+
2. Define a class in the module that implements a base class. For value matching, it could be `BaseOne2oneValueMatcher` or `BaseTopkValueMatcher`. For schema matching, it could be `BaseOne2oneSchemaMatcher` or `BaseTopkSchemaMatcher`.
4141

42-
3. Add a new entry to the Enum class (e.g. `ValueMatchers`) in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`).
42+
3. Add a new entry to the Enum class (e.g. `One2OneValueMatchers`) in `matcher_factory.py` (e.g., `bdikit/value_matching/matcher_factory.py`).
4343
Make sure to add the correct import path for your module to ensure it can be accessed without errors.
4444

4545

bdikit/api.py

+49-39
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77
import panel as pn
88
from IPython.display import display, Markdown
99

10-
from bdikit.schema_matching.one2one.base import BaseSchemaMatcher
11-
from bdikit.schema_matching.one2one.matcher_factory import SchemaMatchers
12-
from bdikit.schema_matching.topk.base import BaseTopkSchemaMatcher
13-
from bdikit.schema_matching.topk.matcher_factory import TopkMatchers
14-
from bdikit.value_matching.base import BaseValueMatcher, ValueMatch, ValueMatchingResult
15-
from bdikit.value_matching.matcher_factory import ValueMatchers
10+
from bdikit.schema_matching.base import BaseOne2oneSchemaMatcher, BaseTopkSchemaMatcher
11+
from bdikit.schema_matching.matcher_factory import (
12+
get_one2one_schema_matcher,
13+
get_topk_schema_matcher,
14+
)
15+
from bdikit.value_matching.base import (
16+
BaseOne2oneValueMatcher,
17+
BaseTopkValueMatcher,
18+
ValueMatch,
19+
ValueMatchingResult,
20+
)
21+
from bdikit.value_matching.matcher_factory import (
22+
get_one2one_value_matcher,
23+
get_topk_value_matcher,
24+
)
1625
from bdikit.standards.standard_factory import Standards
1726

1827
from bdikit.mapping_functions import (
@@ -43,7 +52,7 @@
4352
def match_schema(
4453
source: pd.DataFrame,
4554
target: Union[str, pd.DataFrame] = "gdc",
46-
method: Union[str, BaseSchemaMatcher] = DEFAULT_SCHEMA_MATCHING_METHOD,
55+
method: Union[str, BaseOne2oneSchemaMatcher] = DEFAULT_SCHEMA_MATCHING_METHOD,
4756
method_args: Optional[Dict[str, Any]] = None,
4857
standard_args: Optional[Dict[str, Any]] = None,
4958
) -> pd.DataFrame:
@@ -74,23 +83,22 @@ def match_schema(
7483
if isinstance(method, str):
7584
if method_args is None:
7685
method_args = {}
77-
matcher_instance = SchemaMatchers.get_matcher(method, **method_args)
78-
elif isinstance(method, BaseSchemaMatcher):
86+
matcher_instance = get_one2one_schema_matcher(method, **method_args)
87+
elif isinstance(method, BaseOne2oneSchemaMatcher):
7988
matcher_instance = method
8089
else:
8190
raise ValueError(
8291
"The method must be a string or an instance of BaseColumnMappingAlgorithm"
8392
)
8493

85-
matches = matcher_instance.map(source, target_table)
94+
matches = matcher_instance.get_one2one_match(source, target_table)
8695

8796
return pd.DataFrame(matches.items(), columns=["source", "target"])
8897

8998

9099
def _load_table_for_standard(name: str, standard_args: Dict[str, Any]) -> pd.DataFrame:
91100
"""
92-
Load the table for the given standard data vocabulary. Currently, only the
93-
GDC standard is supported.
101+
Load the table for the given standard data vocabulary.
94102
"""
95103
if standard_args is None:
96104
standard_args = {}
@@ -138,15 +146,15 @@ def top_matches(
138146
if isinstance(method, str):
139147
if method_args is None:
140148
method_args = {}
141-
topk_matcher = TopkMatchers.get_matcher(method, **method_args)
149+
topk_matcher = get_topk_schema_matcher(method, **method_args)
142150
elif isinstance(method, BaseTopkSchemaMatcher):
143151
topk_matcher = method
144152
else:
145153
raise ValueError(
146154
"The method must be a string or an instance of BaseTopkColumnMatcher"
147155
)
148156

149-
top_k_matches = topk_matcher.get_recommendations(
157+
top_k_matches = topk_matcher.get_topk_matches(
150158
selected_columns, target=target_table, top_k=top_k
151159
)
152160

@@ -164,7 +172,7 @@ def match_values(
164172
source: pd.DataFrame,
165173
target: Union[str, pd.DataFrame],
166174
column_mapping: Union[Tuple[str, str], pd.DataFrame],
167-
method: Union[str, BaseValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
175+
method: Union[str, BaseOne2oneValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
168176
method_args: Optional[Dict[str, Any]] = None,
169177
standard_args: Optional[Dict[str, Any]] = None,
170178
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
@@ -206,20 +214,19 @@ def match_values(
206214
ValueError: If the target is neither a DataFrame nor a standard vocabulary name.
207215
ValueError: If the source column is not present in the source dataset.
208216
"""
209-
if method_args is None:
210-
method_args = {}
211217

212218
if standard_args is None:
213219
standard_args = {}
214220

215-
if "top_k" in method_args and method_args["top_k"] > 1:
216-
logger.warning(
217-
f"Ignoring 'top_k' argument, use the 'top_value_matches()' method to get top-k value matches."
218-
)
219-
method_args["top_k"] = 1
221+
if isinstance(method, str):
222+
if method_args is None:
223+
method_args = {}
224+
matcher_instance = get_one2one_value_matcher(method, **method_args)
225+
elif isinstance(method, BaseOne2oneValueMatcher):
226+
matcher_instance = method
220227

221228
matches = _match_values(
222-
source, target, column_mapping, method, method_args, standard_args
229+
source, target, column_mapping, matcher_instance, standard_args
223230
)
224231

225232
if isinstance(column_mapping, tuple):
@@ -240,7 +247,7 @@ def top_value_matches(
240247
target: Union[str, pd.DataFrame],
241248
column_mapping: Union[Tuple[str, str], pd.DataFrame],
242249
top_k: int = 5,
243-
method: str = DEFAULT_VALUE_MATCHING_METHOD,
250+
method: Union[str, BaseTopkValueMatcher] = DEFAULT_VALUE_MATCHING_METHOD,
244251
method_args: Optional[Dict[str, Any]] = None,
245252
standard_args: Optional[Dict[str, Any]] = None,
246253
) -> List[pd.DataFrame]:
@@ -283,21 +290,19 @@ def top_value_matches(
283290
ValueError: If the target is neither a DataFrame nor a standard vocabulary name.
284291
ValueError: If the source column is not present in the source dataset.
285292
"""
286-
if method_args is None:
287-
method_args = {}
288293

289294
if standard_args is None:
290295
standard_args = {}
291296

292-
if "top_k" in method_args:
293-
logger.warning(
294-
f"Ignoring 'top_k' argument, using top_k argument instead (top_k={top_k})"
295-
)
296-
297-
method_args["top_k"] = top_k
297+
if isinstance(method, str):
298+
if method_args is None:
299+
method_args = {}
300+
matcher_instance = get_topk_value_matcher(method, **method_args)
301+
elif isinstance(method, BaseTopkValueMatcher):
302+
matcher_instance = method
298303

299304
matches = _match_values(
300-
source, target, column_mapping, method, method_args, standard_args
305+
source, target, column_mapping, matcher_instance, standard_args, top_k
301306
)
302307

303308
match_list = []
@@ -358,15 +363,15 @@ def _match_values(
358363
source: pd.DataFrame,
359364
target: Union[str, pd.DataFrame],
360365
column_mapping: Union[Tuple[str, str], pd.DataFrame],
361-
method: str,
362-
method_args: Dict[str, Any],
366+
value_matcher: Union[BaseOne2oneValueMatcher, BaseTopkValueMatcher],
363367
standard_args: Dict[str, Any],
368+
top_k: int = 1,
364369
) -> List[pd.DataFrame]:
365370

366371
target_domain, column_mapping_list = _format_value_matching_input(
367372
source, target, column_mapping, standard_args
368373
)
369-
value_matcher = ValueMatchers.get_matcher(method, **method_args)
374+
370375
mapping_results: List[ValueMatchingResult] = []
371376

372377
for mapping in column_mapping_list:
@@ -388,9 +393,14 @@ def _match_values(
388393
}
389394

390395
# 3. Apply the value matcher to create value mapping dictionaries
391-
raw_matches = value_matcher.match(
392-
list(source_values_dict.keys()), list(target_values_dict.keys())
393-
)
396+
if isinstance(value_matcher, BaseTopkValueMatcher):
397+
raw_matches = value_matcher.get_topk_matches(
398+
list(source_values_dict.keys()), list(target_values_dict.keys()), top_k
399+
)
400+
else:
401+
raw_matches = value_matcher.get_one2one_match(
402+
list(source_values_dict.keys()), list(target_values_dict.keys())
403+
)
394404

395405
# 4. Transform the matches to the original
396406
matches: List[ValueMatch] = []

bdikit/schema_matching/topk/base.py bdikit/schema_matching/base.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1-
from bdikit.schema_matching.one2one.base import BaseSchemaMatcher
21
from typing import List, NamedTuple, TypedDict, Dict
32
import pandas as pd
43

54

5+
class BaseOne2oneSchemaMatcher:
6+
def get_one2one_match(
7+
self, source: pd.DataFrame, target: pd.DataFrame
8+
) -> Dict[str, str]:
9+
raise NotImplementedError("Subclasses must implement this method")
10+
11+
def _fill_missing_matches(
12+
self, dataset: pd.DataFrame, matches: Dict[str, str]
13+
) -> Dict[str, str]:
14+
for column in dataset.columns:
15+
if column not in matches:
16+
matches[column] = ""
17+
return matches
18+
19+
620
class ColumnScore(NamedTuple):
721
column_name: str
822
score: float
@@ -13,19 +27,19 @@ class TopkMatching(TypedDict):
1327
top_k_columns: List[ColumnScore]
1428

1529

16-
class BaseTopkSchemaMatcher(BaseSchemaMatcher):
30+
class BaseTopkSchemaMatcher(BaseOne2oneSchemaMatcher):
1731

18-
def get_recommendations(
32+
def get_topk_matches(
1933
self, source: pd.DataFrame, target: pd.DataFrame, top_k: int
2034
) -> List[TopkMatching]:
2135
raise NotImplementedError("Subclasses must implement this method")
2236

23-
def map(
37+
def get_one2one_match(
2438
self,
2539
source: pd.DataFrame,
2640
target: pd.DataFrame,
2741
) -> Dict[str, str]:
28-
top_matches = self.get_recommendations(source, target, 1)
42+
top_matches = self.get_topk_matches(source, target, 1)
2943
matches = {}
3044

3145
for top_match in top_matches:

bdikit/schema_matching/topk/contrastivelearning.py bdikit/schema_matching/contrastivelearning.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import numpy as np
33
from typing import List
4-
from bdikit.schema_matching.topk.base import (
4+
from bdikit.schema_matching.base import (
55
ColumnScore,
66
TopkMatching,
77
BaseTopkSchemaMatcher,
@@ -14,12 +14,12 @@
1414
from bdikit.models import ColumnEmbedder
1515

1616

17-
class EmbeddingSimilarityTopkSchemaMatcher(BaseTopkSchemaMatcher):
17+
class EmbeddingSimilarity(BaseTopkSchemaMatcher):
1818
def __init__(self, column_embedder: ColumnEmbedder, metric: str = "cosine"):
1919
self.api = column_embedder
2020
self.metric = metric
2121

22-
def get_recommendations(
22+
def get_topk_matches(
2323
self, source: pd.DataFrame, target: pd.DataFrame, top_k: int = 10
2424
) -> List[TopkMatching]:
2525
"""
@@ -54,7 +54,7 @@ def get_recommendations(
5454
return top_k_results
5555

5656

57-
class CLTopkSchemaMatcher(EmbeddingSimilarityTopkSchemaMatcher):
57+
class ContrastiveLearning(EmbeddingSimilarity):
5858
def __init__(self, model_name: str = DEFAULT_CL_MODEL, metric: str = "cosine"):
5959
super().__init__(
6060
column_embedder=ContrastiveLearningAPI(model_name=model_name), metric=metric

bdikit/schema_matching/one2one/gpt.py bdikit/schema_matching/gpt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pandas as pd
22
from openai import OpenAI
3-
from bdikit.schema_matching.one2one.base import BaseSchemaMatcher
3+
from bdikit.schema_matching.base import BaseOne2oneSchemaMatcher
44

55

6-
class GPTSchemaMatcher(BaseSchemaMatcher):
6+
class GPT(BaseOne2oneSchemaMatcher):
77
def __init__(self):
88
self.client = OpenAI()
99

10-
def map(self, source: pd.DataFrame, target: pd.DataFrame):
10+
def get_one2one_match(self, source: pd.DataFrame, target: pd.DataFrame):
1111
target_columns = target.columns
1212
labels = ", ".join(target_columns)
1313
candidate_columns = source.columns

bdikit/schema_matching/topk/magneto.py bdikit/schema_matching/magneto.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import pandas as pd
22
from typing import Dict, Any, List
33
from magneto import Magneto as Magneto_Lib
4-
from bdikit.schema_matching.one2one.base import BaseSchemaMatcher
54
from bdikit.download import get_cached_model_or_download
6-
from bdikit.schema_matching.topk.base import (
7-
ColumnScore,
8-
TopkMatching,
9-
BaseTopkSchemaMatcher,
10-
)
5+
from bdikit.schema_matching.base import ColumnScore, TopkMatching, BaseTopkSchemaMatcher
116

127
DEFAULT_MAGNETO_MODEL = "magneto-gdc-v0.1"
138

@@ -18,7 +13,7 @@ def __init__(self, kwargs: Dict[str, Any] = None):
1813
kwargs = {}
1914
self.magneto = Magneto_Lib(**kwargs)
2015

21-
def get_recommendations(
16+
def get_topk_matches(
2217
self, source: pd.DataFrame, target: pd.DataFrame, top_k: int
2318
) -> List[TopkMatching]:
2419
self.magneto.params["topk"] = (

0 commit comments

Comments
 (0)