|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 |
| -import itertools |
17 |
| -import unittest |
18 |
| - |
19 |
| -import oneflow.utils.data as flowdata |
20 |
| - |
21 |
| -from libai.data.samplers import CyclicSampler, SingleRoundSampler |
22 |
| - |
23 |
| - |
24 |
| -class TestCyclicSampler(unittest.TestCase): |
25 |
| - def test_cyclic_sampler_iterable(self): |
26 |
| - sampler = CyclicSampler( |
27 |
| - list(range(100)), |
28 |
| - micro_batch_size=4, |
29 |
| - shuffle=True, |
30 |
| - consumed_samples=0, |
31 |
| - seed=123, |
32 |
| - ) |
33 |
| - output_iter = itertools.islice(sampler, 25) # iteration=100/4=25 |
34 |
| - sample_output = list() |
35 |
| - for batch in output_iter: |
36 |
| - sample_output.extend(batch) |
37 |
| - self.assertEqual(set(sample_output), set(range(100))) |
38 |
| - |
39 |
| - data_sampler = CyclicSampler( |
40 |
| - list(range(100)), |
41 |
| - micro_batch_size=4, |
42 |
| - shuffle=True, |
43 |
| - consumed_samples=0, |
44 |
| - seed=123, |
45 |
| - ) |
46 |
| - |
47 |
| - data_loader = flowdata.DataLoader( |
48 |
| - list(range(100)), batch_sampler=data_sampler, num_workers=0, collate_fn=lambda x: x |
49 |
| - ) |
50 |
| - |
51 |
| - data_loader_iter = itertools.islice(data_loader, 25) |
52 |
| - output = list() |
53 |
| - for data in data_loader_iter: |
54 |
| - output.extend(data) |
55 |
| - self.assertEqual(output, sample_output) |
56 |
| - |
57 |
| - def test_cyclic_sampler_seed(self): |
58 |
| - sampler = CyclicSampler( |
59 |
| - list(range(100)), |
60 |
| - micro_batch_size=4, |
61 |
| - shuffle=True, |
62 |
| - seed=123, |
63 |
| - ) |
64 |
| - |
65 |
| - data = list(itertools.islice(sampler, 65)) |
66 |
| - |
67 |
| - sampler = CyclicSampler( |
68 |
| - list(range(100)), |
69 |
| - micro_batch_size=4, |
70 |
| - shuffle=True, |
71 |
| - seed=123, |
72 |
| - ) |
73 |
| - |
74 |
| - data2 = list(itertools.islice(sampler, 65)) |
75 |
| - self.assertEqual(data, data2) |
76 |
| - |
77 |
| - def test_cyclic_sampler_resume(self): |
78 |
| - # Single rank |
79 |
| - sampler = CyclicSampler( |
80 |
| - list(range(10)), |
81 |
| - micro_batch_size=4, |
82 |
| - shuffle=True, |
83 |
| - seed=123, |
84 |
| - ) |
85 |
| - |
86 |
| - all_output = list(itertools.islice(sampler, 50)) # iteration 50 times |
87 |
| - |
88 |
| - sampler = CyclicSampler( |
89 |
| - list(range(10)), |
90 |
| - micro_batch_size=4, |
91 |
| - shuffle=True, |
92 |
| - seed=123, |
93 |
| - consumed_samples=4 * 11, # consumed 11 iters |
94 |
| - ) |
95 |
| - |
96 |
| - resume_output = list(itertools.islice(sampler, 39)) |
97 |
| - self.assertEqual(all_output[11:], resume_output) |
98 |
| - |
99 |
| - def test_cyclic_sampler_resume_multi_rank(self): |
100 |
| - # Multiple ranks |
101 |
| - sampler_rank0 = CyclicSampler( |
102 |
| - list(range(10)), |
103 |
| - micro_batch_size=4, |
104 |
| - shuffle=True, |
105 |
| - seed=123, |
106 |
| - data_parallel_rank=0, |
107 |
| - data_parallel_size=2, |
108 |
| - ) |
109 |
| - sampler_rank1 = CyclicSampler( |
110 |
| - list(range(10)), |
111 |
| - micro_batch_size=4, |
112 |
| - shuffle=True, |
113 |
| - seed=123, |
114 |
| - data_parallel_rank=1, |
115 |
| - data_parallel_size=2, |
116 |
| - ) |
117 |
| - |
118 |
| - all_output_rank0 = list(itertools.islice(sampler_rank0, 50)) # iteration 50 times |
119 |
| - all_output_rank1 = list(itertools.islice(sampler_rank1, 50)) # iteration 50 times |
120 |
| - |
121 |
| - sampler_rank0 = CyclicSampler( |
122 |
| - list(range(10)), |
123 |
| - micro_batch_size=4, |
124 |
| - shuffle=True, |
125 |
| - seed=123, |
126 |
| - data_parallel_rank=0, |
127 |
| - data_parallel_size=2, |
128 |
| - consumed_samples=4 * 11, # consumed 11 iters |
129 |
| - ) |
130 |
| - sampler_rank1 = CyclicSampler( |
131 |
| - list(range(10)), |
132 |
| - micro_batch_size=4, |
133 |
| - shuffle=True, |
134 |
| - seed=123, |
135 |
| - data_parallel_rank=1, |
136 |
| - data_parallel_size=2, |
137 |
| - consumed_samples=4 * 11, # consumed 11 iters |
138 |
| - ) |
139 |
| - |
140 |
| - resume_output_rank0 = list(itertools.islice(sampler_rank0, 39)) |
141 |
| - resume_output_rank1 = list(itertools.islice(sampler_rank1, 39)) |
142 |
| - |
143 |
| - self.assertEqual(all_output_rank0[11:], resume_output_rank0) |
144 |
| - self.assertEqual(all_output_rank1[11:], resume_output_rank1) |
145 |
| - |
146 |
| - |
147 |
| -class TestSingleRoundSampler(unittest.TestCase): |
148 |
| - def test_single_sampler_iterable(self): |
149 |
| - sampler = SingleRoundSampler( |
150 |
| - list(range(100)), |
151 |
| - micro_batch_size=4, |
152 |
| - shuffle=False, |
153 |
| - ) |
154 |
| - output_iter = itertools.islice(sampler, 30) # exceed iteration number |
155 |
| - sample_output = list() |
156 |
| - for batch in output_iter: |
157 |
| - sample_output.extend(batch) |
158 |
| - self.assertEqual(sample_output, list(range(100))) |
159 |
| - |
160 |
| - def test_single_sampler_multi_rank(self): |
161 |
| - sampler_rank0 = SingleRoundSampler( |
162 |
| - list(range(101)), |
163 |
| - micro_batch_size=4, |
164 |
| - shuffle=False, |
165 |
| - data_parallel_rank=0, |
166 |
| - data_parallel_size=2, |
167 |
| - ) |
168 |
| - sampler_rank1 = SingleRoundSampler( |
169 |
| - list(range(101)), |
170 |
| - micro_batch_size=4, |
171 |
| - shuffle=False, |
172 |
| - data_parallel_rank=1, |
173 |
| - data_parallel_size=2, |
174 |
| - ) |
175 |
| - |
176 |
| - output_iter_rank0 = itertools.islice(sampler_rank0, 30) |
177 |
| - sample_output_rank0 = list() |
178 |
| - for batch in output_iter_rank0: |
179 |
| - sample_output_rank0.extend(batch) |
180 |
| - |
181 |
| - output_iter_rank1 = itertools.islice(sampler_rank1, 30) |
182 |
| - sample_output_rank1 = list() |
183 |
| - for batch in output_iter_rank1: |
184 |
| - sample_output_rank1.extend(batch) |
185 |
| - |
186 |
| - # Padding 0 if it's not enough for a batch, otherwise `to_global` |
187 |
| - # will raise errors for imbalanced data shape in different ranks |
188 |
| - self.assertEqual(sample_output_rank0, list(range(51))) |
189 |
| - self.assertEqual(sample_output_rank1, list(range(51, 101)) + [0]) |
190 |
| - |
191 |
| - |
192 |
| -if __name__ == "__main__": |
193 |
| - unittest.main() |
| 16 | +# import itertools |
| 17 | +# import unittest |
| 18 | + |
| 19 | +# import oneflow.utils.data as flowdata |
| 20 | + |
| 21 | +# from libai.data.samplers import CyclicSampler, SingleRoundSampler |
| 22 | + |
| 23 | + |
| 24 | +# class TestCyclicSampler(unittest.TestCase): |
| 25 | +# def test_cyclic_sampler_iterable(self): |
| 26 | +# sampler = CyclicSampler( |
| 27 | +# list(range(100)), |
| 28 | +# micro_batch_size=4, |
| 29 | +# shuffle=True, |
| 30 | +# consumed_samples=0, |
| 31 | +# seed=123, |
| 32 | +# ) |
| 33 | +# output_iter = itertools.islice(sampler, 25) # iteration=100/4=25 |
| 34 | +# sample_output = list() |
| 35 | +# for batch in output_iter: |
| 36 | +# sample_output.extend(batch) |
| 37 | +# self.assertEqual(set(sample_output), set(range(100))) |
| 38 | + |
| 39 | +# data_sampler = CyclicSampler( |
| 40 | +# list(range(100)), |
| 41 | +# micro_batch_size=4, |
| 42 | +# shuffle=True, |
| 43 | +# consumed_samples=0, |
| 44 | +# seed=123, |
| 45 | +# ) |
| 46 | + |
| 47 | +# data_loader = flowdata.DataLoader( |
| 48 | +# list(range(100)), batch_sampler=data_sampler, num_workers=0, collate_fn=lambda x: x |
| 49 | +# ) |
| 50 | + |
| 51 | +# data_loader_iter = itertools.islice(data_loader, 25) |
| 52 | +# output = list() |
| 53 | +# for data in data_loader_iter: |
| 54 | +# output.extend(data) |
| 55 | +# self.assertEqual(output, sample_output) |
| 56 | + |
| 57 | +# def test_cyclic_sampler_seed(self): |
| 58 | +# sampler = CyclicSampler( |
| 59 | +# list(range(100)), |
| 60 | +# micro_batch_size=4, |
| 61 | +# shuffle=True, |
| 62 | +# seed=123, |
| 63 | +# ) |
| 64 | + |
| 65 | +# data = list(itertools.islice(sampler, 65)) |
| 66 | + |
| 67 | +# sampler = CyclicSampler( |
| 68 | +# list(range(100)), |
| 69 | +# micro_batch_size=4, |
| 70 | +# shuffle=True, |
| 71 | +# seed=123, |
| 72 | +# ) |
| 73 | + |
| 74 | +# data2 = list(itertools.islice(sampler, 65)) |
| 75 | +# self.assertEqual(data, data2) |
| 76 | + |
| 77 | +# def test_cyclic_sampler_resume(self): |
| 78 | +# # Single rank |
| 79 | +# sampler = CyclicSampler( |
| 80 | +# list(range(10)), |
| 81 | +# micro_batch_size=4, |
| 82 | +# shuffle=True, |
| 83 | +# seed=123, |
| 84 | +# ) |
| 85 | + |
| 86 | +# all_output = list(itertools.islice(sampler, 50)) # iteration 50 times |
| 87 | + |
| 88 | +# sampler = CyclicSampler( |
| 89 | +# list(range(10)), |
| 90 | +# micro_batch_size=4, |
| 91 | +# shuffle=True, |
| 92 | +# seed=123, |
| 93 | +# consumed_samples=4 * 11, # consumed 11 iters |
| 94 | +# ) |
| 95 | + |
| 96 | +# resume_output = list(itertools.islice(sampler, 39)) |
| 97 | +# self.assertEqual(all_output[11:], resume_output) |
| 98 | + |
| 99 | +# def test_cyclic_sampler_resume_multi_rank(self): |
| 100 | +# # Multiple ranks |
| 101 | +# sampler_rank0 = CyclicSampler( |
| 102 | +# list(range(10)), |
| 103 | +# micro_batch_size=4, |
| 104 | +# shuffle=True, |
| 105 | +# seed=123, |
| 106 | +# data_parallel_rank=0, |
| 107 | +# data_parallel_size=2, |
| 108 | +# ) |
| 109 | +# sampler_rank1 = CyclicSampler( |
| 110 | +# list(range(10)), |
| 111 | +# micro_batch_size=4, |
| 112 | +# shuffle=True, |
| 113 | +# seed=123, |
| 114 | +# data_parallel_rank=1, |
| 115 | +# data_parallel_size=2, |
| 116 | +# ) |
| 117 | + |
| 118 | +# all_output_rank0 = list(itertools.islice(sampler_rank0, 50)) # iteration 50 times |
| 119 | +# all_output_rank1 = list(itertools.islice(sampler_rank1, 50)) # iteration 50 times |
| 120 | + |
| 121 | +# sampler_rank0 = CyclicSampler( |
| 122 | +# list(range(10)), |
| 123 | +# micro_batch_size=4, |
| 124 | +# shuffle=True, |
| 125 | +# seed=123, |
| 126 | +# data_parallel_rank=0, |
| 127 | +# data_parallel_size=2, |
| 128 | +# consumed_samples=4 * 11, # consumed 11 iters |
| 129 | +# ) |
| 130 | +# sampler_rank1 = CyclicSampler( |
| 131 | +# list(range(10)), |
| 132 | +# micro_batch_size=4, |
| 133 | +# shuffle=True, |
| 134 | +# seed=123, |
| 135 | +# data_parallel_rank=1, |
| 136 | +# data_parallel_size=2, |
| 137 | +# consumed_samples=4 * 11, # consumed 11 iters |
| 138 | +# ) |
| 139 | + |
| 140 | +# resume_output_rank0 = list(itertools.islice(sampler_rank0, 39)) |
| 141 | +# resume_output_rank1 = list(itertools.islice(sampler_rank1, 39)) |
| 142 | + |
| 143 | +# self.assertEqual(all_output_rank0[11:], resume_output_rank0) |
| 144 | +# self.assertEqual(all_output_rank1[11:], resume_output_rank1) |
| 145 | + |
| 146 | + |
| 147 | +# class TestSingleRoundSampler(unittest.TestCase): |
| 148 | +# def test_single_sampler_iterable(self): |
| 149 | +# sampler = SingleRoundSampler( |
| 150 | +# list(range(100)), |
| 151 | +# micro_batch_size=4, |
| 152 | +# shuffle=False, |
| 153 | +# ) |
| 154 | +# output_iter = itertools.islice(sampler, 30) # exceed iteration number |
| 155 | +# sample_output = list() |
| 156 | +# for batch in output_iter: |
| 157 | +# sample_output.extend(batch) |
| 158 | +# self.assertEqual(sample_output, list(range(100))) |
| 159 | + |
| 160 | +# def test_single_sampler_multi_rank(self): |
| 161 | +# sampler_rank0 = SingleRoundSampler( |
| 162 | +# list(range(101)), |
| 163 | +# micro_batch_size=4, |
| 164 | +# shuffle=False, |
| 165 | +# data_parallel_rank=0, |
| 166 | +# data_parallel_size=2, |
| 167 | +# ) |
| 168 | +# sampler_rank1 = SingleRoundSampler( |
| 169 | +# list(range(101)), |
| 170 | +# micro_batch_size=4, |
| 171 | +# shuffle=False, |
| 172 | +# data_parallel_rank=1, |
| 173 | +# data_parallel_size=2, |
| 174 | +# ) |
| 175 | + |
| 176 | +# output_iter_rank0 = itertools.islice(sampler_rank0, 30) |
| 177 | +# sample_output_rank0 = list() |
| 178 | +# for batch in output_iter_rank0: |
| 179 | +# sample_output_rank0.extend(batch) |
| 180 | + |
| 181 | +# output_iter_rank1 = itertools.islice(sampler_rank1, 30) |
| 182 | +# sample_output_rank1 = list() |
| 183 | +# for batch in output_iter_rank1: |
| 184 | +# sample_output_rank1.extend(batch) |
| 185 | + |
| 186 | +# # Padding 0 if it's not enough for a batch, otherwise `to_global` |
| 187 | +# # will raise errors for imbalanced data shape in different ranks |
| 188 | +# self.assertEqual(sample_output_rank0, list(range(51))) |
| 189 | +# self.assertEqual(sample_output_rank1, list(range(51, 101)) + [0]) |
| 190 | + |
| 191 | + |
| 192 | +# if __name__ == "__main__": |
| 193 | +# unittest.main() |
0 commit comments