99
1010class LogitsProcessor (nn .Module ):
1111
12- def __init__ (self , model_config : ModelConfig ):
12+ def __init__ (self ):
1313 super ().__init__ ()
14- self .model_config = model_config
1514
1615 def forward (self ,
1716 hidden_states : torch .Tensor ,
@@ -30,49 +29,6 @@ def forward(self,
3029 else :
3130 hidden_states = hidden_states [- 1 ]
3231
33- # token_count = hidden_states.view(-1, hidden_states.shape[-1]).shape[0]
34-
35- # # Add pre-lm gather logic
36- # if (self.model_config.mapping.enable_attention_dp and getattr(
37- # self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
38- # # ADP + LM TP mode: perform All-Gather before LM_head
39- # from ..distributed import allgather
40- # all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
41- # pad_len = all_rank_max_num_tokens - token_count
42- # if pad_len > 0:
43- # padded_hidden_states = F.pad(hidden_states.view(
44- # -1, hidden_states.shape[-1]), (0, 0, 0, pad_len),
45- # mode="constant",
46- # value=0)
47- # else:
48- # padded_hidden_states = hidden_states.view(
49- # -1, hidden_states.shape[-1])
50- # hidden_states = allgather(padded_hidden_states,
51- # self.model_config.mapping,
52- # dim=0)
53-
54- # # Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled)
55- # if (not self.model_config.mapping.enable_attention_dp) or (
56- # self.model_config.mapping.enable_attention_dp and getattr(
57- # self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
58- # lm_head.gather_output = False
5932 logits = lm_head (hidden_states )
60- # if (not self.model_config.mapping.enable_attention_dp) or (
61- # self.model_config.mapping.enable_attention_dp and getattr(
62- # self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
63- # lm_head.gather_output = True
64-
65- # if (self.model_config.mapping.enable_attention_dp and getattr(
66- # self.model_config.mapping, 'enable_lm_tp_in_adp', False)):
67- # # print(f"In LogitsProcessor, lm_head.weight.data_ptr: {lm_head.weight.data_ptr()}")
68- # # print(f"In LogitsProcessor, lm_head.weight.shape: {lm_head.weight.shape}")
69- # # print(f"In LogitsProcessor, logits.shape: {logits.shape}")
70- # logits = allgather(logits, self.model_config.mapping, dim=-1)
71- # batch_size = logits.shape[0]
72- # local_batch_size = batch_size // self.model_config.mapping.tp_size
73- # logits = logits.view(self.model_config.mapping.tp_size,
74- # local_batch_size, -1)
75- # logits = logits[self.model_config.mapping.tp_rank][:token_count]
76- # print(f"In LogitsProcessor, final logits.shape: {logits.shape}")
7733 logits = logits .float ()
7834 return logits
0 commit comments