Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Numpy] Fix SQuAD + Fix GLUE downloading #1280

Merged
merged 3 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions scripts/datasets/general_nlp_benchmark/prepare_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ def read_tsv_glue(tsv_file, num_skip=1, keep_column_names=False):
nrows = len(elements)
else:
assert nrows == len(elements)
return pd.DataFrame(out, columns=column_names)
df = pd.DataFrame(out, columns=column_names)
series_l = []
for col_name in df.columns:
idx = df[col_name].first_valid_index()
val = df[col_name][idx]
if isinstance(val, str):
try:
dat = pd.to_numeric(df[col_name])
series_l.append(dat)
continue
except ValueError:
pass
finally:
pass
series_l.append(df[col_name])
new_df = pd.DataFrame({name: series for name, series in zip(df.columns, series_l)})
return new_df


def read_jsonl_superglue(jsonl_file):
Expand Down Expand Up @@ -157,6 +173,13 @@ def read_sts(dir_path):
else:
df = df[[7, 8, 1, 9]]
df.columns = ['sentence1', 'sentence2', 'genre', 'score']
genre_l = []
for ele in df['genre'].tolist():
if ele == 'main-forum':
genre_l.append('main-forums')
else:
genre_l.append(ele)
df['genre'] = pd.Series(genre_l)
df_dict[fold] = df
return df_dict, None

Expand Down Expand Up @@ -320,8 +343,8 @@ def read_rte_superglue(dir_path):
def read_wic(dir_path):
df_dict = dict()
meta_data = dict()
meta_data['entities1'] = {'type': 'entity', 'parent': 'sentence1'}
meta_data['entities2'] = {'type': 'entity', 'parent': 'sentence2'}
meta_data['entities1'] = {'type': 'entity', 'attrs': {'parent': 'sentence1'}}
meta_data['entities2'] = {'type': 'entity', 'attrs': {'parent': 'sentence2'}}

for fold in ['train', 'val', 'test']:
if fold != 'test':
Expand All @@ -340,13 +363,13 @@ def read_wic(dir_path):
end2 = row['end2']
if fold == 'test':
out.append([sentence1, sentence2,
(start1, end1),
(start2, end2)])
{'start': start1, 'end': end1},
{'start': start2, 'end': end2}])
else:
label = row['label']
out.append([sentence1, sentence2,
(start1, end1),
(start2, end2),
{'start': start1, 'end': end1},
{'start': start2, 'end': end2},
label])
df = pd.DataFrame(out, columns=columns)
df_dict[fold] = df
Expand All @@ -357,8 +380,8 @@ def read_wsc(dir_path):
df_dict = dict()
tokenizer = WhitespaceTokenizer()
meta_data = dict()
meta_data['noun'] = {'type': 'entity', 'parent': 'text'}
meta_data['pronoun'] = {'type': 'entity', 'parent': 'text'}
meta_data['noun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
meta_data['pronoun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
for fold in ['train', 'val', 'test']:
jsonl_path = os.path.join(dir_path, '{}.jsonl'.format(fold))
df = read_jsonl_superglue(jsonl_path)
Expand All @@ -374,20 +397,20 @@ def read_wsc(dir_path):
span2_text = target['span2_text']
# Build entity
# list of entities
# 'entity': {'start': 0, 'end': 100}
# 'entities': {'start': 0, 'end': 100}
tokens, offsets = tokenizer.encode_with_offsets(text, str)
pos_start1 = offsets[span1_index][0]
pos_end1 = pos_start1 + len(span1_text)
pos_start2 = offsets[span2_index][0]
pos_end2 = pos_start2 + len(span2_text)
if fold == 'test':
samples.append({'text': text,
'noun': (pos_start1, pos_end1),
'pronoun': (pos_start2, pos_end2)})
'noun': {'start': pos_start1, 'end': pos_end1},
'pronoun': {'start': pos_start2, 'end': pos_end2}})
else:
samples.append({'text': text,
'noun': (pos_start1, pos_end1),
'pronoun': (pos_start2, pos_end2),
'noun': {'start': pos_start1, 'end': pos_end1},
'pronoun': {'start': pos_start2, 'end': pos_end2},
'label': label})
df = pd.DataFrame(samples)
df_dict[fold] = df
Expand All @@ -406,8 +429,8 @@ def read_boolq(dir_path):
def read_record(dir_path):
df_dict = dict()
meta_data = dict()
meta_data['entities'] = {'type': 'entity', 'parent': 'text'}
meta_data['answers'] = {'type': 'entity', 'parent': 'text'}
meta_data['entities'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
meta_data['answers'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
for fold in ['train', 'val', 'test']:
if fold != 'test':
columns = ['source', 'text', 'entities', 'query', 'answers']
Expand All @@ -422,15 +445,11 @@ def read_record(dir_path):
passage = row['passage']
text = passage['text']
entities = passage['entities']
entities = [(ele['start'], ele['end']) for ele in entities]
entities = [{'start': ele['start'], 'end': ele['end']} for ele in entities]
for qas in row['qas']:
query = qas['query']
if fold != 'test':
answer_entities = []
for answer in qas['answers']:
start = answer['start']
end = answer['end']
answer_entities.append((start, end))
answer_entities = qas['answers']
out.append((source, text, entities, query, answer_entities))
else:
out.append((source, text, entities, query))
Expand Down Expand Up @@ -518,11 +537,15 @@ def format_mrpc(data_dir):
os.makedirs(mrpc_dir, exist_ok=True)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file)
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file)
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file,
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['train']])
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file,
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['test']])
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
download(GLUE_TASK2PATH["mrpc"]['dev'], os.path.join(mrpc_dir, "dev_ids.tsv"))
download(GLUE_TASK2PATH["mrpc"]['dev'],
os.path.join(mrpc_dir, "dev_ids.tsv"),
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['dev']])

dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
Expand Down Expand Up @@ -575,7 +598,7 @@ def get_tasks(benchmark, task_names):
@DATA_PARSER_REGISTRY.register('prepare_glue')
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", choices=['glue', 'superglue', 'sts'],
parser.add_argument("--benchmark", choices=['glue', 'superglue'],
default='glue', type=str)
parser.add_argument("-d", "--data_dir", help="directory to save data to", type=str,
default=None)
Expand Down Expand Up @@ -618,39 +641,44 @@ def main(args):
base_dir = os.path.join(args.data_dir, 'rte_diagnostic')
os.makedirs(base_dir, exist_ok=True)
download(TASK2PATH['diagnostic'][0],
path=os.path.join(base_dir, 'diagnostic.tsv'))
path=os.path.join(base_dir, 'diagnostic.tsv'),
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][0]])
download(TASK2PATH['diagnostic'][1],
path=os.path.join(base_dir, 'diagnostic-full.tsv'))
path=os.path.join(base_dir, 'diagnostic-full.tsv'),
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][1]])
df = reader(base_dir)
df.to_pickle(os.path.join(base_dir, 'diagnostic-full.pd.pkl'))
df.to_parquet(os.path.join(base_dir, 'diagnostic-full.parquet'))
else:
for key, name in [('broadcoverage-diagnostic', 'AX-b'),
('winogender-diagnostic', 'AX-g')]:
data_file = os.path.join(args.cache_path, "{}.zip".format(key))
url = TASK2PATH[key]
reader = TASK2READER[key]
download(url, data_file)
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
with zipfile.ZipFile(data_file) as zipdata:
zipdata.extractall(args.data_dir)
df = reader(os.path.join(args.data_dir, name))
df.to_pickle(os.path.join(args.data_dir, name, '{}.pd.pkl'.format(name)))
df.to_parquet(os.path.join(args.data_dir, name, '{}.parquet'.format(name)))
elif task == 'mrpc':
reader = TASK2READER[task]
format_mrpc(args.data_dir)
df_dict, meta_data = reader(os.path.join(args.data_dir, 'mrpc'))
for key, df in df_dict.items():
if key == 'val':
key = 'dev'
df.to_pickle(os.path.join(args.data_dir, 'mrpc', '{}.pd.pkl'.format(key)))
df.to_parquet(os.path.join(args.data_dir, 'mrpc', '{}.parquet'.format(key)))
with open(os.path.join(args.data_dir, 'mrpc', 'metadata.json'), 'w') as f:
json.dump(meta_data, f)
else:
# Download data
data_file = os.path.join(args.cache_path, "{}.zip".format(task))
url = TASK2PATH[task]
reader = TASK2READER[task]
download(url, data_file)
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
base_dir = os.path.join(args.data_dir, task)
if os.path.exists(base_dir):
print('Found!')
continue
zip_dir_name = None
with zipfile.ZipFile(data_file) as zipdata:
if zip_dir_name is None:
Expand All @@ -662,7 +690,7 @@ def main(args):
for key, df in df_dict.items():
if key == 'val':
key = 'dev'
df.to_pickle(os.path.join(base_dir, '{}.pd.pkl'.format(key)))
df.to_parquet(os.path.join(base_dir, '{}.parquet'.format(key)))
if meta_data is not None:
with open(os.path.join(base_dir, 'metadata.json'), 'w') as f:
json.dump(meta_data, f)
Expand Down
4 changes: 2 additions & 2 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ def train(args):
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx)
gt_end = sample.gt_end.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32)
gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32)
is_impossible = sample.is_impossible.as_in_ctx(ctx).astype(np.int32)
batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
Expand Down