diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 11a0c4083cb7..49278ec03691 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...generation.configuration_utils import GenerationConfig -from ...generation.logits_process import LogitsProcessorList +from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from ...generation.stopping_criteria import StoppingCriteriaList from ...modeling_outputs import ( BaseModelOutput, @@ -1354,7 +1354,12 @@ def generate( and generation_config.do_sample is True ) - # 8. prepare distribution pre_processing samplers + # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) + generation_config.guidance_scale = None + + # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -1363,7 +1368,7 @@ def generate( logits_processor=logits_processor, ) - # 9. prepare stopping criteria + # 10. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) @@ -1375,7 +1380,7 @@ def generate( f"but is {generation_config.num_return_sequences}." ) - # 8. run greedy search + # 11. run greedy search outputs = self.greedy_search( input_ids, logits_processor=logits_processor, @@ -1389,7 +1394,7 @@ def generate( ) elif is_sample_gen_mode: - # 9. prepare logits warper + # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) # expand input_ids with `num_return_sequences` additional sequences per batch @@ -1399,7 +1404,7 @@ def generate( **model_kwargs, ) - # 10. run sample + # 12. run sample outputs = self.sample( input_ids, logits_processor=logits_processor, @@ -2378,7 +2383,12 @@ def generate( and generation_config.do_sample is True ) - # 8. prepare distribution pre_processing samplers + # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) + generation_config.guidance_scale = None + + # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -2387,7 +2397,7 @@ def generate( logits_processor=logits_processor, ) - # 9. prepare stopping criteria + # 10. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) @@ -2399,7 +2409,7 @@ def generate( f"but is {generation_config.num_return_sequences}." ) - # 10. run greedy search + # 11. run greedy search outputs = self.greedy_search( input_ids, logits_processor=logits_processor,