diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index c03ee2d1..2a200688 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -145,32 +145,6 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None return _tensor_to_object(output_tensor) -def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor: - """ - Broadcasts an optional tensor of size, shape, and dtype unknown in advance. - Returns the tensor on all ranks or None if no tensor was sent. - """ - assert group is not None - - if group.rank() == src: - has_tensor = tensor is not None - if has_tensor: - meta = (has_tensor, tensor.shape, tensor.dtype) - else: - meta = (has_tensor, None, None) - broadcast_object(meta, group, src) - if has_tensor: - broadcast(tensor.to(torch.cuda.current_device()), src, group) - return tensor - else: - has_tensor, shape, dtype = broadcast_object(None, group, src) - if not has_tensor: - return None - output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) - broadcast(output_tensor, src, group) - return output_tensor - - def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None work = group.send([tensor], dst, tag) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0d5c0178..5606eeb9 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -251,7 +251,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: ) # Get the image patches and associated data. image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( - self._config.image_patches.get_patches(images, self._data_type) + self._config.image_patches.get_patches_from_images(images, self._data_type) ) patch_count_cumsum = padded_cumsum(patch_counts).tolist() # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index 61e5dd7b..d6f5bf19 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -30,16 +30,17 @@ class ImagePatchConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + do_resize: bool = Field(default=True, desc="Whether to resize the image.") max_image_height: int = Field( default=1024, desc="Maximum height of the complete image, in pixels." - "If the original image is larger than this, it will be resized to this height.", + "If the original image is larger than this and resizing is enabled, it will be resized to this height.", hint=FieldHint.optional, ) max_image_width: int = Field( default=1024, desc="Maximum width of the complete image, in pixels." - "If the original image is larger than this, it will be resized to this width.", + "If the original image is larger than this and resizing is enabled, it will be resized to this width.", hint=FieldHint.optional, ) image_break_token: int | None = Field( @@ -72,14 +73,14 @@ def _validate(self): Assert.gt(self.max_patches_height, 0) Assert.gt(self.max_patches_width, 0) - def get_patches( - self, images: list[bytes], token_data_type: DataType = DataType.int64 + def get_patches_from_images( + self, images: list["torch.Tensor|bytes"], token_data_type: DataType = DataType.int64 ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["torch.Tensor"], list[int]]: import torch if len(images) > 0: image_patches, image_positions, image_token_maps, image_token_ids = zip( - *(self._get_patches(image, token_data_type) for image in images) + *(self._get_patches_from_image(image, token_data_type) for image in images) ) return ( torch.cat(image_patches), @@ -98,28 +99,37 @@ def get_patches( [0], ) - def _get_patches( - self, image_bytes: bytes, token_data_type: DataType = DataType.int64 + def _get_patches_from_image( + self, image: "torch.Tensor|bytes", token_data_type: DataType = DataType.int64 ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: - import numpy as np - import PIL.Image import torch - with PIL.Image.open(io.BytesIO(image_bytes)) as image: - if image.mode != "RGB": - # Convert all images to RGB - image = image.convert("RGB") - image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW - Assert.eq(image.dtype, torch.uint8) + if not torch.is_tensor(image): + import numpy as np + import PIL.Image + + with PIL.Image.open(io.BytesIO(image)) as image: + if image.mode != "RGB": + # Convert all images to RGB + image = image.convert("RGB") + image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW + Assert.eq(image.dtype, torch.uint8) + + if self.do_resize: + # Resize to a multiple of patch size smaller or equal to max size. + image = self._resize(image) + else: + # Crop the image to ensure its shape is a multiple of the patch size. + image = image[ + :, : image.size(1) - image.size(1) % self.height, : image.size(2) - image.size(2) % self.width + ] - # Resize to a multiple of patch size smaller or equal to max size. - image = self._resize(image) num_patches_height = div(image.size(1), self.height) num_patches_width = div(image.size(2), self.width) # Convert to patches. (`torch.nn.functional.unfold` not supported for uint8.) patches = ( - image.view(self.num_channels, num_patches_height, self.height, num_patches_width, self.width) - .permute(3, 1, 0, 2, 4) + image.reshape(self.num_channels, num_patches_height, self.height, num_patches_width, self.width) + .permute(1, 3, 0, 2, 4) .flatten(0, 1) ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index b634f8a4..3ffed453 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -6,8 +6,9 @@ import torch import transformers.generation.utils import transformers.modeling_outputs +import transformers.utils.generic -from fast_llm.core.distributed import broadcast_object, broadcast_optional, safe_barrier +from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.config import HuggingfaceModelConfig @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) -class HuggingfacePreTrainedModel(transformers.PreTrainedModel): +class HuggingfacePreTrainedModel(transformers.PreTrainedModel, transformers.generation.utils.GenerationMixin): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner config: HuggingfaceModelConfig @@ -112,40 +113,14 @@ def from_pretrained( def _init_weights(self, module) -> None: raise NotImplementedError(module) - -class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin): - def inner_forward( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - past_key_values=None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: - # Meant to be overridden in derived classes - raise NotImplementedError() - def forward( self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - past_key_values=None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, + *args, coordinator_forward: bool = False, communication_timeout_sec: float = 600.0, continue_work: bool = True, - ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None: + **kwargs, + ) -> tuple | transformers.utils.generic.ModelOutput | None: """ Forward pass compatible with HuggingFace forward. @@ -170,44 +145,37 @@ def forward( if coordinator_forward and distributed.world_group and distributed.tensor_group: assert distributed.tensor_group.rank() == 0 - assert past_key_values is None and not use_cache # Some tasks may post-process too slowly, so waiting for the next batch or # the end of work can exceed the standard 60s timeout. safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) - broadcast_optional(input_ids, distributed.tensor_group, 0) - broadcast_optional(attention_mask, distributed.tensor_group, 0) - broadcast_optional(position_ids, distributed.tensor_group, 0) - broadcast_optional(inputs_embeds, distributed.tensor_group, 0) - broadcast_optional(labels, distributed.tensor_group, 0) - + # Broadcast all input arguments, handling tensor and non-tensor arguments separately + # TODO: Support nested tensor in arguments (ex. past_key_values) + # TODO: Bypassed if passed as positional argument. + assert kwargs.get("past_key_values") is None and not kwargs.get("use_cache") + broadcast_kwargs = {**kwargs, **{i: arg for i, arg in enumerate(args)}, "continue_work": continue_work} + tensor_kwargs = {key: value for key, value in broadcast_kwargs if torch.is_tensor(value)} + broadcast_object( + [(key, tensor.shape, tensor.dtype) for key, tensor in tensor_kwargs.items()], + distributed.tensor_group, + 0, + ) + for tensor in tensor_kwargs.values(): + broadcast(tensor.to(distributed.device), 0, distributed.tensor_group) + non_tensor_kwargs = {key: value for key, value in broadcast_kwargs if key not in tensor_kwargs} broadcast_object( - (past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work), + non_tensor_kwargs, distributed.tensor_group, 0, ) if not coordinator_forward or continue_work: - return self.inner_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - labels, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - ) + return self.inner_forward(*args, **kwargs) return None - def worker_forward( - self, - communication_timeout_sec: float = 600.0, - ): + def worker_forward(self, communication_timeout_sec: float = 600.0): """ Run the forward loop on worker ranks in coordinated mode. @@ -239,30 +207,28 @@ def worker_forward( # the end of work can exceed the standard 60s timeout. safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) - input_ids = broadcast_optional(None, distributed.tensor_group, 0) - attention_mask = broadcast_optional(None, distributed.tensor_group, 0) - position_ids = broadcast_optional(None, distributed.tensor_group, 0) - inputs_embeds = broadcast_optional(None, distributed.tensor_group, 0) - labels = broadcast_optional(None, distributed.tensor_group, 0) - - past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = ( - broadcast_object(None, distributed.tensor_group, 0) + broadcast_kwargs = {} + for key, shape, dtype in broadcast_object(None, distributed.tensor_group, 0): + tensor = torch.empty(shape, dtype=dtype, device=distributed.device) + broadcast(tensor, 0, distributed.tensor_group) + broadcast_kwargs[key] = tensor + + broadcast_kwargs.update( + broadcast_object( + None, + distributed.tensor_group, + 0, + ) ) - if not continue_work: + if not broadcast_kwargs.pop("continue_work"): break + arg_kwargs = {key: value for key, value in broadcast_kwargs.items() if isinstance(key, int)} + self.inner_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - labels, - use_cache, - output_attentions, - output_hidden_states, - return_dict, + *(arg_kwargs[i] for i in range(len(arg_kwargs))), + **{key: value for key, value in broadcast_kwargs.items() if key not in arg_kwargs}, ) safe_barrier(distributed.world_group, "forward_work_end") @@ -274,3 +240,7 @@ def stop_workers(self): return self.forward(coordinator_forward=True, continue_work=False) safe_barrier(distributed.world_group, "forward_work_end") + + def inner_forward(*args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput: + # Meant to be overridden in derived classes + raise NotImplementedError() diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 27c0e2b7..41736aed 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,9 +30,9 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel + from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM logger = logging.getLogger(__name__) @@ -242,7 +242,7 @@ def get_inference_runner_class(cls) -> type["InferenceRunner"]: raise NotImplementedError @classmethod - def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: raise NotImplementedError @classmethod diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 9f554359..25942b38 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -132,23 +132,6 @@ def forward( if output is not None: self._log_layer_forward(output, kwargs, i) - # TODO: very slow and memory consuming, only use for debugging for now - # TODO: decide if and how we want to return - # HF transformer style details from forward properly - if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: - # Last layer does not provide output - if output is not None: - meta = self._meta_outputs[i] - if output.shape == meta.shape: - output_global, _ = meta.local_to_global(output.detach()) - else: - # TODO: Handle variable shape. - output_global = output - - kwargs["hidden_states"][self._layers[i].module_name] = { - "layer_type": type(layer).__name__, - "tensor": output_global, - } return None if output is None else output.detach(), (input_, output) def backward( diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index 656d65eb..1428ed25 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -27,10 +27,13 @@ def wrap_forward_backward[ and returns the `input_` gradient. """ + # We want to run the backward pass even if the input doesn't require grads. + in_ = torch.empty(0, requires_grad=True) + class Function(torch.autograd.Function): @staticmethod def forward(ctx, *args): - outputs = forward(*args) + outputs = forward(*args[:-1]) Assert.custom(isinstance, outputs, tuple) # No need to call `save_for_backward`, we don't want the safety checks anyway. ctx.context = outputs[-1] @@ -44,13 +47,13 @@ def forward(ctx, *args): def backward(ctx, *grad_outputs): grad_input = backward(*grad_outputs, ctx.context) if not isinstance(grad_input, tuple): - assert isinstance(grad_input, torch.Tensor) + assert isinstance(grad_input, torch.Tensor) or grad_input is None grad_input = (grad_input,) return *grad_input, *[None for _ in range(ctx.nargs - len(grad_input))] def call(*args, **kwargs): # TODO: Any way to validate kwargs without making function wrappers? - return Function.apply(*args, *kwargs.values()) + return Function.apply(*args, *kwargs.values(), in_) return call diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index dbc05184..38658ffc 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -70,8 +70,12 @@ def update_linear_gradients( def linear_forward( - input_: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, transposed_weight: bool = False -) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: + input_: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + transposed_weight: bool = False, + input_requires_grad: bool = True, +) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]]: # Matmul if TritonConfig.TRITON_LINEAR: assert bias is None @@ -81,17 +85,19 @@ def linear_forward( ).unflatten(0, input_.shape[:-1]) else: output = torch.nn.functional.linear(input_, maybe_transpose(weight, transposed_weight), bias) - return output, (input_, weight, bias, transposed_weight) + return output, (input_, weight, bias, transposed_weight, input_requires_grad) def linear_backward( - grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool] -) -> torch.Tensor: - input_, weight, bias, transposed_weight = context + grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool, bool] +) -> torch.Tensor | None: + input_, weight, bias, transposed_weight, input_requires_grad = context weight_t = maybe_transpose(weight, transposed_weight) # Input grad - if TritonConfig.TRITON_LINEAR: + if not input_requires_grad: + grad_input = None + elif TritonConfig.TRITON_LINEAR: grad_input = dense_matmul(grad_output.flatten(0, -2), weight_t).view_as(input_) else: grad_input = grad_output.matmul(weight_t) @@ -108,6 +114,7 @@ def output_parallel_linear_forward( group: ProcessGroup | None, sequence_parallel: bool, transposed_weight: bool = False, + input_requires_grad: bool = True, sparse_map: SparseMap | None = None, ) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: # Gather sequence-parallel slices (non-overlapped) @@ -133,12 +140,13 @@ def output_parallel_linear_forward( group, sequence_parallel, transposed_weight, + input_requires_grad, sparse_map, ) -def output_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: - input_, weight, bias, group, sequence_parallel, transposed_weight, sparse_map = context +def output_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor | None: + input_, weight, bias, group, sequence_parallel, transposed_weight, input_requires_grad, sparse_map = context weight_t = maybe_transpose(weight, transposed_weight) # Gather sequence-parallel slices (overlapped) @@ -148,15 +156,20 @@ def output_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[ty gather_handle = None # Input grad - if TritonConfig.TRITON_LINEAR or sparse_map is not None: - grad_input = input_inner_sparse_matmul(grad_output.flatten(0, -2), weight_t, sparse_map).view_as(input_) + if input_requires_grad: + if TritonConfig.TRITON_LINEAR or sparse_map is not None: + grad_input = input_inner_sparse_matmul(grad_output.flatten(0, -2), weight_t, sparse_map).view_as(input_) + else: + grad_input = grad_output.matmul(weight_t) + + # Reduce input grad (overlapped) + grad_input, reduce_handle = (reduce_scatter_op if sequence_parallel else reduce_op)( + grad_input, group=group, async_op=True + ) else: - grad_input = grad_output.matmul(weight_t) + grad_input = None + reduce_handle = None - # Reduce input grad (overlapped) - grad_input, reduce_handle = (reduce_scatter_op if sequence_parallel else reduce_op)( - grad_input, group=group, async_op=True - ) if sequence_parallel: gather_handle.wait() @@ -175,6 +188,7 @@ def input_parallel_linear_forward( group: ProcessGroup | None, sequence_parallel: bool, transposed_weight: bool = False, + input_requires_grad: bool = True, sparse_map: SparseMap | None = None, ) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: # Matmul @@ -197,12 +211,13 @@ def input_parallel_linear_forward( group, sequence_parallel, transposed_weight, + input_requires_grad, sparse_map, ) -def input_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: - input_, weight, bias, group, sequence_parallel, transposed_weight, sparse_map = context +def input_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor | None: + input_, weight, bias, group, sequence_parallel, transposed_weight, input_requires_grad, sparse_map = context weight_t = maybe_transpose(weight, transposed_weight) # Gather sequence-parallel slices (non-overlapped) @@ -210,7 +225,9 @@ def input_parallel_linear_backward(grad_output: torch.Tensor, context: tuple[typ grad_output = gather_op(grad_output, group, dim=0) # Input grad - if TritonConfig.TRITON_LINEAR or sparse_map is not None: + if not input_requires_grad: + grad_input = None + elif TritonConfig.TRITON_LINEAR or sparse_map is not None: grad_input = output_sparse_matmul(grad_output.flatten(0, -2), weight_t, sparse_map).view_as(input_) else: grad_input = grad_output.matmul(weight_t) @@ -237,6 +254,7 @@ def input_parallel_linear_autograd( group: ProcessGroup | None, sequence_parallel: bool, transposed_weight: bool = False, + input_requires_grad: bool = True, sparse_map: SparseMap | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Autograd goes nuts it this goes in the function. @@ -248,6 +266,7 @@ def input_parallel_linear_autograd( group, sequence_parallel, transposed_weight, + input_requires_grad, sparse_map, ), bias if group else None, diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ab408368..1d2d0b3d 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -226,7 +226,7 @@ def mlp_forward( # Layer 1 intermediate_1, _ = output_parallel_linear_forward( - intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map + intermediate_0, weight_1, bias_1, group, sequence_parallel, False, True, sparse_map ) if recompute_level.recompute_sparse_input: @@ -254,6 +254,7 @@ def mlp_forward( group, sequence_parallel, transposed_layer_2_weight, + True, sparse_map, ) @@ -340,7 +341,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ # Layer 1 recomputation if intermediate_1 is None: intermediate_1 = output_parallel_linear_forward( - intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map + intermediate_0, weight_1, bias_1, group, sequence_parallel, False, True, sparse_map )[0] # Activation recomputation and/or backward @@ -374,7 +375,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ # Layer 1 backward grad_input = output_parallel_linear_backward( grad_intermediate_1, - (intermediate_0, weight_1, bias_1, group, sequence_parallel, False, sparse_map), + (intermediate_0, weight_1, bias_1, group, sequence_parallel, False, True, sparse_map), ) # Sparse copy diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 94382b25..059469d9 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -365,9 +365,8 @@ def _forward( key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) - if self._debug.enabled: - self._debug(query, "query_rotary_input", self._query_dims, kwargs) - self._debug(key, "key_rotary_input", self._kv_dims, kwargs) + self._debug(query, "query_rotary_input", self._query_dims, kwargs) + self._debug(key, "key_rotary_input", self._kv_dims, kwargs) query, key = self._rotary(query, key, kwargs) with set_generator(self._distributed.tp_generator): @@ -379,16 +378,17 @@ def _forward( else: raise NotImplementedError(self._implementation) - if self._debug.enabled: - self._debug(query, "query", self._query_dims, kwargs) - self._debug(key, "key", self._kv_dims, kwargs) - self._debug(value, "value", self._kv_dims, kwargs) - self._debug(input_, "context", self._context_dims, kwargs) + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) input_ = input_.transpose(0, 1).contiguous() - return self.dense(input_) + out, bias = self.dense(input_) + self._debug(out, None, kwargs.get(AttentionKwargs.hidden_dims), kwargs) + return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: batch_dim: TensorDim = kwargs[AttentionKwargs.hidden_dims][1 if kwargs[AttentionKwargs.sequence_first] else 0] diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index d65c924e..6f589eeb 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -86,7 +86,7 @@ class AttentionConfig(MixerConfig): causal: bool = Field( default=True, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", - hint=FieldHint.feature, + hint=FieldHint.architecture, ) dropout: float = Field( default=0.0, diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 67ce5eea..3a0f7cc5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -27,9 +27,62 @@ class DebugLayer: def __init__(self, module: torch.nn.Module): self._module = module + def __call__( + self, + tensor: torch.Tensor | None, + suffix: str | None, + dims: tuple[TensorDim | str, ...] | None, + kwargs: dict[str, typing.Any], + bias: torch.Tensor | None = None, + **logging_kwargs, + ): + name = self._name if suffix is None else f"{self._name}.{suffix}" + output_hidden_state = ( + BlockKwargs.output_hidden_states in kwargs + and any(pattern.match(name) for pattern in kwargs[BlockKwargs.output_hidden_states]) + and tensor is not None + ) + if (level := get_model_debug_level()) == 0 and not output_hidden_state: + return + if bias is not None: + assert tensor is not None + tensor = tensor + bias + meta = self._get_meta(tensor, name, dims, kwargs) + + if output_hidden_state: + kwargs[BlockKwargs.hidden_states][name] = (meta, tensor) + + if level > 1: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(name, str)) + + if level > 0 and tensor is not None: + log_distributed_tensor( + "", + tensor, + level=level, + meta=meta, + **logging_kwargs, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=level, + meta=self._get_meta(tensor, name + f"{name}.grad", dims, kwargs), + **logging_kwargs, + ) + def _get_meta( - self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: + self, + tensor: torch.Tensor | None, + name: str, + dims: tuple[TensorDim | str, ...] | None, + kwargs: dict[str, typing.Any], + ) -> TensorMeta | None: + if tensor is None: + return None + if dims is None: + dims = tuple(f"dim_{i}" for i in range(tensor.ndim)) hidden_dims = { dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) } @@ -42,7 +95,7 @@ def _get_meta( ) for i, dim in enumerate(dims) ), - tensor_name=f"{self._name} {name}", + tensor_name=name, dtype=tensor.dtype, ) @@ -51,47 +104,6 @@ def _name(self): # Should be called after `module_name` is set in `BaseModel` return getattr(self._module, "module_name", "unknown") - @property - def enabled(self) -> bool: - return get_model_debug_level() > 0 - - def __call__[ - T - ]( - self, - tensor: torch.Tensor | None, - name: str, - dims: tuple[TensorDim | str, ...], - kwargs: dict[str, typing.Any], - scale: float = 1.0, - global_: bool = True, - log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, - ) -> None: - if (level := get_model_debug_level()) == 0: - return - if level > 1: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if tensor is not None: - log_distributed_tensor( - "", - tensor, - level=level, - meta=self._get_meta(tensor, name, dims, kwargs), - global_=global_, - log_fn=log_fn, - scale=scale, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=level, - meta=self._get_meta(tensor, name + " grad", dims, kwargs), - global_=global_, - log_fn=log_fn, - scale=scale, - ) - class BlockBase[ConfigType: ModuleConfig](Configurable[ConfigType], LayerBase): """ diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index c6e43ee7..d9a27c45 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -39,6 +39,8 @@ class BlockKwargs: grad_output = "grad_output" iteration = "iteration" device = "device" + hidden_states = "hidden_states" + output_hidden_states = "output_hidden_states" @config_class(registry=True) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e2c586bb..e7c6d9e9 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,12 +1,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import ( - Initialization, - init_normal_, - init_uniform_centered_, - init_zeros_, -) +from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.config import ActivationType @@ -14,7 +9,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear.convolution import CausalConv1d, Convolution2D + from fast_llm.layers.common.linear.convolution import CausalConv1d from fast_llm.layers.common.linear.linear import LinearBase @@ -222,44 +217,3 @@ def get_layer( return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation ) - - -@config_class() -class Convolution2DConfig(AffineLinearBaseConfig): - def get_layer( - self, - in_dim: TensorDim, - out_dim: TensorDim, - kernel_dim_1: TensorDim, - kernel_dim_2: TensorDim, - *, - stride: tuple[int, int], - default_weight_initialization: Initialization | None = None, - default_bias_initialization: Initialization | None = None, - default_add_bias: bool = True, - lr_scale: float | None, - peft: PeftConfig | None, - ) -> "Convolution2D": - from fast_llm.layers.common.linear.convolution import Convolution2D - - if default_weight_initialization is None: - default_weight_initialization = init_normal_() - if default_bias_initialization is None: - default_bias_initialization = init_normal_() - - lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) - weight = self.weight.get_parameter( - (out_dim, in_dim, kernel_dim_1, kernel_dim_2), - default_initialization=default_weight_initialization, - lr_scale=lr_scale, - peft=peft, - ) - bias = self.bias.get_parameter( - (out_dim,), - default_initialization=default_bias_initialization, - lr_scale=lr_scale, - default_enabled=default_add_bias, - peft=peft, - ) - - return Convolution2D(weight, bias, stride=stride) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 6281348e..b88b7b2e 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -55,27 +55,3 @@ def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: raise NotImplementedError() - - -class Convolution2D(torch.nn.Module): - """ - TODO: Generalize to other convolutions? - """ - - def __init__( - self, - weight: ParameterMeta, - bias: ParameterMeta | None, - *, - stride: tuple[int, int], - ): - super().__init__() - self.weight = weight - self.bias = bias - self._stride = stride - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) - - def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: - raise NotImplementedError() diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear/linear.py index 3028fd1e..d0ea7a68 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -81,8 +81,14 @@ class Linear(LinearBase): def forward_only( self, input_: torch.Tensor - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: - return linear_forward(input_, weight=self.weight, bias=self.bias, transposed_weight=self._transposed_weight) + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]]: + return linear_forward( + input_, + weight=self.weight, + bias=self.bias, + transposed_weight=self._transposed_weight, + input_requires_grad=input_.requires_grad, + ) def backward(self, grad_output: torch.Tensor, context) -> torch.Tensor: # noqa return linear_backward(grad_output, context) @@ -114,6 +120,7 @@ def forward_only(self, input_) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, + input_requires_grad=input_.requires_grad, ) def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]): # noqa @@ -147,6 +154,7 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | No group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, + input_requires_grad=input_.requires_grad, ) def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None, tuple[typing.Any, ...]]: @@ -158,6 +166,7 @@ def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor group=group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, + input_requires_grad=input_.requires_grad, ) return output, self.bias if group else None, context diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 5713cbb6..a915b16d 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -129,39 +129,21 @@ def forward( dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator - if self._debug.enabled: - self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(None, "begin", kwargs.get(BlockKwargs.hidden_dims), kwargs) fw_input = input_ hidden_states = self.norm_1(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) hidden_states = self.norm_2(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm_2", kwargs.get(BlockKwargs.hidden_dims), kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "MLP output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 4171e66a..5cc351da 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -76,8 +76,7 @@ def __init__( dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - if self._debug.enabled: - self._top_expert_dim = TensorDim("top_experts", self._config.experts_per_token) + self._top_expert_dim = TensorDim("top_experts", self._config.experts_per_token) def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() @@ -94,10 +93,12 @@ def _forward( return TensorMeta.from_dims(input_.dims[:-1] + (self._output_dim,), "MLP output"), None hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug.enabled: - self._debug( - logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs - ) + logit_dims = ( + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,) + if BlockKwargs.hidden_dims in kwargs + else None + ) + self._debug(logits, "Router logits", logit_dims, kwargs) # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: @@ -125,19 +126,12 @@ def _forward( else: raise NotImplementedError(self._config.routing) - if self._debug.enabled: - # To log all ranks set `global_=False` - self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs - ) - self._debug( - top_experts, - "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), - kwargs, - ) + self._debug(scores, "router_scores", logit_dims, kwargs) + self._debug(top_experts, "router_top_experts", logit_dims, kwargs) - return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa + out = self._mlp_forward(hidden_states, scores, top_experts).view_as(input_) # noqa + self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + return out, None def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 7a52539d..b4da15b4 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -10,6 +10,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.decoder.mlp.config import MLPConfig @@ -120,21 +121,21 @@ def _forward( ), None, ) - return ( - mlp_autograd( - input_, - None, - self.layer_1.weight, - self.layer_1.bias, - self.layer_2.weight, - None if self._parallel_dim.group else self.layer_2.bias, - gated=self._config.gated, - activation_type=self._config.activation, - group=self._parallel_dim.group, - sequence_parallel=self._sequence_parallel, - training=self.training, - recompute_level=self._config.recompute_level, - transposed_layer_2_weight=self.layer_2.transposed_weight, - ), - self.layer_2.bias if self._parallel_dim.group else None, + out = mlp_autograd( + input_, + None, + self.layer_1.weight, + self.layer_1.bias, + self.layer_2.weight, + None if self._parallel_dim.group else self.layer_2.bias, + gated=self._config.gated, + activation_type=self._config.activation, + group=self._parallel_dim.group, + sequence_parallel=self._sequence_parallel, + training=self.training, + recompute_level=self._config.recompute_level, + transposed_layer_2_weight=self.layer_2.transposed_weight, ) + bias = self.layer_2.bias if self._parallel_dim.group else None + self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs, bias=bias) + return out, bias diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 3b2ed26d..32633f21 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -14,6 +14,7 @@ StochasticMixerKwargs, StochasticMixerSamplingStrategy, ) +from fast_llm.logging import get_model_debug_level from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -116,7 +117,7 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: mixer_name = self._sample_mixer_name(kwargs) - if self._debug.enabled: + if get_model_debug_level() > 0: logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}") return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 321400ac..fc8794b5 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -167,7 +167,7 @@ def forward( # Drop the placeholder batch dimension, remove patch padding. input_ = input_.squeeze(int(kwargs[LanguageModelKwargs.sequence_first])) - return self._forward( + out = self._forward( input_, token_ids, kwargs.get(LanguageModelKwargs.position_ids), @@ -175,6 +175,8 @@ def forward( kwargs.get(LanguageModelKwargs.mask_inputs), embedding_map, ) + self._debug(out, None, kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) + return out def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (embeddings) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4b0e3d10..180785af 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -166,24 +166,9 @@ def _forward_backward( input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) - - if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: - # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. - # So, if needed, we gather the data after normalization and set it as the output of the previous layer. - dims = list(kwargs[LanguageModelKwargs.hidden_dims]) - sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - BlockDimNames.sequence_q_tp, - dims[sequence_index].global_size, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), - ) - if self._sequence_parallel_logits - else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) - ) - meta = TensorMeta.from_dims(tuple(dims), tensor_name="hidden_state", dtype=ln_output.dtype) - hidden_state, _ = meta.local_to_global(ln_output.detach()) - kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state + # Transormers expect normalized outputs for the last transformer layer, + # so we add the norm output to the hidden states. + self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) grad_output = kwargs[LanguageModelKwargs.grad_output] / ( self._parallel_dim.size if self._sequence_parallel_logits else 1 @@ -344,15 +329,18 @@ def _logits_cross_entropy_forward_backward( self._z_loss_name, logits_scale_factor=self._config.logits_scale_factor, ) - if self._debug.enabled and self._config.cross_entropy_splits is None: - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q + + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q + if LanguageModelKwargs.hidden_dims in kwargs: batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] dims = ( (sequence_dim, batch_dim, self._vocab_dim) if kwargs[LanguageModelKwargs.sequence_first] else (batch_dim, sequence_dim, self._vocab_dim) ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) + else: + dims = None + self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) if targets is None: return logits * self._config.logits_scale_factor, None @@ -502,6 +490,11 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return loss_defs + @property + def heads(self): + # For compatibility with MTP. + return [self] + def _format_name(name: str) -> str: return name.replace("_", " ") diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c9fc609b..5f222866 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -207,7 +207,9 @@ def _forward( y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) - return self.out_proj(y) + out, bias = self.out_proj(y) + self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + return out, bias @torch.compile def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 081aabe6..3e875a64 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -146,6 +146,7 @@ def _forward( ) if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) + self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) return out, None def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 4b0bd436..616d1152 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -208,12 +208,11 @@ def _forward( # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) dt = dt.transpose(1, 2) - if self._debug.enabled: - self._debug(z, "z", self._xz_dims, kwargs) - self._debug(x, "x", self._xz_dims, kwargs) - self._debug(b, "b", self._bc_dims, kwargs) - self._debug(c, "c", self._bc_dims, kwargs) - self._debug(dt, "dt", self._xz_dims, kwargs) + self._debug(z, "z", self._xz_dims, kwargs) + self._debug(x, "x", self._xz_dims, kwargs) + self._debug(b, "b", self._bc_dims, kwargs) + self._debug(c, "c", self._bc_dims, kwargs) + self._debug(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -227,8 +226,7 @@ def _forward( delta_softplus=True, ) - if self._debug.enabled: - self._debug(y, "y", self._xz_dims, kwargs) + self._debug(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] @@ -237,7 +235,9 @@ def _forward( y = y.transpose(0, 1).contiguous() # (batch/sequence, sequence/batch, local_heads * state) # -> (batch/local_sequence, local_sequence/batch, hidden) - return self.out_proj(y) + out, bias = self.out_proj(y) + self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Implement. diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index bd1c6916..2e0389e8 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -3,14 +3,14 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig -from fast_llm.layers.common.linear.config import Convolution2DConfig +from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MLPBaseConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.vision.patch_convolution import PatchConvolution + from fast_llm.layers.vision.embeddings import PatchEmbeddings from fast_llm.layers.vision.vision_encoder import VisionEncoder, VisionMultiModalModel @@ -58,10 +58,10 @@ class ImageNormalizationConfig(Config): @config_class() -class PatchConvolutionConfig(BlockConfig): +class PatchEmbeddingsConfig(BlockConfig): _abstract = False - convolution: Convolution2DConfig = Field( - desc="Configuration for the 2d convolution.", + patch_embeddings: AffineLinearConfig = Field( + desc="Configuration for the patch embedding layer.", hint=FieldHint.architecture, ) normalization: NormalizationConfig = Field( @@ -90,17 +90,17 @@ def input_channels(self) -> int: return 3 @property - def layer_class(self) -> "type[PatchConvolution]": - from fast_llm.layers.vision.patch_convolution import PatchConvolution + def layer_class(self) -> "type[PatchEmbeddings]": + from fast_llm.layers.vision.embeddings import PatchEmbeddings - return PatchConvolution + return PatchEmbeddings @config_class(registry=True) class VisionEncoderConfig(BlockConfig): _abstract = False # TODO: ====== Rename to patch_embeddings? ====== - patch_convolution: PatchConvolutionConfig = Field( + embeddings: PatchEmbeddingsConfig = Field( desc="Configuration for the patch convolution layer.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/vision/patch_convolution.py b/fast_llm/layers/vision/embeddings.py similarity index 73% rename from fast_llm/layers/vision/patch_convolution.py rename to fast_llm/layers/vision/embeddings.py index e744044c..2076f72e 100644 --- a/fast_llm/layers/vision/patch_convolution.py +++ b/fast_llm/layers/vision/embeddings.py @@ -3,16 +3,17 @@ import torch from fast_llm.core.ops import split +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionKwargs +from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionKwargs from fast_llm.tensor import TensorMeta -class PatchConvolution[ConfigType: PatchConvolutionConfig](Block[ConfigType]): +class PatchEmbeddings[ConfigType: PatchEmbeddingsConfig](Block[ConfigType]): _config: ConfigType def __init__( @@ -39,12 +40,11 @@ def __init__( ).torch self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self.convolution = self._config.convolution.get_layer( - TensorDim("input_channels", self._config.input_channels), + self.patch_embeddings = self._config.patch_embeddings.get_layer( + TensorDim("patch", self._config.input_channels * self._config.patch_height * self._config.patch_width), self._hidden_dim, - TensorDim("patch_height", self._config.patch_height), - TensorDim("patch_width", self._config.patch_width), - stride=(self._config.patch_height, self._config.patch_width), + default_weight_initialization=init_normal_(), + default_bias_initialization=init_normal_(), default_add_bias=False, lr_scale=self._lr_scale, peft=self._peft, @@ -66,9 +66,11 @@ def forward( ) if self._sequence_parallel: input_ = split(input_, group=self._parallel_dim.group, dim=0) - patch_embeddings = ( - self.normalization(self.convolution(input_).flatten(1)) - .view(-1, self._hidden_dim.size) + + out = ( + self.normalization(self.patch_embeddings(input_.flatten(1))) .unsqueeze(int(kwargs[AttentionKwargs.sequence_first])) + .to(self._residual_dtype) ) - return patch_embeddings.to(self._residual_dtype) + self._debug(out, None, kwargs.get(VisionKwargs.hidden_dims), kwargs) + return out diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py index e6261600..03acfdde 100644 --- a/fast_llm/layers/vision/vision_encoder.py +++ b/fast_llm/layers/vision/vision_encoder.py @@ -28,7 +28,7 @@ def __init__( ): super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) - self.patch_convolution = self._config.patch_convolution.get_layer( + self.embeddings = self._config.embeddings.get_layer( distributed_config, vision_hidden_dim, lr_scale=self._lr_scale, @@ -49,18 +49,18 @@ def __init__( ) def get_layers(self) -> list["Layer"]: - return self.patch_convolution.get_layers() + self.encoder.get_layers() + self.adapter.get_layers() + return self.embeddings.get_layers() + self.encoder.get_layers() + self.adapter.get_layers() def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? - self.patch_convolution.preprocess(kwargs) + self.embeddings.preprocess(kwargs) self.encoder.preprocess(kwargs) self.adapter.preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? return ( - self.patch_convolution.get_loss_definitions(count) + self.embeddings.get_loss_definitions(count) + self.encoder.get_loss_definitions(count) + self.adapter.get_loss_definitions(count) ) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 34e38469..a418c3fb 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -1,5 +1,6 @@ import logging import random +import re import typing import torch @@ -9,8 +10,9 @@ from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig -from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -23,15 +25,11 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): fast_llm_config: GPTModelConfig -class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): +class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig runner_class: typing.ClassVar[type[GPTInferenceRunner]] = GPTInferenceRunner fast_llm_base_model: GPTBaseModel - # base_model_prefix = "" - # _no_split_modules = None - # _supports_cache_class = False - # _tied_weights_keys = [] def inner_forward( self, @@ -46,21 +44,23 @@ def inner_forward( output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: - # TODO: Most of this is generalizable. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return self._inner_forward( + self._get_batch(input_ids, attention_mask, position_ids), + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if output_attentions: - raise NotImplementedError() - if inputs_embeds is not None: - raise NotImplementedError() - if labels is not None: - raise NotImplementedError() + def _get_batch( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + ): # NOTE: We are ignoring position_ids as we reconstruct them from attention_mask via sequence_lengths. if attention_mask is not None: # First non zero indexes or zero index if the row is all zeros (invalid row) @@ -77,15 +77,54 @@ def inner_forward( ] else: sequence_lenghts = None + return LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)) + + def _inner_forward( + self, + batch: LanguageModelBatch, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: list[str | re.Pattern] | bool | None = None, + return_dict: bool | None = None, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: + # TODO: Most of this is generalizable. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if output_attentions: + raise NotImplementedError() + if inputs_embeds is not None: + raise NotImplementedError() + if labels is not None: + raise NotImplementedError() # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self.fast_llm_base_model.preprocess_batch( - LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)), - phase=PhaseType.inference, - iteration=iteration, + + ((input_meta, kwargs_meta),) = self.fast_llm_base_model.preprocess_meta(batch, phase=PhaseType.inference) + + if output_hidden_states: + if isinstance(output_hidden_states, bool): + # Hugging Face expect the last layer to include the final norm. + # Note: We can't index `decoder` with slice because it tries to create a new block sequence instance. + output_hidden_states = [layer.module_name + "$" for layer in self.fast_llm_base_model.decoder][:-1] + [ + self.fast_llm_base_model.head.heads[0].final_norm.module_name + "$" + ] + + # This needs to be set before preprocessing so it propagates to layers with namespace. + # kwargs is shallow-copied so changes will propagate back to the main namespace. + kwargs_meta[BlockKwargs.output_hidden_states] = [re.compile(pattern) for pattern in output_hidden_states] + + ((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch( + batch, [(input_meta, kwargs_meta)], phase=PhaseType.inference, iteration=iteration ) - ((input_, kwargs),) = batch if past_key_values is not None: # The transformers will use the past keys and values to this list. @@ -96,12 +135,6 @@ def inner_forward( # The transformers will save the present keys and values to this list. kwargs[AttentionKwargs.presents] = [] - if output_hidden_states: - kwargs["output_hidden_states"] = True - kwargs["hidden_states"] = {} - else: - kwargs["output_hidden_states"] = False - kwargs["global_logits"] = True self._inference_runner.forward(input_, kwargs, iteration=iteration) @@ -112,10 +145,13 @@ def inner_forward( else: logits = kwargs["logits"] - # TODO: convert hidden state form dict to list to be the same as with HFs - hidden_states = None if output_hidden_states: - hidden_states = kwargs["hidden_states"] + hidden_states = { + key: tensor if meta is None else meta.local_to_global(tensor)[0] + for key, (meta, tensor) in kwargs["hidden_states"].items() + } + else: + hidden_states = None if not return_dict: # TODO: Then implementing cache, check hidden state goes before past in the tuple diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ee0fae25..0f26d14f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -201,6 +201,7 @@ def preprocess_batch( BlockKwargs.iteration: iteration, AttentionKwargs.sequence_lengths: cropped_tokens.lengths, AttentionKwargs.device: self._distributed.device, + BlockKwargs.hidden_states: {}, **reference_logits[i], } diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index e07f596a..a485f2c7 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -17,7 +17,8 @@ from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat, LlavaHybridSSMCheckpointFormat if typing.TYPE_CHECKING: - from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalInferenceRunner, MultiModalModel from fast_llm.models.multimodal.trainer import MultiModalTrainer logger = logging.getLogger(__name__) @@ -54,10 +55,10 @@ def get_model_class(cls) -> type["MultiModalModel"]: return MultiModalModel @classmethod - def get_inference_runner_class(cls) -> type["MultiModalModelInferenceRunner"]: - from fast_llm.models.multimodal.model import MultiModalModelInferenceRunner + def get_inference_runner_class(cls) -> type["MultiModalInferenceRunner"]: + from fast_llm.models.multimodal.model import MultiModalInferenceRunner - return MultiModalModelInferenceRunner + return MultiModalInferenceRunner @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModalModelForCausalLM"]: diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 3342fe5e..9657d71b 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -1,5 +1,7 @@ import typing +import torch + from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -10,7 +12,7 @@ from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import LanguageModelHeadConfig -from fast_llm.layers.vision.config import PatchConvolutionConfig, VisionEncoderConfig +from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, LlamaBlockConverter, @@ -24,6 +26,7 @@ from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel +from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div, safe_merge_dicts @@ -52,6 +55,8 @@ def import_config(cls, config: dict) -> dict: config["attention_bias"] = False out = super().import_config(config) out["rotary"]["type"] = "default_2d" + out["causal"] = False + out["cross_document_attention"] = False return out @classmethod @@ -60,6 +65,8 @@ def export_config(cls, config: AttentionConfig) -> dict: Assert.eq(config.softmax_scale_power, 0.5) Assert.is_(type(config.rotary), Rotary2DConfig) assert not config.add_linear_biases + assert not config.causal + assert not config.cross_document_attention Assert.eq(config.head_groups, config.heads) return { "num_attention_heads": config.heads, @@ -85,7 +92,35 @@ class PixtralEncoderConverter(LlamaDecoderConverter): block_converter_class: typing.ClassVar[type[PixtralBlockConverter]] = PixtralBlockConverter -class PixtralPatchConvolutionConverter: +class PatchEmbeddingWeightConverter(WeightConverter): + _config: PatchEmbeddingsConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-1], + self._config.input_channels, + self._config.patch_height, + self._config.patch_width, + ) + for weight_ in weight + ) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-3], + self._config.input_channels * self._config.patch_height * self._config.patch_width, + ) + for weight_ in weight + ) + + +class PixtralEmbeddingsConverter: normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter @classmethod @@ -98,10 +133,10 @@ def import_config(cls, config: dict) -> dict: } @classmethod - def export_config(cls, config: PatchConvolutionConfig) -> dict: - Assert.custom(isinstance, config, PatchConvolutionConfig) + def export_config(cls, config: PatchEmbeddingsConfig) -> dict: + Assert.custom(isinstance, config, PatchEmbeddingsConfig) Assert.eq(config.patch_height, config.patch_width) - Assert.incl(config.convolution.bias.enabled, (None, False)) + Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) return safe_merge_dicts( { @@ -113,14 +148,15 @@ def export_config(cls, config: PatchConvolutionConfig) -> dict: @classmethod def get_converters( - cls, config: PatchConvolutionConfig, fast_llm_prefix: str, hf_prefix: str + cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str ) -> list[WeightConverter]: return [ *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", + f"{fast_llm_prefix}.patch_embeddings", f"{hf_prefix}.patch_conv", False, - WeightConverter, + PatchEmbeddingWeightConverter, + config, ), *cls.normalization_converter_class.get_converters( config, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.ln_pre" @@ -172,9 +208,7 @@ def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) class LlavaVisionModelConverter: vision_adapter_converter_class: typing.ClassVar[type[LlavaVisionAdapterConverter]] = LlavaVisionAdapterConverter - patch_convolution_converter_class: typing.ClassVar[type[PixtralPatchConvolutionConverter]] = ( - PixtralPatchConvolutionConverter - ) + embeddings_converter_class: typing.ClassVar[type[PixtralEmbeddingsConverter]] = PixtralEmbeddingsConverter encoder_converter_class: typing.ClassVar[type[PixtralEncoderConverter]] = PixtralEncoderConverter model_type: typing.ClassVar[str] = "pixtral" @@ -182,7 +216,7 @@ class LlavaVisionModelConverter: def import_config(cls, config: dict) -> dict: Assert.eq(config["vision_config"]["model_type"], cls.model_type) return { - "patch_convolution": cls.patch_convolution_converter_class.import_config(config["vision_config"]), + "embeddings": cls.embeddings_converter_class.import_config(config["vision_config"]), "encoder": cls.encoder_converter_class.import_config(config["vision_config"]), "adapter": cls.vision_adapter_converter_class.import_config(config), "hidden_size": config["vision_config"]["hidden_size"], @@ -193,7 +227,7 @@ def export_config(cls, config: VisionEncoderConfig) -> dict: Assert.custom(isinstance, config, VisionEncoderConfig) # TODO: ====== image_size? ====== vision_config = safe_merge_dicts( - cls.patch_convolution_converter_class.export_config(config.patch_convolution), + cls.embeddings_converter_class.export_config(config.embeddings), cls.encoder_converter_class.export_config(config.encoder), {"hidden_size": config.hidden_size, "model_type": cls.model_type}, ) @@ -210,8 +244,8 @@ def export_config(cls, config: VisionEncoderConfig) -> dict: @classmethod def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: return [ - *cls.patch_convolution_converter_class.get_converters( - config.patch_convolution, "vision_encoder.patch_convolution", "model.vision_tower" + *cls.embeddings_converter_class.get_converters( + config.embeddings, "vision_encoder.embeddings", "model.vision_tower" ), *cls.encoder_converter_class.get_converters( config.encoder, "vision_encoder.encoder", "model.vision_tower.transformer.layers" diff --git a/fast_llm/models/multimodal/conversion/llava_hybrid.py b/fast_llm/models/multimodal/conversion/llava_hybrid.py index da84455a..92571797 100644 --- a/fast_llm/models/multimodal/conversion/llava_hybrid.py +++ b/fast_llm/models/multimodal/conversion/llava_hybrid.py @@ -41,6 +41,7 @@ def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForImageTextToText": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, }, ) diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py new file mode 100644 index 00000000..8b085999 --- /dev/null +++ b/fast_llm/models/multimodal/huggingface.py @@ -0,0 +1,114 @@ +import logging +import typing + +import torch +import transformers.modeling_outputs + +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.sample.patch import PatchBatch +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM +from fast_llm.models.multimodal.config import MultiModalModelConfig +from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalInferenceRunner, MultiModalModel +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig): + model_type = "fast_llm_multi_modal" + model_config_class = MultiModalModelConfig + fast_llm_config: MultiModalModelConfig + + +class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM): + config_class = HuggingfaceMultiModalModelConfig + config: HuggingfaceMultiModalModelConfig + runner_class: typing.ClassVar[type[MultiModalInferenceRunner]] = MultiModalInferenceRunner + fast_llm_base_model: MultiModalBaseModel + + def __init__( + self, + fast_llm_model: MultiModalModel, + config: HuggingfaceMultiModalModelConfig | None = None, + runner: ScheduleRunner | None = None, + **kwargs, + ): + super().__init__(fast_llm_model, config, runner, **kwargs) + embedding_config = self.config.fast_llm_config.base_model.vision_encoder.embeddings + self._patch_config = ImagePatchConfig( + height=embedding_config.patch_height, + width=embedding_config.patch_width, + do_resize=False, + ) + self._image_token_index = self.config.fast_llm_config.base_model.image_token_index + assert self._image_token_index is not None + + def inner_forward( + self, + input_ids: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + image_sizes: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: + return self._inner_forward( + self._get_batch(input_ids, pixel_values, attention_mask, position_ids, image_sizes), + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + def _get_batch( + self, + input_ids: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + image_sizes: torch.Tensor | None = None, + ): + batch = super()._get_batch(input_ids, attention_mask, position_ids) + num_samples, sample_size = batch.tokens.tokens.shape + + if pixel_values is None: + images = [] + elif image_sizes is None: + images = pixel_values.unbind() + else: + # Hugging Face uses a batch of padded images with shape (num_images, max_height, max_width) + # We need to remove padding before further processing. + images = [image[:, :height, :width] for image, (height, width) in zip(pixel_values, image_sizes)] + + # Convert to patches. TODO: Creating token map and image token ids unnecessarily. + image_patches, image_position_ids, _, _, patch_counts = self._patch_config.get_patches_from_images(images) + + # Hugging Face encodes token positions through an image token, from which we extract the patch mapping. + image_mask = batch.tokens.tokens == self._image_token_index + + sample_map, token_map = torch.nonzero(image_mask, as_tuple=True) + Assert.eq(len(sample_map), len(image_patches)) + # Fast-LLM uses negative token ids as placeholders for image tokens. + batch.tokens.tokens = torch.where(image_mask, -100, batch.tokens.tokens) + + batch.image_patches = PatchBatch( + image_patches, + sample_map, + token_map, + image_position_ids, + num_samples, + sample_size, + patch_counts, + ) + + return batch diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index c30a5d27..f8251e21 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel @@ -129,9 +129,9 @@ def preprocess_meta( # Gives the same result, assuming we disable cross-image attention (TODO: Enforce) batch_and_sequence_q_dim, # TODO: Relate to tensor dims in patch convolution. - TensorDim("input_channels", self._config.vision_encoder.patch_convolution.input_channels), - TensorDim("patch_height", self._config.vision_encoder.patch_convolution.patch_height), - TensorDim("patch_width", self._config.vision_encoder.patch_convolution.patch_width), + TensorDim("input_channels", self._config.vision_encoder.embeddings.input_channels), + TensorDim("patch_height", self._config.vision_encoder.embeddings.patch_height), + TensorDim("patch_width", self._config.vision_encoder.embeddings.patch_width), ) ) hidden_dims = ( @@ -193,6 +193,8 @@ def preprocess_batch( VisionKwargs.sequence_lengths: [cropped_image_patches.lengths + [pad_size]], VisionKwargs.sequence_length: sequence_length, VisionKwargs.device: self._distributed.device, + BlockKwargs.output_hidden_states: kwargs.get(BlockKwargs.output_hidden_states, []), + BlockKwargs.hidden_states: kwargs[BlockKwargs.hidden_states], } # We need to modify `local_unpadded_size` directly in `preprocessed_meta` since it's the one used by the engine. # Unsafe, but only needed for testing. @@ -200,8 +202,6 @@ def preprocess_batch( hidden_batch_and_sequence_q_dim = kwargs[self._vision_encoder_namespace][VisionKwargs.hidden_dims][ 0 if kwargs[self._vision_encoder_namespace][VisionKwargs.sequence_first] else 1 ] - print(kwargs[self._vision_encoder_namespace][VisionKwargs.hidden_dims]) - print(hidden_batch_and_sequence_q_dim) assert isinstance(hidden_batch_and_sequence_q_dim, PatchSequenceTensorDim) PatchSequenceTensorDim.local_unpadded_size = cropped_image_patches.patches.size(0) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f4469df9..c614793b 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -151,7 +151,7 @@ def verify_shape(self, tensor: torch.Tensor, global_: bool = False): else: Assert.eq(tensor.shape, self.global_shape if global_ else self.shape, msg=self) - def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]: """ Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. Returns a view of the input tensor (or the input tensor itself) when possible. diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py index 197d1db2..86fe9c70 100644 --- a/tests/data/test_image_patch.py +++ b/tests/data/test_image_patch.py @@ -98,9 +98,9 @@ def _position_ids(height_patches: int, width_patches: int): DATASET_WITH_IMAGE_PATCHES_PATCHES_MD5 = { 27: "d41d8cd98f00b204e9800998ecf8427e", 30: "f9e5a216990b1a3646677195532dddec", - 31: "c56ce50e02154d52e82d320547e3973f", + 31: "bd469b52ddd4f8f2bea4af5c7d843da9", 77: "d41d8cd98f00b204e9800998ecf8427e", - 87: "90ab851ceb87678b4c151edee2049702", + 87: "946d6363c3440c4d3d7b5c684c6efcee", } diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 0dc2421a..c431bb26 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -256,7 +256,7 @@ def test_lm_head( else: logit_weight = None - for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): + for prediction_distance, head in enumerate(model.head.heads): # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index f75ad5eb..f141d04c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -5,7 +5,6 @@ import pytest import safetensors.torch import torch -import transformers import yaml from fast_llm.engine.checkpoint.config import ( @@ -328,12 +327,39 @@ def test_huggingface_model(model_testing_config, get_convert_path): ).eval() test_input = torch.randint( 0, - model_ref.config.fast_llm_config.base_model.embeddings.vocab_size, + 384, size=(4, 100), dtype=torch.int64, device="cuda", ) - output_ref = model_ref(test_input) + kwargs = {} + if model_testing_config.model_type == "multimodal": + kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).cuda() + kwargs["image_sizes"] = torch.tensor( + [ + [20, 20], # Full image, 25 patches + [12, 12], # Smaller, 9 patches + [9, 15], # Cropped to patch size, 6 patches + [5, 20], # Cropped in one dim, 5 patches + [7, 5], # Single patch + [2, 3], # Cropped out (0 patch) + ] + ) + image_token_index = model_ref.fast_llm_base_model.config.image_token_index + # First sample has one image at the beginning. + test_input[0, :25] = image_token_index + # Second sample has one image in the middle + test_input[1, 30:39] = image_token_index + # Third sample has no image. + # Fourth sample has four images. + # First one has discontinuous embedding (ex. image break token) + test_input[3, :3] = image_token_index + test_input[3, 7:10] = image_token_index + # Second and third one next to each other. + test_input[3, 28:34] = image_token_index + # Last one cropped out. + + output_ref = model_ref(test_input, **kwargs) model_from_fast_llm = hf_class.from_pretrained(fast_llm_path).eval() model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( @@ -343,18 +369,13 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ).eval() errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) - model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda().eval() + model_as_hf = model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True).cuda().eval() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), ): print(name) - output = model(test_input) + output = model(test_input, **kwargs) # TODO: Make a generic comparison util. CompareConfig().compare_tensors( {"samples": output_ref.logits, "shape": output_ref.logits.shape, "step": 0}, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f7797e3c..eb0a91dd 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -7,6 +7,7 @@ import typing import pytest +import transformers from fast_llm.config import set_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -32,6 +33,9 @@ EvaluatorsConfig, ) +if typing.TYPE_CHECKING: + import transformers.models.auto.auto_factory + _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) @@ -85,6 +89,9 @@ class ModelTestingConfig: get_dataset: typing.Callable[[bool], tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path]] = ( get_model_test_dataset ) + auto_model_class: type["transformers.models.auto.auto_factory._BaseAutoModelClass"] = ( + transformers.AutoModelForCausalLM + ) def __post_init__(self): _, config, _ = self.get_dataset(config_only=True) @@ -428,6 +435,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + auto_model_class=transformers.AutoModel, ) @@ -505,6 +513,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + auto_model_class=transformers.AutoModel, ) _update_and_add_testing_config( @@ -691,7 +700,7 @@ def _update_and_add_testing_config( model_type="multimodal", updates={ ("model", "base_model", "vision_encoder"): { - "patch_convolution": {"patch_height": 4, "patch_width": 4, "normalization": {"type": "rms_norm"}}, + "embeddings": {"patch_height": 4, "patch_width": 4, "normalization": {"type": "rms_norm"}}, "encoder": copy.deepcopy(MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]), "adapter": {"intermediate_size": 256}, "hidden_size": 256, @@ -723,6 +732,7 @@ def _update_and_add_testing_config( # Micro-sequence split and sequence-first not supported. # TODO: Gradient accumulation works but comparison is broken. skip_tests=("sdp", "ms", "bf4", "df"), + auto_model_class=transformers.AutoModelForImageTextToText, ) @@ -811,9 +821,10 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=2.0, + compare_factor=10.0, # Micro-sequence split not supported for Mamba. - skip_tests=("sdp", "ms"), + # Pipeline-parallel gives a different mixer selection. + skip_tests=("sdp", "ms", "pp"), )