Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check for first or last stage #6708

Merged
merged 6 commits into from
May 26, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,45 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para


def get_computeprob_response(tokenizer, response, inputs):
compute_prob_response = {}
new_token_ids = []
new_tokens = []
new_texts = []
log_probs = []
full_logprobs = []
offsets = []
for batch_id in range(len(response['tokens'])):
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], str):
new_token_id = tokenizer.text_to_ids(inputs[batch_id])
new_text = inputs[batch_id]
token_len = len(new_token_id)
elif isinstance(inputs[0], torch.Tensor):
token_len = int(inputs[1][batch_id].item())
new_token_id = inputs[0][batch_id][:token_len].tolist()
new_text = tokenizer.ids_to_text(new_token_id)
new_token_ids.append(new_token_id)
new_tokens.append(response['tokens'][batch_id][:token_len])
new_texts.append(new_text)
log_probs.append(response['logprob'][batch_id][:token_len])
full_logprobs.append(response['full_logprob'][batch_id][:token_len])
offsets.append(response['offsets'][batch_id][:-1])
compute_prob_response['sentences'] = new_texts
compute_prob_response['tokens'] = new_tokens
compute_prob_response['token_ids'] = new_token_ids
compute_prob_response['logprob'] = log_probs
compute_prob_response['full_logprob'] = full_logprobs
compute_prob_response['offsets'] = offsets
return compute_prob_response
if (
not parallel_state.model_parallel_is_initialized()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this condition not parallel_state.model_parallel_is_initialized()?
if the model_parallel is not initialized, it will throw an exception at the very beginning right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to make it robust. This method will only be called after parallel_state is initialized?

Copy link
Collaborator

@yidong72 yidong72 May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parallel_state.model_parallel_is_initialized() method will always be true after parallel init.

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
        _DATA_PARALLEL_GROUP is None:
        return False
    return True

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean will this method ever be used without model parallel init?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. it is only used with model parallel.

or parallel_state.is_pipeline_first_stage()
or parallel_state.is_pipeline_last_stage()
):
# we only have a reseponse on the first and last pipeline stages
compute_prob_response = {}
new_token_ids = []
new_tokens = []
new_texts = []
log_probs = []
full_logprobs = []
offsets = []
for batch_id in range(len(response['tokens'])):
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], str):
new_token_id = tokenizer.text_to_ids(inputs[batch_id])
new_text = inputs[batch_id]
token_len = len(new_token_id)
elif isinstance(inputs[0], torch.Tensor):
token_len = int(inputs[1][batch_id].item())
new_token_id = inputs[0][batch_id][:token_len].tolist()
new_text = tokenizer.ids_to_text(new_token_id)
new_token_ids.append(new_token_id)
new_tokens.append(response['tokens'][batch_id][:token_len])
new_texts.append(new_text)
log_probs.append(response['logprob'][batch_id][:token_len])
full_logprobs.append(response['full_logprob'][batch_id][:token_len])
offsets.append(response['offsets'][batch_id][:-1])
compute_prob_response['sentences'] = new_texts
compute_prob_response['tokens'] = new_tokens
compute_prob_response['token_ids'] = new_token_ids
compute_prob_response['logprob'] = log_probs
compute_prob_response['full_logprob'] = full_logprobs
compute_prob_response['offsets'] = offsets
return compute_prob_response
else:
# intermediate stages
return None


def get_batch(model, tokenizer, context_tokens):
Expand Down
Loading