Skip to content

Commit

Permalink
Add the strategy argument to MegatronGPTModel.generate() (NVIDIA#…
Browse files Browse the repository at this point in the history
…7264)

It is passed as an explicit argument rather than through
`**strategy_args` so as to ensure someone cannot accidentally pass other
arguments that would end up being ignored.

It is a keyword-only argument to ensure that if in the future we want to
update the signature to `**strategy_args`, we can do it without breaking
code.

Signed-off-by: Olivier Delalleau <[email protected]>
  • Loading branch information
odelalleau authored Sep 27, 2023
1 parent e2140ac commit 9c4fbe1
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_ltor_masks_and_position_ids,
get_params_for_weight_decay_optimization,
)
from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy
from nemo.collections.nlp.modules.common.text_generation_utils import (
generate,
get_computeprob_response,
Expand Down Expand Up @@ -1176,6 +1177,8 @@ def generate(
inputs: Union[List[str], torch.Tensor, List[dict]],
length_params: LengthParam,
sampling_params: SamplingParam = None,
*,
strategy: Optional[TextGenerationStrategy] = None,
) -> OutputType:

# check whether the DDP is initialized
Expand All @@ -1201,7 +1204,11 @@ def dummy():
if length_params is None:
length_params = get_default_length_params()

return megatron_gpt_generate(self.cuda(), inputs, self.tokenizer, length_params, sampling_params)
strategy_args = {} if strategy is None else {"strategy": strategy}

return megatron_gpt_generate(
self.cuda(), inputs, self.tokenizer, length_params, sampling_params, **strategy_args
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
inference_config = self.get_inference_config()
Expand Down

0 comments on commit 9c4fbe1

Please sign in to comment.