-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar100.py
38 lines (29 loc) · 1.19 KB
/
cifar100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os.path
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--i', default='raw/cifar100.pt', help='input directory')
parser.add_argument('--o', default='cifar100.pt', help='output file')
parser.add_argument('--n_tasks', default=10, type=int, help='number of tasks')
parser.add_argument('--seed', default=0, type=int, help='random seed')
args = parser.parse_args()
torch.manual_seed(args.seed)
tasks_tr = []
tasks_te = []
x_tr, y_tr, x_te, y_te = torch.load(os.path.join(args.i))
x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0
x_te = x_te.float().view(x_te.size(0), -1) / 255.0
cpt = int(100 / args.n_tasks)
for t in range(args.n_tasks):
c1 = t * cpt
c2 = (t + 1) * cpt
i_tr = ((y_tr >= c1) & (y_tr < c2)).nonzero().view(-1)
i_te = ((y_te >= c1) & (y_te < c2)).nonzero().view(-1)
tasks_tr.append([(c1, c2), x_tr[i_tr].clone(), y_tr[i_tr].clone()])
tasks_te.append([(c1, c2), x_te[i_te].clone(), y_te[i_te].clone()])
torch.save([tasks_tr, tasks_te], args.o)