2222class  ShardConfigTest (parameterized .TestCase ):
2323
2424  @parameterized .named_parameters ( 
25-       ('imagenet train, 137 GiB' , 137  <<  30 , 1281167 , True , 1024 ), 
26-       ('imagenet evaluation, 6.3 GiB' , 6300  *  (1  <<  20 ), 50000 , True , 64 ), 
27-       ('very large, but few examples, 52 GiB' , 52  <<  30 , 512 , True , 512 ), 
28-       ('xxl, 10 TiB' , 10  <<  40 , 10 ** 9 , True , 11264 ), 
29-       ('xxl, 10 PiB, 100B examples' , 10  <<  50 , 10 ** 11 , True , 10487808 ), 
30-       ('xs, 100 MiB, 100K records' , 10  <<  20 , 100  *  10 ** 3 , True , 1 ), 
31-       ('m, 499 MiB, 200K examples' , 400  <<  20 , 200  *  10 ** 3 , True , 4 ), 
25+       dict ( 
26+           testcase_name = 'imagenet train, 137 GiB' , 
27+           total_size = 137  <<  30 , 
28+           num_examples = 1281167 , 
29+           uses_precise_sharding = True , 
30+           max_size = None , 
31+           expected_num_shards = 1024 , 
32+       ), 
33+       dict ( 
34+           testcase_name = 'imagenet evaluation, 6.3 GiB' , 
35+           total_size = 6300  *  (1  <<  20 ), 
36+           num_examples = 50000 , 
37+           uses_precise_sharding = True , 
38+           max_size = None , 
39+           expected_num_shards = 64 , 
40+       ), 
41+       dict ( 
42+           testcase_name = 'very large, but few examples, 52 GiB' , 
43+           total_size = 52  <<  30 , 
44+           num_examples = 512 , 
45+           uses_precise_sharding = True , 
46+           max_size = None , 
47+           expected_num_shards = 512 , 
48+       ), 
49+       dict ( 
50+           testcase_name = 'xxl, 10 TiB' , 
51+           total_size = 10  <<  40 , 
52+           num_examples = 10 ** 9 , 
53+           uses_precise_sharding = True , 
54+           max_size = None , 
55+           expected_num_shards = 11264 , 
56+       ), 
57+       dict ( 
58+           testcase_name = 'xxl, 10 PiB, 100B examples' , 
59+           total_size = 10  <<  50 , 
60+           num_examples = 10 ** 11 , 
61+           uses_precise_sharding = True , 
62+           max_size = None , 
63+           expected_num_shards = 10487808 , 
64+       ), 
65+       dict ( 
66+           testcase_name = 'xs, 100 MiB, 100K records' , 
67+           total_size = 10  <<  20 , 
68+           num_examples = 100  *  10 ** 3 , 
69+           uses_precise_sharding = True , 
70+           max_size = None , 
71+           expected_num_shards = 1 , 
72+       ), 
73+       dict ( 
74+           testcase_name = 'm, 499 MiB, 200K examples' , 
75+           total_size = 400  <<  20 , 
76+           num_examples = 200  *  10 ** 3 , 
77+           uses_precise_sharding = True , 
78+           max_size = None , 
79+           expected_num_shards = 4 , 
80+       ), 
81+       dict ( 
82+           testcase_name = '100GiB, even example sizes' , 
83+           num_examples = 1e9 ,  # 1B examples  
84+           total_size = 1e9  *  1000 ,  # On average 1000 bytes per example  
85+           max_size = 1000 ,  # Max example size is 4000 bytes  
86+           uses_precise_sharding = True , 
87+           expected_num_shards = 1024 , 
88+       ), 
89+       dict ( 
90+           testcase_name = '100GiB, uneven example sizes' , 
91+           num_examples = 1e9 ,  # 1B examples  
92+           total_size = 1e9  *  1000 ,  # On average 1000 bytes per example  
93+           max_size = 4  *  1000 ,  # Max example size is 4000 bytes  
94+           uses_precise_sharding = True , 
95+           expected_num_shards = 4096 , 
96+       ), 
97+       dict ( 
98+           testcase_name = '100GiB, very uneven example sizes' , 
99+           num_examples = 1e9 ,  # 1B examples  
100+           total_size = 1e9  *  1000 ,  # On average 1000 bytes per example  
101+           max_size = 16  *  1000 ,  # Max example size is 16x the average bytes  
102+           uses_precise_sharding = True , 
103+           expected_num_shards = 15360 , 
104+       ), 
32105  ) 
33106  def  test_get_number_shards_default_config (
34-       self , total_size , num_examples , uses_precise_sharding , expected_num_shards 
107+       self ,
108+       total_size : int ,
109+       num_examples : int ,
110+       uses_precise_sharding : bool ,
111+       max_size : int ,
112+       expected_num_shards : int ,
35113  ):
36114    shard_config  =  shard_utils .ShardConfig ()
37115    self .assertEqual (
38116        expected_num_shards ,
39117        shard_config .get_number_shards (
40118            total_size = total_size ,
41119            num_examples = num_examples ,
120+             max_example_size = max_size ,  # max(1, total_size // num_examples), 
42121            uses_precise_sharding = uses_precise_sharding ,
43122        ),
44123    )
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48127    self .assertEqual (
49128        42 ,
50129        shard_config .get_number_shards (
51-             total_size = 100 , num_examples = 1 , uses_precise_sharding = True 
130+             total_size = 100 ,
131+             max_example_size = 100 ,
132+             num_examples = 1 ,
133+             uses_precise_sharding = True ,
52134        ),
53135    )
54136
0 commit comments