Skip to content

Commit 0037599

Browse files
Yuanjing ShiLucien0
authored andcommitted
remove exception handling of autotvm xgboost extract functions (apache#10948)
1 parent c4c1d54 commit 0037599

File tree

1 file changed

+49
-69
lines changed

1 file changed

+49
-69
lines changed

python/tvm/autotvm/tuner/xgboost_cost_model.py

Lines changed: 49 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -360,98 +360,78 @@ def _extract_popen_initializer(space, target, task):
360360

361361
def _extract_itervar_feature_index(args):
362362
"""extract iteration var feature for an index in extract_space"""
363-
try:
364-
config = _extract_space.get(args)
365-
with _extract_target:
366-
sch, fargs = _extract_task.instantiate(config)
363+
config = _extract_space.get(args)
364+
with _extract_target:
365+
sch, fargs = _extract_task.instantiate(config)
367366

368-
fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True)
369-
fea = np.concatenate((fea, list(config.get_other_option().values())))
370-
return fea
371-
except Exception: # pylint: disable=broad-except
372-
return None
367+
fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True)
368+
fea = np.concatenate((fea, list(config.get_other_option().values())))
369+
return fea
373370

374371

375372
def _extract_itervar_feature_log(arg):
376373
"""extract iteration var feature for log items"""
377-
try:
378-
inp, res = arg
379-
config = inp.config
380-
with inp.target:
381-
sch, args = inp.task.instantiate(config)
382-
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
383-
x = np.concatenate((fea, list(config.get_other_option().values())))
384-
385-
if res.error_no == 0:
386-
y = inp.task.flop / np.mean(res.costs)
387-
else:
388-
y = 0.0
389-
return x, y
390-
except Exception: # pylint: disable=broad-except
391-
return None
374+
inp, res = arg
375+
config = inp.config
376+
with inp.target:
377+
sch, args = inp.task.instantiate(config)
378+
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
379+
x = np.concatenate((fea, list(config.get_other_option().values())))
380+
381+
if res.error_no == 0:
382+
y = inp.task.flop / np.mean(res.costs)
383+
else:
384+
y = 0.0
385+
return x, y
392386

393387

394388
def _extract_knob_feature_index(args):
395389
"""extract knob feature for an index in extract_space"""
396-
try:
397-
398-
config = _extract_space.get(args)
390+
config = _extract_space.get(args)
399391

400-
return config.get_flatten_feature()
401-
except Exception: # pylint: disable=broad-except
402-
return None
392+
return config.get_flatten_feature()
403393

404394

405395
def _extract_knob_feature_log(arg):
406396
"""extract knob feature for log items"""
407-
try:
408-
inp, res = arg
409-
config = inp.config
410-
x = config.get_flatten_feature()
411-
412-
if res.error_no == 0:
413-
with inp.target: # necessary, for calculating flops of this task
414-
inp.task.instantiate(config)
415-
y = inp.task.flop / np.mean(res.costs)
416-
else:
417-
y = 0.0
418-
return x, y
419-
except Exception: # pylint: disable=broad-except
420-
return None
397+
inp, res = arg
398+
config = inp.config
399+
x = config.get_flatten_feature()
400+
401+
if res.error_no == 0:
402+
with inp.target: # necessary, for calculating flops of this task
403+
inp.task.instantiate(config)
404+
y = inp.task.flop / np.mean(res.costs)
405+
else:
406+
y = 0.0
407+
return x, y
421408

422409

423410
def _extract_curve_feature_index(args):
424411
"""extract sampled curve feature for an index in extract_space"""
425-
try:
412+
config = _extract_space.get(args)
413+
with _extract_target:
414+
sch, fargs = _extract_task.instantiate(config)
426415

427-
config = _extract_space.get(args)
428-
with _extract_target:
429-
sch, fargs = _extract_task.instantiate(config)
430-
431-
fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20)
432-
fea = np.concatenate((fea, list(config.get_other_option().values())))
433-
return np.array(fea)
434-
except Exception: # pylint: disable=broad-except
435-
return None
416+
fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20)
417+
fea = np.concatenate((fea, list(config.get_other_option().values())))
418+
return np.array(fea)
436419

437420

438421
def _extract_curve_feature_log(arg):
439422
"""extract sampled curve feature for log items"""
440-
try:
441-
inp, res = arg
442-
config = inp.config
443-
with inp.target:
444-
sch, args = inp.task.instantiate(config)
445-
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
446-
x = np.concatenate((fea, list(config.get_other_option().values())))
447-
448-
if res.error_no == 0:
449-
y = inp.task.flop / np.mean(res.costs)
450-
else:
451-
y = 0.0
452-
return x, y
453-
except Exception: # pylint: disable=broad-except
454-
return None
423+
inp, res = arg
424+
config = inp.config
425+
with inp.target:
426+
sch, args = inp.task.instantiate(config)
427+
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
428+
x = np.concatenate((fea, list(config.get_other_option().values())))
429+
430+
if res.error_no == 0:
431+
y = inp.task.flop / np.mean(res.costs)
432+
else:
433+
y = 0.0
434+
return x, y
455435

456436

457437
def custom_callback(

0 commit comments

Comments
 (0)