Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import LogitsProcessorList
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1354,7 +1354,12 @@ def generate(
and generation_config.do_sample is True
)

# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None

# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
Expand All @@ -1363,7 +1368,7 @@ def generate(
logits_processor=logits_processor,
)

# 9. prepare stopping criteria
# 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
Expand All @@ -1375,7 +1380,7 @@ def generate(
f"but is {generation_config.num_return_sequences}."
)

# 8. run greedy search
# 11. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
Expand All @@ -1389,7 +1394,7 @@ def generate(
)

elif is_sample_gen_mode:
# 9. prepare logits warper
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)

# expand input_ids with `num_return_sequences` additional sequences per batch
Expand All @@ -1399,7 +1404,7 @@ def generate(
**model_kwargs,
)

# 10. run sample
# 12. run sample
outputs = self.sample(
input_ids,
logits_processor=logits_processor,
Expand Down Expand Up @@ -2378,7 +2383,12 @@ def generate(
and generation_config.do_sample is True
)

# 8. prepare distribution pre_processing samplers
# 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None

# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
Expand All @@ -2387,7 +2397,7 @@ def generate(
logits_processor=logits_processor,
)

# 9. prepare stopping criteria
# 10. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
Expand All @@ -2399,7 +2409,7 @@ def generate(
f"but is {generation_config.num_return_sequences}."
)

# 10. run greedy search
# 11. run greedy search
outputs = self.greedy_search(
input_ids,
logits_processor=logits_processor,
Expand Down