forked from huggingface/transfer-learning-conv-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
90 lines (76 loc) · 3.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) 2019-present, HuggingFace Inc.
# All rights reserved. This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import os
import tarfile
import tempfile
import torch
from pytorch_pretrained_bert import cached_path
PERSONACHAT_URL = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/finetuned_chatbot_gpt.tar.gz"
logger = logging.getLogger(__file__)
def download_pretrained_model():
""" Download and extract finetuned model from S3 """
resolved_archive_file = cached_path(HF_FINETUNED_MODEL)
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
return tempdir
def get_dataset(tokenizer, dataset_path, dataset_cache=None):
""" Get PERSONACHAT from S3 """
dataset_path = dataset_path or PERSONACHAT_URL
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa
if dataset_cache and os.path.isfile(dataset_cache):
logger.info("Load tokenized dataset from cache at %s", dataset_cache)
dataset = torch.load(dataset_cache)
else:
logger.info("Download dataset from %s", dataset_path)
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
dataset = json.loads(f.read())
logger.info("Tokenize and encode the dataset")
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
dataset = tokenize(dataset)
if dataset_cache:
torch.save(dataset, dataset_cache)
return dataset
def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None):
""" Get personalities from PERSONACHAT """
dataset_path = dataset_path or PERSONACHAT_URL
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa
if os.path.isfile(dataset_cache):
logger.info("Load tokenized dataset from cache at %s", dataset_cache)
personachat = torch.load(dataset_cache)
else:
logger.info("Download PERSONACHAT dataset from %s", dataset_path)
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
personachat = json.loads(f.read())
logger.info("Tokenize and encode the dataset")
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
personachat = tokenize(personachat)
torch.save(personachat, dataset_cache)
logger.info("Filter personalities")
personalities = []
for dataset in personachat.values():
for dialog in dataset:
personalities.append(dialog["personality"])
logger.info("Gathered {} personalities".format(len(personalities)))
return personalities
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self