Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 154 additions & 34 deletions verl/models/llama/megatron/modeling_llama_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from verl.utils.megatron import tensor_parallel as tp_utils
from verl.utils.megatron_utils import TransformerConfig, convert_config
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
from verl.utils.kernel import linear_cross_entropy
from verl.models.transformers.common import FusedCausalLMOutputWithPast
"""
TODO:
1. Add weight initialization. Here we need to be careful on TP weight init.
Expand Down Expand Up @@ -180,7 +182,10 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
fuse_entropy_logprobs: bool = False,
) -> Union[Tuple, FusedCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -199,17 +204,45 @@ def forward(
)

hidden_states = outputs
logits = self.lm_head(hidden_states)[0]

logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)

logits = logits.float()
return CausalLMOutputWithPast(

log_probs = None
entropy = None
logits = None

if self.training and fuse_entropy_logprobs:
# TOCHECK: whether labels is not None is needed
"""
To Squeeze:
responses = data['responses']
response_length = responses.size(1)

# labels is responses
labels = responses
label_length = labels.size(1)
label_mask = attention_mask[:, -label_length:]
logits = self._forward_head(hidden_states)
logits = self.lm_head(hidden_states)[0]
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
logits = logits.float()
logits = logits[:, -label_length - 1:-1].contiguous()
logits = logits.div_(temperature)
log_prob = vocab_parallel_log_probs_from_logits(logits, labels)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask)
"""
log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none")
else:
logits = self.lm_head(hidden_states)[0]
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
logits = logits.float()

return FusedCausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
log_probs=log_probs,
entropy=entropy
)


Expand Down Expand Up @@ -315,7 +348,10 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
fuse_entropy_logprobs: bool = False,
) -> Union[Tuple, FusedCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -346,25 +382,62 @@ def forward(
max_seqlen_in_batch=max_seqlen_in_batch)

hidden_states = outputs

log_probs = None
entropy = None
logits = None

if self.training and fuse_entropy_logprobs:
# TOCHECK: whether labels is not None is needed
"""
To Squeeze:
responses = data['responses']
response_length = responses.size(1)

# labels is responses
labels = responses
label_length = labels.size(1)
label_mask = attention_mask[:, -label_length:]
logits = self._forward_head(hidden_states)
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)

logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back, maybe done later
# move outside
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
logits = logits[:, -label_length - 1:-1].contiguous()
logits = logits.div_(temperature)
log_prob = vocab_parallel_log_probs_from_logits(logits, labels)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask)
"""
log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none")
log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length)
entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length)
else:
logits = self._forward_head(hidden_states)

logits = self._forward_head(hidden_states)

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)

logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)

return CausalLMOutputWithPast(
return FusedCausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
log_probs=log_probs,
entropy=entropy
)


Expand All @@ -391,8 +464,11 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
fuse_entropy_logprobs: bool = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids, attention_mask, position_ids)
output = super().forward(input_ids, attention_mask, position_ids, labels, temperature, fuse_entropy_logprobs)
output.logits = torch.squeeze(output.logits, dim=-1)
return output

Expand Down Expand Up @@ -578,6 +654,9 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
fuse_entropy_logprobs: bool = False
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -612,24 +691,62 @@ def forward(

if self.post_process:
hidden_states = outputs
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
logits = self._forward_head(hidden_states)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)

return CausalLMOutputWithPast(

log_probs = None
entropy = None
logits = None

if self.training and fuse_entropy_logprobs:
# TOCHECK: whether labels is not None is needed
"""
To Squeeze:
responses = data['responses']
response_length = responses.size(1)

# labels is responses
labels = responses
label_length = labels.size(1)
label_mask = attention_mask[:, -label_length:]
logits = self._forward_head(hidden_states)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back
# move outside
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
logits = logits[:, -label_length - 1:-1].contiguous()
logits = logits.div_(temperature)
log_prob = vocab_parallel_log_probs_from_logits(logits, labels)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask)
"""
log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none")
log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length)
entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length)
else:
logits = self._forward_head(hidden_states)

logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension

# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)

return FusedCausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
log_probs=log_probs,
entropy=entropy
)
else:
return outputs
Expand Down Expand Up @@ -659,8 +776,11 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
fuse_entropy_logprobs: bool = False
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, temperature=temperature, fuse_entropy_logprobs=fuse_entropy_logprobs)
if self.post_process:
output.logits = torch.squeeze(output.logits, dim=-1)
return output
Expand Down
Loading
Loading