diff --git a/generate.py b/generate.py index c2cef3b2..e34a68b3 100644 --- a/generate.py +++ b/generate.py @@ -70,6 +70,10 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range(steps): + # Skip non-informative steps + if num_transfer_tokens[:, i].sum() == 0: + warning('Detected unnecessary unmask steps w/o masked inputs, please lower the total step cnt.') + continue mask_index = (x == mask_id) if cfg_scale > 0.: un_x = x.clone()