1/25/23
- Add Hyper label model, please find more details in our paper.
4/20/22
- Add WS explainer, please find more details in our paper.
4/20/22
- We have updated the
setup.py
to make installation more flexible.
Please use pip install ws-benchmark==1.1.2rc0
to install the latest version. We strongly suggest create a new environment to install wrench. We will bring better compatibility in the next stable release.
If you have any problems with installation, please let us know.
Known incompatibilities:
tensorflow==2.8.0
, albumentations==0.1.12
3/18/22
- Wrench is available on ws-benchmark now, using
pip install ws-benchmark
to qucik install.
2/13/22
- Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.
11/04/21
- (beta) Add
parallel_fit
for torch model to support pytorch DistributedDataParallel-example
10/15/21
- A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
- Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the
datasets
folder)
Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.
For more information, checkout our publications:
- WRENCH: A Comprehensive Benchmark for Weak Supervision (NeurIPS 2021)
- A Survey on Programmatic Weak Supervision
If you find this repository helpful, feel free to cite our publication:
@inproceedings{
zhang2021wrench,
title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2021},
url={https://openreview.net/forum?id=Q9SKS5k8io}
}
Weak Supervision is a paradigm for automated training data creation without manual annotations.
For a brief overview, please check out this blog.
For more context, please check out this survey.
To track recent advances in weak supervision, please follow this repo.
[1] Install anaconda: Instructions here: https://www.anaconda.com/download/
[2] Clone the repository:
git clone https://github.com/JieyuZ2/wrench.git
cd wrench
[3] Create virtual environment:
conda env create -f environment.yml
source activate wrench
If this not working or you want to use only a subset of modules of Wrench, check out this wiki page
[4] Download datasets:
from huggingface_hub import snapshot_download
path = "path to local dir"
snapshot_download(repo_id="jieyuz2/WRENCH", repo_type="dataset", local_dir=path)
Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.
A documentation of dataset format and usage can be found in this wiki-page
Name | Task | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|---|
Census | income classification | 2 | 83 | 10083 | 5561 | 16281 | link | link |
Youtube | spam classification | 2 | 10 | 1586 | 120 | 250 | link | link |
SMS | spam classification | 2 | 73 | 4571 | 500 | 500 | link | link |
IMDB | sentiment classification | 2 | 8 | 20000 | 2500 | 2500 | link | link |
Yelp | sentiment classification | 2 | 8 | 30400 | 3800 | 3800 | link | link |
AGNews | topic classification | 4 | 9 | 96000 | 12000 | 12000 | link | link |
TREC | question classification | 6 | 68 | 4965 | 500 | 500 | link | link |
Spouse | relation classification | 2 | 9 | 22254 | 2801 | 2701 | link | link |
SemEval | relation classification | 9 | 164 | 1749 | 178 | 600 | link | link |
CDR | bio relation classification | 2 | 33 | 8430 | 920 | 4673 | link | link |
Chemprot | chemical relation classification | 10 | 26 | 12861 | 1607 | 1607 | link | link |
Commercial | video frame classification | 2 | 4 | 64130 | 9479 | 7496 | link | link |
Tennis Rally | video frame classification | 2 | 6 | 6959 | 746 | 1098 | link | link |
Basketball | video frame classification | 2 | 4 | 17970 | 1064 | 1222 | link | link |
DomainNet | image classification | - | - | - | - | - | link | link |
Name | # class | # LF | # train | # validation | # test | data source | LF source |
---|---|---|---|---|---|---|---|
CoNLL-03 | 4 | 16 | 14041 | 3250 | 3453 | link | link |
WikiGold | 4 | 16 | 1355 | 169 | 170 | link | link |
OntoNotes 5.0 | 18 | 17 | 115812 | 5000 | 22897 | link | link |
BC5CDR | 2 | 9 | 500 | 500 | 500 | link | link |
NCBI-Disease | 1 | 5 | 592 | 99 | 99 | link | link |
Laptop-Review | 1 | 3 | 2436 | 609 | 800 | link | link |
MIT-Restaurant | 8 | 16 | 7159 | 500 | 1521 | link | link |
MIT-Movies | 12 | 7 | 9241 | 500 | 2441 | link | link |
The detailed documentation is coming soon.
If you find any of the implementations is wrong/problematic, don't hesitate to raise issue/pull request, we really appreciate it!
TODO-list: check this out!
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Majority Voting | Label Model | -- | link |
Weighted Majority Voting | Label Model | -- | link |
Dawid-Skene | Label Model | link | link |
Data Progamming | Label Model | link | link |
MeTaL | Label Model | link | link |
FlyingSquid | Label Model | link | link |
EBCC | Label Model | link | link |
IBCC | Label Model | link | link |
FABLE | Label Model | link | link |
Hyper Label Model | Label Model | link | link |
Logistic Regression | End Model | -- | link |
MLP | End Model | -- | link |
BERT | End Model | link | link |
COSINE | End Model | link | link |
ARS2 | End Model | link | link |
Denoise | Joint Model | link | link |
WeaSEL | Joint Model | link | link |
SepLL | Joint Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Hidden Markov Model | Label Model | link | link |
Conditional Hidden Markov Model | Label Model | link | link |
LSTM-CNNs-CRF | End Model | link | link |
BERT-CRF | End Model | link | link |
LSTM-ConNet | Joint Model | link | link |
BERT-ConNet | Joint Model | link | link |
Wrench also provides a SeqLabelModelWrapper
that adaptes label model for classification task to sequence tagging task.
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
Meta-Weight-Net | End Model | link | link |
Learning2ReWeight | End Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
MeanTeacher | End Model | link | link |
Model | Model Type | Reference | Link to Wrench |
---|---|---|---|
ImplyLoss | Joint Model | link | link |
ASTRA | Joint Model | link | link |
import logging
import numpy as np
import pprint
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel
from wrench.evaluation import AverageMeter
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Specify the hyper-parameter search space for grid search
search_space = {
'Snorkel': {
'lr': np.logspace(-5, -1, num=5, base=10),
'l2': np.logspace(-5, -1, num=5, base=10),
'n_epochs': [5, 10, 50, 100, 200],
}
}
#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)
#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
metric=target, direction='auto', search_space=search_space[label_model_name],
n_repeats=n_repeats, n_trials=n_trials, parallel=True)
#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
model = label_model(**searched_paras)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
metric_value = model.test(test_data, target)
meter.update(target=metric_value)
metrics = meter.get_results()
pprint.pprint(metrics)
For detailed guidance of grid_search
, please check out this wiki page.
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target,
patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
cache_name=extract_fn, model_name=model_name)
#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)
#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000
patience = 200
evaluation_step = 50
target='acc'
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
#### Evaluate the trained model
metric_value = model.test(test_data, target)
#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label,
device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
import logging
import torch
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
n_class=2,
n_lfs=10,
alpha=0.75, # mean accuracy
beta=0.1, # mean propensity
alpha_radius=0.2, # radius of accuracy
beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)
#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)
#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')
Contact person: Jieyu Zhang, [email protected]
Don't hesitate to send us an e-mail if you have any question.
We're also open to any collaboration!
We sincerely welcome any contribution to the datasets or models!
@inproceedings{
zhang2021wrench,
title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2021},
url={https://openreview.net/forum?id=Q9SKS5k8io}
}