Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,32 +145,6 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None
return _tensor_to_object(output_tensor)


def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor:
"""
Broadcasts an optional tensor of size, shape, and dtype unknown in advance.
Returns the tensor on all ranks or None if no tensor was sent.
"""
assert group is not None

if group.rank() == src:
has_tensor = tensor is not None
if has_tensor:
meta = (has_tensor, tensor.shape, tensor.dtype)
else:
meta = (has_tensor, None, None)
broadcast_object(meta, group, src)
if has_tensor:
broadcast(tensor.to(torch.cuda.current_device()), src, group)
return tensor
else:
has_tensor, shape, dtype = broadcast_object(None, group, src)
if not has_tensor:
return None
output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device())
broadcast(output_tensor, src, group)
return output_tensor


def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
work = group.send([tensor], dst, tag)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
)
# Get the image patches and associated data.
image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
self._config.image_patches.get_patches(images, self._data_type)
self._config.image_patches.get_patches_from_images(images, self._data_type)
)
patch_count_cumsum = padded_cumsum(patch_counts).tolist()
# Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
Expand Down
48 changes: 29 additions & 19 deletions fast_llm/data/preprocessing/image_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ class ImagePatchConfig(Config):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
do_resize: bool = Field(default=True, desc="Whether to resize the image.")
max_image_height: int = Field(
default=1024,
desc="Maximum height of the complete image, in pixels."
"If the original image is larger than this, it will be resized to this height.",
"If the original image is larger than this and resizing is enabled, it will be resized to this height.",
hint=FieldHint.optional,
)
max_image_width: int = Field(
default=1024,
desc="Maximum width of the complete image, in pixels."
"If the original image is larger than this, it will be resized to this width.",
"If the original image is larger than this and resizing is enabled, it will be resized to this width.",
hint=FieldHint.optional,
)
image_break_token: int | None = Field(
Expand Down Expand Up @@ -72,14 +73,14 @@ def _validate(self):
Assert.gt(self.max_patches_height, 0)
Assert.gt(self.max_patches_width, 0)

def get_patches(
self, images: list[bytes], token_data_type: DataType = DataType.int64
def get_patches_from_images(
self, images: list["torch.Tensor|bytes"], token_data_type: DataType = DataType.int64
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["torch.Tensor"], list[int]]:
import torch

if len(images) > 0:
image_patches, image_positions, image_token_maps, image_token_ids = zip(
*(self._get_patches(image, token_data_type) for image in images)
*(self._get_patches_from_image(image, token_data_type) for image in images)
)
return (
torch.cat(image_patches),
Expand All @@ -98,28 +99,37 @@ def get_patches(
[0],
)

def _get_patches(
self, image_bytes: bytes, token_data_type: DataType = DataType.int64
def _get_patches_from_image(
self, image: "torch.Tensor|bytes", token_data_type: DataType = DataType.int64
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
import numpy as np
import PIL.Image
import torch

with PIL.Image.open(io.BytesIO(image_bytes)) as image:
if image.mode != "RGB":
# Convert all images to RGB
image = image.convert("RGB")
image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW
Assert.eq(image.dtype, torch.uint8)
if not torch.is_tensor(image):
import numpy as np
import PIL.Image

with PIL.Image.open(io.BytesIO(image)) as image:
if image.mode != "RGB":
# Convert all images to RGB
image = image.convert("RGB")
image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW
Assert.eq(image.dtype, torch.uint8)

if self.do_resize:
# Resize to a multiple of patch size smaller or equal to max size.
image = self._resize(image)
else:
# Crop the image to ensure its shape is a multiple of the patch size.
image = image[
:, : image.size(1) - image.size(1) % self.height, : image.size(2) - image.size(2) % self.width
]

# Resize to a multiple of patch size smaller or equal to max size.
image = self._resize(image)
num_patches_height = div(image.size(1), self.height)
num_patches_width = div(image.size(2), self.width)
# Convert to patches. (`torch.nn.functional.unfold` not supported for uint8.)
patches = (
image.view(self.num_channels, num_patches_height, self.height, num_patches_width, self.width)
.permute(3, 1, 0, 2, 4)
image.reshape(self.num_channels, num_patches_height, self.height, num_patches_width, self.width)
.permute(1, 3, 0, 2, 4)
.flatten(0, 1)
)

Expand Down
118 changes: 44 additions & 74 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch
import transformers.generation.utils
import transformers.modeling_outputs
import transformers.utils.generic

from fast_llm.core.distributed import broadcast_object, broadcast_optional, safe_barrier
from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.inference.config import HuggingfaceModelConfig
Expand All @@ -20,7 +21,7 @@
logger = logging.getLogger(__name__)


class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
class HuggingfacePreTrainedModel(transformers.PreTrainedModel, transformers.generation.utils.GenerationMixin):
config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig
runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner
config: HuggingfaceModelConfig
Expand Down Expand Up @@ -112,40 +113,14 @@ def from_pretrained(
def _init_weights(self, module) -> None:
raise NotImplementedError(module)


class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin):
def inner_forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values=None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
# Meant to be overridden in derived classes
raise NotImplementedError()

def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values=None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
*args,
coordinator_forward: bool = False,
communication_timeout_sec: float = 600.0,
continue_work: bool = True,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None:
**kwargs,
) -> tuple | transformers.utils.generic.ModelOutput | None:
"""
Forward pass compatible with HuggingFace forward.

Expand All @@ -170,44 +145,37 @@ def forward(

if coordinator_forward and distributed.world_group and distributed.tensor_group:
assert distributed.tensor_group.rank() == 0
assert past_key_values is None and not use_cache

# Some tasks may post-process too slowly, so waiting for the next batch or
# the end of work can exceed the standard 60s timeout.
safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec)

broadcast_optional(input_ids, distributed.tensor_group, 0)
broadcast_optional(attention_mask, distributed.tensor_group, 0)
broadcast_optional(position_ids, distributed.tensor_group, 0)
broadcast_optional(inputs_embeds, distributed.tensor_group, 0)
broadcast_optional(labels, distributed.tensor_group, 0)

# Broadcast all input arguments, handling tensor and non-tensor arguments separately
# TODO: Support nested tensor in arguments (ex. past_key_values)
# TODO: Bypassed if passed as positional argument.
assert kwargs.get("past_key_values") is None and not kwargs.get("use_cache")
broadcast_kwargs = {**kwargs, **{i: arg for i, arg in enumerate(args)}, "continue_work": continue_work}
tensor_kwargs = {key: value for key, value in broadcast_kwargs if torch.is_tensor(value)}
broadcast_object(
[(key, tensor.shape, tensor.dtype) for key, tensor in tensor_kwargs.items()],
distributed.tensor_group,
0,
)
for tensor in tensor_kwargs.values():
broadcast(tensor.to(distributed.device), 0, distributed.tensor_group)
non_tensor_kwargs = {key: value for key, value in broadcast_kwargs if key not in tensor_kwargs}
broadcast_object(
(past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work),
non_tensor_kwargs,
distributed.tensor_group,
0,
)

if not coordinator_forward or continue_work:
return self.inner_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)
return self.inner_forward(*args, **kwargs)

return None

def worker_forward(
self,
communication_timeout_sec: float = 600.0,
):
def worker_forward(self, communication_timeout_sec: float = 600.0):
"""
Run the forward loop on worker ranks in coordinated mode.

Expand Down Expand Up @@ -239,30 +207,28 @@ def worker_forward(
# the end of work can exceed the standard 60s timeout.
safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec)

input_ids = broadcast_optional(None, distributed.tensor_group, 0)
attention_mask = broadcast_optional(None, distributed.tensor_group, 0)
position_ids = broadcast_optional(None, distributed.tensor_group, 0)
inputs_embeds = broadcast_optional(None, distributed.tensor_group, 0)
labels = broadcast_optional(None, distributed.tensor_group, 0)

past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = (
broadcast_object(None, distributed.tensor_group, 0)
broadcast_kwargs = {}
for key, shape, dtype in broadcast_object(None, distributed.tensor_group, 0):
tensor = torch.empty(shape, dtype=dtype, device=distributed.device)
broadcast(tensor, 0, distributed.tensor_group)
broadcast_kwargs[key] = tensor

broadcast_kwargs.update(
broadcast_object(
None,
distributed.tensor_group,
0,
)
)

if not continue_work:
if not broadcast_kwargs.pop("continue_work"):
break

arg_kwargs = {key: value for key, value in broadcast_kwargs.items() if isinstance(key, int)}

self.inner_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
*(arg_kwargs[i] for i in range(len(arg_kwargs))),
**{key: value for key, value in broadcast_kwargs.items() if key not in arg_kwargs},
)

safe_barrier(distributed.world_group, "forward_work_end")
Expand All @@ -274,3 +240,7 @@ def stop_workers(self):
return
self.forward(coordinator_forward=True, continue_work=False)
safe_barrier(distributed.world_group, "forward_work_end")

def inner_forward(*args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput:
# Meant to be overridden in derived classes
raise NotImplementedError()
4 changes: 2 additions & 2 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -242,7 +242,7 @@ def get_inference_runner_class(cls) -> type["InferenceRunner"]:
raise NotImplementedError

@classmethod
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]:
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]:
raise NotImplementedError

@classmethod
Expand Down
17 changes: 0 additions & 17 deletions fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,6 @@ def forward(
if output is not None:
self._log_layer_forward(output, kwargs, i)

# TODO: very slow and memory consuming, only use for debugging for now
# TODO: decide if and how we want to return
# HF transformer style details from forward properly
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
# Last layer does not provide output
if output is not None:
meta = self._meta_outputs[i]
if output.shape == meta.shape:
output_global, _ = meta.local_to_global(output.detach())
else:
# TODO: Handle variable shape.
output_global = output

kwargs["hidden_states"][self._layers[i].module_name] = {
"layer_type": type(layer).__name__,
"tensor": output_global,
}
return None if output is None else output.detach(), (input_, output)

def backward(
Expand Down
9 changes: 6 additions & 3 deletions fast_llm/functional/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ def wrap_forward_backward[
and returns the `input_` gradient.
"""

# We want to run the backward pass even if the input doesn't require grads.
in_ = torch.empty(0, requires_grad=True)

class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
outputs = forward(*args)
outputs = forward(*args[:-1])
Assert.custom(isinstance, outputs, tuple)
# No need to call `save_for_backward`, we don't want the safety checks anyway.
ctx.context = outputs[-1]
Expand All @@ -44,13 +47,13 @@ def forward(ctx, *args):
def backward(ctx, *grad_outputs):
grad_input = backward(*grad_outputs, ctx.context)
if not isinstance(grad_input, tuple):
assert isinstance(grad_input, torch.Tensor)
assert isinstance(grad_input, torch.Tensor) or grad_input is None
grad_input = (grad_input,)
return *grad_input, *[None for _ in range(ctx.nargs - len(grad_input))]

def call(*args, **kwargs):
# TODO: Any way to validate kwargs without making function wrappers?
return Function.apply(*args, *kwargs.values())
return Function.apply(*args, *kwargs.values(), in_)

return call

Expand Down
Loading