diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 124153f14..f4e8fe0b8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. + The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim). Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. @@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert ndim > 1 + batch_size = x.shape[0] seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1]) + shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -474,9 +473,18 @@ def get_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1) + return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1}) + def forward( self, tokens: torch.Tensor, + freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None = None, ): """ @@ -496,7 +504,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer(h, freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65..5b633243e 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -70,3 +70,10 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) + + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + return ({}, {}) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2efd7931e..0fe1decff 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -429,6 +429,12 @@ def forward_backward_step( extra_inputs=extra_inputs, ) + # Get the order sensitive buffers + order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers( + inputs.size(0), inputs.size(1) + ) + extra_args.update(order_sensitive_buffers[0]) + # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = (