Skip to content

Commit 7efd2b5

Browse files
authored
Merge pull request #12 from codertimo/alpha0.0.1a3
alpha0.0.1a3 version update
2 parents 5b9f139 + b8f27e3 commit 7efd2b5

File tree

12 files changed

+153
-220
lines changed

12 files changed

+153
-220
lines changed

README.md

+16-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# BERT-pytorch
22

3-
[![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/kor2vec/kor2vec/blob/master/LICENSE)
3+
[![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/blob/master/LICENSE)
44
![GitHub issues](https://img.shields.io/github/issues/codertimo/BERT-pytorch.svg)
5-
[![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/kor2vec/kor2vec/stargazers)
6-
[![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/kor2vec/kor2vec)
5+
[![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/stargazers)
6+
[![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/codertimo/BERT-pytorch)
77
[![PyPI](https://img.shields.io/pypi/v/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/)
88
[![PyPI - Status](https://img.shields.io/pypi/status/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/)
99
[![Documentation Status](https://readthedocs.org/projects/bert-pytorch/badge/?version=latest)](https://bert-pytorch.readthedocs.io/en/latest/?badge=latest)
@@ -39,24 +39,28 @@ pip install bert-pytorch
3939
## Quickstart
4040

4141
**NOTICE : Your corpus should be prepared with two sentences in one line with tab(\t) separator**
42+
43+
### 0. Prepare your corpus
4244
```
43-
Welcome to the \t the jungle \n
44-
I can stay \t here all night \n
45+
Welcome to the \t the jungle\n
46+
I can stay \t here all night\n
4547
```
4648

47-
### 1. Building vocab based on your corpus
48-
```shell
49-
bert-vocab -c data/corpus.small -o data/corpus.small.vocab
49+
or tokenized corpus (tokenization is not in package)
5050
```
51+
Wel_ _come _to _the \t _the _jungle\n
52+
_I _can _stay \t _here _all _night\n
53+
```
54+
5155

52-
### 2. Building BERT train dataset with your corpus
56+
### 1. Building vocab based on your corpus
5357
```shell
54-
bert-dataset -d data/corpus.small -v data/corpus.small.vocab -o data/dataset.small
58+
bert-vocab -c data/corpus.small -o data/vocab.small
5559
```
5660

57-
### 3. Train your own BERT model
61+
### 2. Train your own BERT model
5862
```shell
59-
bert -d data/dataset.small -v data/corpus.small.vocab -o output/bert.model
63+
bert -c data/dataset.small -v data/vocab.small -o output/bert.model
6064
```
6165

6266
## Language Model Pre-training

bert_pytorch/__main__.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import argparse
2+
3+
from torch.utils.data import DataLoader
4+
5+
from .model import BERT
6+
from .trainer import BERTTrainer
7+
from .dataset import BERTDataset, WordVocab
8+
9+
10+
def train():
11+
parser = argparse.ArgumentParser()
12+
13+
parser.add_argument("-c", "--train_dataset", required=True, type=str)
14+
parser.add_argument("-t", "--test_dataset", type=str, default=None)
15+
parser.add_argument("-v", "--vocab_path", required=True, type=str)
16+
parser.add_argument("-o", "--output_path", required=True, type=str)
17+
18+
parser.add_argument("-hs", "--hidden", type=int, default=256)
19+
parser.add_argument("-l", "--layers", type=int, default=8)
20+
parser.add_argument("-a", "--attn_heads", type=int, default=8)
21+
parser.add_argument("-s", "--seq_len", type=int, default=20)
22+
23+
parser.add_argument("-b", "--batch_size", type=int, default=64)
24+
parser.add_argument("-e", "--epochs", type=int, default=10)
25+
parser.add_argument("-w", "--num_workers", type=int, default=5)
26+
parser.add_argument("--with_cuda", type=bool, default=True)
27+
parser.add_argument("--log_freq", type=int, default=10)
28+
parser.add_argument("--corpus_lines", type=int, default=None)
29+
30+
parser.add_argument("--lr", type=float, default=1e-3)
31+
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
32+
parser.add_argument("--adam_beta1", type=float, default=0.9)
33+
parser.add_argument("--adam_beta2", type=float, default=0.999)
34+
35+
args = parser.parse_args()
36+
37+
print("Loading Vocab", args.vocab_path)
38+
vocab = WordVocab.load_vocab(args.vocab_path)
39+
print("Vocab Size: ", len(vocab))
40+
41+
print("Loading Train Dataset", args.train_dataset)
42+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)
43+
44+
print("Loading Test Dataset", args.test_dataset)
45+
test_dataset = BERTDataset(args.test_dataset, vocab,
46+
seq_len=args.seq_len) if args.test_dataset is not None else None
47+
48+
print("Creating Dataloader")
49+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
50+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
51+
if test_dataset is not None else None
52+
53+
print("Building BERT model")
54+
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
55+
56+
print("Creating BERT Trainer")
57+
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
58+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
59+
with_cuda=args.with_cuda, log_freq=args.log_freq)
60+
61+
print("Training Start")
62+
for epoch in range(args.epochs):
63+
trainer.train(epoch)
64+
trainer.save(epoch, args.output_path)
65+
66+
if test_data_loader is not None:
67+
trainer.test(epoch)

bert_pytorch/build_dataset.py

-41
This file was deleted.

bert_pytorch/build_vocab.py

-19
This file was deleted.

bert_pytorch/dataset/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .dataset import BERTDataset
2-
from .creator import BERTDatasetCreator
32
from .vocab import WordVocab

bert_pytorch/dataset/creator.py

-61
This file was deleted.

bert_pytorch/dataset/dataset.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
from torch.utils.data import Dataset
22
import tqdm
33
import torch
4+
import random
45

56

67
class BERTDataset(Dataset):
78
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
89
self.vocab = vocab
910
self.seq_len = seq_len
1011

11-
self.datas = []
1212
with open(corpus_path, "r", encoding=encoding) as f:
13-
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
14-
t1, t2, t1_l, t2_l, is_next = line[:-1].split("\t")
15-
t1, t2 = [[int(token) for token in t.split(",")] for t in [t1, t2]]
16-
t1_l, t2_l = [[int(token) for token in label.split(",")] for label in [t1_l, t2_l]]
17-
is_next = int(is_next)
18-
self.datas.append({"t1": t1, "t2": t2, "t1_label": t1_l, "t2_label": t2_l, "is_next": is_next})
13+
self.datas = [line[:-1].split("\t")
14+
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
1915

2016
def __len__(self):
2117
return len(self.datas)
2218

2319
def __getitem__(self, item):
20+
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
21+
t1_random, t1_label = self.random_word(t1)
22+
t2_random, t2_label = self.random_word(t2)
23+
2424
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
25-
t1 = [self.vocab.sos_index] + self.datas[item]["t1"] + [self.vocab.eos_index]
26-
t2 = self.datas[item]["t2"] + [self.vocab.eos_index]
25+
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
26+
t2 = t2_random + [self.vocab.eos_index]
2727

28-
t1_label = [0] + self.datas[item]["t1_label"] + [0]
29-
t2_label = self.datas[item]["t2_label"] + [0]
28+
t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
29+
t2_label = t2_label + [self.vocab.pad_index]
3030

3131
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
3232
bert_input = (t1 + t2)[:self.seq_len]
@@ -38,6 +38,40 @@ def __getitem__(self, item):
3838
output = {"bert_input": bert_input,
3939
"bert_label": bert_label,
4040
"segment_label": segment_label,
41-
"is_next": self.datas[item]["is_next"]}
41+
"is_next": is_next_label}
4242

4343
return {key: torch.tensor(value) for key, value in output.items()}
44+
45+
def random_word(self, sentence):
46+
tokens = sentence.split()
47+
output_label = []
48+
49+
for i, token in enumerate(tokens):
50+
prob = random.random()
51+
if prob < 0.15:
52+
# 80% randomly change token to make token
53+
if prob < prob * 0.8:
54+
tokens[i] = self.vocab.mask_index
55+
56+
# 10% randomly change token to random token
57+
elif prob * 0.8 <= prob < prob * 0.9:
58+
tokens[i] = random.randrange(len(self.vocab))
59+
60+
# 10% randomly change token to current token
61+
elif prob >= prob * 0.9:
62+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
63+
64+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
65+
66+
else:
67+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
68+
output_label.append(0)
69+
70+
return tokens, output_label
71+
72+
def random_sent(self, index):
73+
# output_text, label(isNotNext:0, isNext:1)
74+
if random.random() > 0.5:
75+
return self.datas[index][1], 1
76+
else:
77+
return self.datas[random.randrange(len(self.datas))][1], 0

bert_pytorch/dataset/vocab.py

+18
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,21 @@ def from_seq(self, seq, join=False, with_pad=False):
165165
def load_vocab(vocab_path: str) -> 'WordVocab':
166166
with open(vocab_path, "rb") as f:
167167
return pickle.load(f)
168+
169+
170+
def build():
171+
import argparse
172+
173+
parser = argparse.ArgumentParser()
174+
parser.add_argument("-c", "--corpus_path", required=True, type=str)
175+
parser.add_argument("-o", "--output_path", required=True, type=str)
176+
parser.add_argument("-s", "--vocab_size", type=int, default=None)
177+
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
178+
parser.add_argument("-m", "--min_freq", type=int, default=1)
179+
args = parser.parse_args()
180+
181+
with open(args.corpus_path, "r", encoding=args.encoding) as f:
182+
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)
183+
184+
print("VOCAB SIZE:", len(vocab))
185+
vocab.save_vocab(args.output_path)

0 commit comments

Comments
 (0)