์ง๋ฌธ์ด ์ฃผ์ด์ง ์ํ์์ ์ง์์ ํด๋นํ๋ ๋ต์ ์ฐพ๋ Task๋ฅผ MRC(Machine Reading Comprehension)๋ผ๊ณ ํ๋ค.
ODQA๋ ์ง๋ฌธ์ด ์ฃผ์ด์ง ์ํ๊ฐ ์๋๋ผ wiki๋ ์น ์ ์ฒด ๋ฑ๊ณผ ๊ฐ์ ๋ค์ํ documents๋ค ์ค ์ ์ ํ ์ง๋ฌธ์ ์ฐพ๋ retrieval ๋จ๊ณ์ ์ถ์ถ๋ ์ง๋ฌธ๋ค ์ฌ์ด์์ ์ ์ ํ ๋ต์ ์ฐพ๋ reader ๋จ๊ณ, 2-stage๋ก ์ด๋ฃจ์ด์ง Task๋ฅผ ์๋ฏธํ๋ค.
๋ฌธ์๋ค์ ์ถ์ถํ๊ธฐ ์ํด ์ผ๋ จ์ ๋ฒกํฐ ํํ๋ก ํํํด์ผ ํ๋๋ฐ ๋ํ์ ์ผ๋ก Sparse Embedding ๋ฐฉ์๊ณผ Dense Embedding ๋ฐฉ์์ผ๋ก ๋๋์ด ์ง๋ค.
์ง์์ ๋ง๋ ์ ์ ํ ๋ต์ ์ถ์ถ๋ ๋ฌธ์๋ค ์ฌ์ด์์ ์ฐพ๋ ๋จ๊ณ๋ก, ์ ๋ต์ span์ ์์ธกํ๋ ๋ฐฉ์์ผ๋ก ํ์ต๋๋ค.
๊นํ์ | ์ด์ฑ๊ตฌ | ์ดํ์ค | ์กฐ๋ฌธ๊ธฐ | ์กฐ์ต๋ ธ |
---|---|---|---|---|
![]() |
![]() |
![]() |
![]() |
![]() |
Elasticsearch ๊ตฌ์ฑ KoELECTRA ํ์ต ๋ฐ ํ๊ฐ BERT(multilingual) ํ์ต ๋ฐ ํ๊ฐ |
ColBERT Retriever ์ ์ฉ ๋ฐ ๊ฐ์ ColBERT์ BM25 ์์๋ธ |
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐ์ดํฐ ์ฆ๊ฐ klue/RoBERTa-large ํ์ต ๋ฐ ํ๊ฐ |
K-fold ๊ตฌํ ์ต์ข ์์๋ธ ๊ตฌํ |
BM25 ๊ตฌํ Elasticsearch ๊ตฌํ |
# data (51.2 MB)
tar -xzf data.tar.gz
# ํ์ํ ํ์ด์ฌ ํจํค์ง ์ค์น.
bash ./install/install_requirements.sh
.
|-- README.md
|
|-- arguments.py # data, model, training arguments ๊ด๋ฆฌ
|
|-- colbert # dense retrieval(ColBERT) ํ์ต ๋ฐ ์ถ๋ก
| |-- evaluate.py
| |-- inference.py
| |-- model.py
| |-- tokenizer.py
| `-- train.py
|
|-- es_retrieval.py # sparse retrieval(Elasticsearch) connetion
|-- retrieval.py # tfidf, bm25, elasticsearch retrieval class
|-- settings.json # elasticsearch settings
|
|-- kfold_ensemble_hard.py # k-fold hard voting
|-- kfold_ensemble_soft.py # k-fold soft voting
|-- make_folds.py
|
|-- models # model ์ ์ฅ์
| `-- model_folder
|
|-- outputs # matrix ์ ์ฅ์
| `-- output_folder
|
|-- train.py # reader ํ์ต
|-- train_kfold.py # reader ํ์ต(with. k-fold)
|-- inference.py # retrieval + reader (end-to-end) ํ๊ฐ ๋ฐ ์ถ๋ก
|-- trainer_qa.py # Trainer class
`-- utils_qa.py # utility function
./data/ # ์ ์ฒด ๋ฐ์ดํฐ
./wikipedia_documents.json # ์ํคํผ๋์ ๋ฌธ์ ์งํฉ. retrieval์ ์ํด ์ฐ์ด๋ corpus.
์ด ์ฝ 60000๊ฐ์ ์ํคํผ๋์ ๋ฌธ์์์ context ๊ฐ ์จ์ ํ ๋์ผํ ๋ฐ์ดํฐ์ ํํด์ ์ค๋ณต ์ ๊ฑฐ๋ฅผ ํตํด ์ฝ 56000๊ฐ์ ๋ฐ์ดํฐ๋ก ์งํ.
๋ฐ์ดํฐ์ ์ ํธ์์ฑ์ ์ํด Huggingface ์์ ์ ๊ณตํ๋ datasets๋ฅผ ์ด์ฉํ์ฌ pyarrow ํ์์ ๋ฐ์ดํฐ๋ก ์ ์ฅ๋์ด ์์. ๋ฐ์ดํฐ์ ์ ๊ตฌ์ฑ
./data/ # ์ ์ฒด ๋ฐ์ดํฐ
./train_dataset/ # ํ์ต์ ์ฌ์ฉํ ๋ฐ์ดํฐ์
. train ๊ณผ validation ์ผ๋ก ๊ตฌ์ฑ
./test_dataset/ # ์ ์ถ์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์
. validation ์ผ๋ก ๊ตฌ์ฑ
์ธ๋ถ ๋ฐ์ดํฐ์ธ KorQuAD, Ko-WIKI๋ฅผ ์ถ๊ฐํ์ฌ ์ฝ 12๋ง๊ฐ์ ๋ฐ์ดํฐ์ ๊ตฌ์ฑ.
./data/ # ์ ์ฒด ๋ฐ์ดํฐ
./wiki_korQuAD_aug_dataset/ # ํ์ต์ ์ฌ์ฉํ ์ธ๋ถ ๋ฐ์ดํฐ์
.
data์ ๋ํ argument ๋ arguments.py
์ DataTrainingArguments
์์ ํ์ธ ๊ฐ๋ฅ.
๋ง์ฝ arguments ์ ๋ํ ์ธํ
์ ์ง์ ํ๊ณ ์ถ๋ค๋ฉด arguments.py
๋ฅผ ์ฐธ๊ณ .
roberta ๋ชจ๋ธ์ ์ฌ์ฉํ ๊ฒฝ์ฐ tokenizer ์ฌ์ฉ์ ์๋ ํจ์์ ์ต์
์ ์์ ํด์ผํจ.
tokenizer๋ train, validation (train.py), test(inference.py) ์ ์ฒ๋ฆฌ๋ฅผ ์ํด ํธ์ถ๋์ด ์ฌ์ฉ๋จ.
(tokenizer์ return_token_type_ids=False๋ก ์ค์ ํด์ฃผ์ด์ผ ํจ)
# train.py
def prepare_train_features(examples):
# truncation๊ณผ padding(length๊ฐ ์งง์๋๋ง)์ ํตํด toknization์ ์งํํ๋ฉฐ, stride๋ฅผ ์ด์ฉํ์ฌ overflow๋ฅผ ์ ์งํจ.
# ๊ฐ example๋ค์ ์ด์ ์ context์ ์กฐ๊ธ์ฉ ๊ฒน์ณ์ง.
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
# return_token_type_ids=False, # roberta๋ชจ๋ธ์ ์ฌ์ฉํ ๊ฒฝ์ฐ False, bert๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ True๋ก ํ๊ธฐํด์ผํจ.
padding="max_length" if data_args.pad_to_max_length else False,
)
# ํ์ต ์์ (train_dataset ์ฌ์ฉ)
python train.py --output_dir ./models/train_dataset --do_train
-
train.py
์์ sparse embedding ์ ํ๋ จํ๊ณ ์ ์ฅํ๋ ๊ณผ์ ์ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ์ง ์์ ๋ฐ๋ก argument ์ default ๊ฐ True๋ก ์ค์ ๋์ด ์์. ์คํ ํ sparse_embedding.bin ๊ณผ tfidfv.bin ์ด ์ ์ฅ๋จ. ๋ง์ฝ sparse retrieval ๊ด๋ จ ์ฝ๋๋ฅผ ์์ ํ๋ค๋ฉด, ๊ผญ ๋ ํ์ผ์ ์ง์ฐ๊ณ ๋ค์ ์คํ! ์๊ทธ๋ฌ๋ฉด ๊ธฐ์กด ํ์ผ์ด load ๋จ. -
๋ชจ๋ธ์ ๊ฒฝ์ฐ
--overwrite_cache
๋ฅผ ์ถ๊ฐํ์ง ์์ผ๋ฉด ๊ฐ์ ํด๋์ ์ ์ฅ๋์ง ์์. -
./outputs/
ํด๋ ๋ํ--overwrite_output_dir
์ ์ถ๊ฐํ์ง ์์ผ๋ฉด ๊ฐ์ ํด๋์ ์ ์ฅ๋์ง ์์.
MRC ๋ชจ๋ธ์ ํ๊ฐ๋(--do_eval
) ๋ฐ๋ก ์ค์ ํด์ผ ํจ. ์ ํ์ต ์์์ ๋จ์ํ --do_eval
์ ์ถ๊ฐ๋ก ์
๋ ฅํด์ ํ๋ จ ๋ฐ ํ๊ฐ๋ฅผ ๋์์ ์งํํ ์ ์์.
# mrc ๋ชจ๋ธ ํ๊ฐ (train_dataset ์ฌ์ฉ)
python train.py --output_dir ./outputs/train_dataset --model_name_or_path ./models/train_dataset/ --do_eval
retrieval ๊ณผ mrc ๋ชจ๋ธ์ ํ์ต์ด ์๋ฃ๋๋ฉด inference.py
๋ฅผ ์ด์ฉํด odqa ๋ฅผ ์งํํ ์ ์์.
-
ํ์ตํ ๋ชจ๋ธ์ test_dataset์ ๋ํ ๊ฒฐ๊ณผ๋ฅผ ์ ์ถํ๊ธฐ ์ํด์ ์ถ๋ก (
--do_predict
)๋ง ์งํ. -
ํ์ตํ ๋ชจ๋ธ์ด train_dataset ๋ํด์ ODQA ์ฑ๋ฅ์ด ์ด๋ป๊ฒ ๋์ค๋์ง ์๊ณ ์ถ๋ค๋ฉด ํ๊ฐ(
--do_eval
)๋ฅผ ์งํ.
# ODQA ์คํ (test_dataset ์ฌ์ฉ)
# wandb ๊ฐ ๋ก๊ทธ์ธ ๋์ด์๋ค๋ฉด ์๋์ผ๋ก ๊ฒฐ๊ณผ๊ฐ wandb ์ ์ ์ฅ. ์๋๋ฉด ๋จ์ํ ์ถ๋ ฅ๋จ
python inference.py --output_dir ./outputs/test_dataset/ --dataset_name ../data/test_dataset/ --model_name_or_path ./models/train_dataset/ --do_predict
- TF-IDF
๋จ์ด์ ๋ฑ์ฅ๋น๋(TF)์ ๋จ์ด๊ฐ ์ ๊ณตํ๋ ์ ๋ณด์ ์(IDF)๋ฅผ ์ด์ฉํ function
- BM25
๊ธฐ์กด TF-IDF๋ณด๋ค TF์ ์ํฅ๋ ฅ์ ์ค์ธ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก TF์ ํ๊ณ๋ฅผ ์ง์ ํ์ฌ ์ผ์ ๋ฒ์๋ฅผ ์ ์งํ๋๋ก ํ๋ค.
BM25๋ ๋ฌธ์์ ๊ธธ์ด๊ฐ ๋ ์์์๋ก ํฐ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ๋ function
- Elasticsearch
Elasticsearch๋ ๊ธฐ๋ณธ์ ์ผ๋ก scoring function์ผ๋ก bm25 ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๋ค.
์์ bm25์ ์ฐจ์ด์ ์ k=1.2
, b=0.75
๋ฅผ ์ฌ์ฉํ๋ค.
๊ฒ์์์ง์ Elasticsearch๋ ๋ค์ํ ํ๋ฌ๊ทธ์ธ์ ์ฌ์ฉํ ์ ์๋๋ฐ ์ฐ๋ฆฌ๋ ํํ์ ๋ถ์๊ธฐ์ธ nori-analyzer๋ฅผ ์ฌ์ฉํ๋ค. ์์ธํ setting์ settings.json
ํ์ผ์ ์ฐธ๊ณ ํ ์ ์๋ค.
- ColBERT
ColBERT๋ BERT ๊ธฐ๋ฐ์ Encoder๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ๋ก Query ์ธ์ฝ๋์
ํ๋์ BERT๋ชจ๋ธ์ ๊ณต์ ํ์ฌ ์ด๋ฅผ ๊ตฌ๋ถํ๊ธฐ ์ํด *[Q], [D]*์ ์คํ์
ํ ํฐ์ ์ฌ์ฉํ๋ฉฐ Query์ Document์ relevance score๋ฅผ ๊ตฌํ๋ ํจ์๋ Cosine similarity๋ฅผ ์ฌ์ฉํ๋ค.
ColBERT๋ ๊ฐ ๋ฌธ์์ ๋ํ ์ ์๋ฅผ ์์ฑํ๊ณ Cross-entropy loss๋ฅผ ์ด์ฉํ์ฌ ์ต์ ํ๋ฅผ ์งํํ๋ค.
๋ค์ํ ์คํ์ ํตํด ColBERT์ BM25๊ฐ์ ๊ฐ์ค์น๋ฅผ ๋ถ์ฌํ Ensemble๋ชจ๋ธ์ ์ฑํํ์ฌ ์ฌ์ฉํ์ผ๋ฉฐ ์์ธํ ์คํ ๋ด์ฉ์ ์ด๊ณณ์์ ํ์ธํ ์ ์๋ค.
Masked Language Modeling
๊ณผ Next Sentence Prediction
์ ํตํ ์ฌ์ ํ์ต
- klue/bert
- data : ๋ชจ๋์ ๋ง๋ญ์น, ์ํค, ๋ด์ค ๋ฑ ํ๊ตญ์ด ๋ฐ์ดํฐ 62GB
- vocab size : 32000
- wordpiece tokenizer
- bert-multilingual
- data : 104๊ฐ ์ธ์ด์ ์ํคํผ๋์ ๋ฐ์ดํฐ
- vocab size : 119547
- wordpiece tokenizer
Dynamic Masking
๊ธฐ๋ฒ๊ณผ ๋ ๊ธด ํ์ต ์๊ฐ๊ณผ ํฐ ๋ฐ์ดํฐ ์ฌ์ฉ
- klue/RoBERTa
- data : ์ํคํผ๋์, ๋ด์ค ๋ฑ
- vocab size : 32000
- wordpiece
- xlm/RoBERTa
- data : 100๊ฐ์ ์ธ์ด๊ฐ ํฌํจ๋๊ณ ํํฐ๋ง๋ CommonCrawl (2.5TB)
- vocab size : 250002
- sentencepiece
- Koelectra
- data : ๋ด์ค, ์ํค, ๋๋ฌด์ํค ๋ชจ๋์ ๋ง๋ญ์น ๋ฑ (20GB)
- vocab size : 35000
- wordpiece
๊ฐ ๋ชจ๋ธ๋ค์ ๊ฒฐ๊ณผ๊ฐ์ ํ๋ฅ ๋ถํฌ๋ฅผ ํตํด ๊ฐ์ฅ ํฐ ํ๋ฅ ์ ๋ํ๋ด๋ ๊ฐ์ ํด๋นํ๋ ๊ฐ์ ์ต์ข ๊ฒฐ๊ณผ๊ฐ์ผ๋ก ์ฑํํ๋ ๋ฐฉ์
๊ฐ ๋ชจ๋ธ๋ค์ด ์ถ๋ก ํ N๊ฐ์ ํ๋ฅ ๊ฐ์ ๋์ผํ ์ ๋ต์ ๊ฐ์ง ํ๋ฅ ์ ๋ชจ๋ ๋ํ์ฌ ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ์ฑํํ๋ ๋ฐฉ์
Soft Voting์ ๊ธฐ๋ฐ์ผ๋ก ์ด ํฉ์ด 10์ด ๋๋๋ก ํ๋ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฐฐํ์ฌ weighted voting์ ์ํ
Model | Exact Match | F1 score |
---|---|---|
klue/bert-base | 46.25 | 54.89 |
bert-base-multilingual | 49.58 | 55.50 |
klue/roberta-large | 69.58 | 76.84 |
xlm-roberta-large | 61.67 | 71.85 |
KoELECTRA | 57.5 | 63.11 |
Public & Private 1st