diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py index 7991c7cee6..fd5358fa10 100644 --- a/optimum/fx/parallelization/distributed/dist_ops.py +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -20,6 +20,16 @@ def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor: + """ + Performs an all-reduce operation on a tensor across all processes in the group. + + Args: + group: The process group to perform the all-reduce operation on + tensor: The input tensor to reduce across all processes + + Returns: + The tensor after all-reduce operation (sum across all processes) + """ world_size = dist.get_world_size(group) if world_size == 1: return tensor @@ -29,6 +39,17 @@ def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor: def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor: + """ + Gathers tensors from all processes in the group along the specified dimension. + + Args: + group: The process group to gather tensors from + tensor: The input tensor to gather from each process + gather_dim: The dimension along which to gather tensors (default: -1) + + Returns: + A tensor containing all gathered tensors concatenated along the gather dimension + """ world_size = dist.get_world_size(group) if world_size == 1: return tensor @@ -49,6 +70,17 @@ def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor: + """ + Splits a tensor along the specified dimension and returns the chunk for the current process. + + Args: + group: The process group to determine the current rank + tensor: The input tensor to split + split_dim: The dimension along which to split the tensor (default: -1) + + Returns: + The tensor chunk corresponding to the current process rank + """ world_size = dist.get_world_size(group) if world_size == 1: return tensor @@ -65,6 +97,18 @@ def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) - def scatter( group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0 ) -> torch.Tensor: + """ + Scatters a tensor from rank 0 to all processes in the group along the specified dimension. + + Args: + group: The process group to scatter the tensor to + tensor: The input tensor to scatter (only used on rank 0) + output_tensor: The output tensor to store the scattered chunk + scatter_dim: The dimension along which to scatter the tensor (default: 0) + + Returns: + The output tensor containing the scattered chunk for the current process + """ world_size = dist.get_world_size(group) if world_size == 1: output_tensor.copy_(tensor) @@ -84,63 +128,207 @@ def scatter( class DifferentiableIdentity(torch.autograd.Function): + """ + A differentiable identity function that performs all-reduce on gradients during backward pass. + """ + @staticmethod def forward(ctx, tensor, group: dist.ProcessGroup): + """ + Forward pass that returns the input tensor unchanged. + + Args: + ctx: Context object to save information for backward pass + tensor: Input tensor + group: Process group for gradient synchronization + + Returns: + The input tensor unchanged + """ ctx.group = group return tensor @staticmethod def backward(ctx, grad_output): + """ + Backward pass that performs all-reduce sum on the gradient. + + Args: + ctx: Context object containing saved information + grad_output: Gradient from the next layer + + Returns: + Tuple of gradients for input arguments (tensor gradient, None for group) + """ group = ctx.group return DifferentiableAllReduceSum.apply(grad_output, group), None class DifferentiableAllReduceSum(torch.autograd.Function): + """ + A differentiable all-reduce sum operation that maintains gradients through the operation. + """ + @staticmethod def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + """ + Forward pass that performs all-reduce sum on the input tensor. + + Args: + ctx: Context object to save information for backward pass + tensor: Input tensor to reduce + group: Process group for the all-reduce operation + + Returns: + The tensor after all-reduce sum operation + """ ctx.group = group return all_reduce(group=group, tensor=tensor) @staticmethod def backward(ctx, grad_output: torch.Tensor): + """ + Backward pass that returns the gradient unchanged. + + Args: + ctx: Context object containing saved information + grad_output: Gradient from the next layer + + Returns: + Tuple of gradients for input arguments (gradient unchanged, None for group) + """ return grad_output, None class DifferentiableScatter(torch.autograd.Function): + """ + A differentiable scatter operation that performs all-gather on gradients during backward pass. + """ + @staticmethod def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: + """ + Forward pass that splits the tensor and returns the chunk for the current process. + + Args: + ctx: Context object to save information for backward pass + tensor: Input tensor to scatter + group: Process group for the scatter operation + dim: Dimension along which to scatter (default: -1) + + Returns: + The tensor chunk for the current process + """ ctx.group = group ctx.dim = dim return split(group=group, tensor=tensor, split_dim=dim) @staticmethod def backward(ctx, grad_output: torch.Tensor): + """ + Backward pass that performs all-gather on the gradient. + + Args: + ctx: Context object containing saved information + grad_output: Gradient from the next layer + + Returns: + Tuple of gradients for input arguments (gathered gradient, None for group and dim) + """ return DifferentiableAllGather.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None class DifferentiableAllGather(torch.autograd.Function): + """ + A differentiable all-gather operation that performs scatter on gradients during backward pass. + """ + @staticmethod def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: + """ + Forward pass that gathers tensors from all processes along the specified dimension. + + Args: + ctx: Context object to save information for backward pass + tensor: Input tensor to gather + group: Process group for the all-gather operation + dim: Dimension along which to gather (default: -1) + + Returns: + The gathered tensor containing all process chunks + """ ctx.group = group ctx.dim = dim return all_gather(group=group, tensor=tensor, gather_dim=dim) @staticmethod def backward(ctx, grad_output: torch.Tensor): + """ + Backward pass that scatters the gradient to the current process chunk. + + Args: + ctx: Context object containing saved information + grad_output: Gradient from the next layer + + Returns: + Tuple of gradients for input arguments (scattered gradient, None for group and dim) + """ return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + """ + Applies differentiable all-reduce sum operation on a tensor. + + Args: + tensor: Input tensor to reduce + group: Process group for the all-reduce operation + + Returns: + The tensor after all-reduce sum operation with gradient support + """ return DifferentiableAllReduceSum.apply(tensor, group) def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + """ + Applies differentiable identity operation that synchronizes gradients during backward pass. + + Args: + tensor: Input tensor + group: Process group for gradient synchronization + + Returns: + The input tensor unchanged with gradient synchronization support + """ return DifferentiableIdentity.apply(tensor, group) def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + """ + Applies differentiable all-gather operation on a tensor. + + Args: + tensor: Input tensor to gather + group: Process group for the all-gather operation + dim: Dimension along which to gather (default: -1) + + Returns: + The gathered tensor with gradient support + """ return DifferentiableAllGather.apply(tensor, group, dim) def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + """ + Applies differentiable scatter operation on a tensor. + + Args: + tensor: Input tensor to scatter + group: Process group for the scatter operation + dim: Dimension along which to scatter (default: -1) + + Returns: + The scattered tensor chunk with gradient support + """ return DifferentiableScatter.apply(tensor, group, dim) diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index 7d6d99d69d..ddec8275d9 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -29,9 +29,19 @@ class Registry: """ def __init__(self) -> None: + """Initialize the registry with an empty mapping.""" self.mapping = {} def register(self, op_types): + """ + Register handler classes for specified operation types. + + Args: + op_types: Single operation type or list/tuple of operation types to register + + Returns: + Decorator function that registers the handler class + """ def wrapper(cls): if isinstance(op_types, (list, tuple)): for op_type in op_types: @@ -43,6 +53,7 @@ def wrapper(cls): return wrapper def is_supported(self, op_type) -> bool: + """Check if an operation type is supported by the registry.""" return op_type in self.mapping @@ -50,18 +61,43 @@ def is_supported(self, op_type) -> bool: class OpParallelAxisPropagateHandler: + """Base class for handling parallel axis propagation in PyTorch operations.""" + def __init__(self, node: Node, meta_key: str, config: Config) -> None: + """ + Initialize the handler with node information and configuration. + + Args: + node: The FX graph node to process + meta_key: Key for accessing metadata + config: Configuration object containing parallelization settings + """ self.node = node self.meta_key = meta_key self.config = config def extract_axis(self, arg: Any) -> Optional[int]: + """ + Extract the parallel axis from a node argument. + + Args: + arg: The argument to extract axis from + + Returns: + The parallel axis if found, None otherwise + """ if not isinstance(arg, Node): return None return arg.meta[self.meta_key].get("parallel_axis", None) @abstractmethod def propagate(self) -> List[int]: + """ + Propagate parallel axis information through the operation. + + Returns: + List of possible parallel axes for the output + """ raise NotImplementedError @@ -129,7 +165,10 @@ def propagate(self) -> List[int]: ] ) class UnaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for unary operations that preserve parallel axis from input to output.""" + def propagate(self) -> List[int]: + """Propagate the parallel axis from the single input to the output.""" arg = self.node.all_input_nodes[0] axis = self.extract_axis(arg) return [axis] @@ -163,7 +202,15 @@ def propagate(self) -> List[int]: ] ) class BinaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for binary operations that combine parallel axes from two inputs.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axes from two inputs, handling broadcasting and axis alignment. + + Returns: + List containing the compatible parallel axis or empty list if incompatible + """ input_nodes = self.node.all_input_nodes # only one node if len(input_nodes) == 1: @@ -213,9 +260,17 @@ def propagate(self) -> List[int]: ] ) class ReductionOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for reduction operations that may eliminate or preserve parallel axes.""" + def extract_dims( self, ) -> List[int]: + """ + Extract the dimensions being reduced and keepdim flag from the operation. + + Returns: + Tuple of (dimensions being reduced, keepdim flag) + """ ndim = self.node.meta["val"].ndim dims = None if "dim" in self.node.kwargs: @@ -238,6 +293,12 @@ def extract_dims( return dims, keepdim def propagate(self) -> List[int]: + """ + Propagate parallel axis through reduction, adjusting for eliminated dimensions. + + Returns: + List containing the adjusted parallel axis or empty list if axis is reduced + """ dims, keepdim = self.extract_dims() arg = self.node.all_input_nodes[0] axis = self.extract_axis(arg) @@ -252,7 +313,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.view.default) class ViewLikeOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for view operations that reshape tensors while preserving data.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through view operation by matching tensor sizes. + + Returns: + List of possible parallel axes in the reshaped tensor + """ arg = self.node.args[0] axis = self.extract_axis(arg) if axis is None: @@ -274,7 +343,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.unsqueeze.default) class UnsqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for unsqueeze operations that add singleton dimensions.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through unsqueeze, adjusting for the new dimension. + + Returns: + List containing the adjusted parallel axis + """ arg, dim = self.node.args[0], self.node.args[1] ndim = arg.meta["val"].ndim axis = self.extract_axis(arg) @@ -293,7 +370,15 @@ def propagate(self) -> List[int]: ] ) class SqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for squeeze operations that remove singleton dimensions.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through squeeze, adjusting for removed dimensions. + + Returns: + List containing the adjusted parallel axis or empty list if axis is squeezed + """ arg, dims = self.node.args[0], self.node.args[1] axis = self.extract_axis(arg) if axis is None: @@ -311,7 +396,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.permute.default) class PermuteParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for permute operations that reorder tensor dimensions.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through permutation by finding the new position. + + Returns: + List containing the new parallel axis position or empty list if not found + """ arg, dims = self.node.args[0], self.node.args[1] ndim = arg.meta["val"].ndim axis = self.extract_axis(arg) @@ -326,7 +419,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.slice.Tensor) class SliceParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for slice operations that extract tensor subsets.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through slice, checking if slicing affects the parallel dimension. + + Returns: + List containing the parallel axis or empty list if slicing conflicts with parallelization + """ arg, slice_dim = self.node.args[0], self.node.args[1] axis = self.extract_axis(arg) if axis is None: @@ -350,7 +451,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.expand.default) class ExpandParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for expand operations that broadcast tensors to larger sizes.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through expand, adjusting for added dimensions. + + Returns: + List containing the adjusted parallel axis + """ arg, size = self.node.args[0], self.node.args[1] axis = self.extract_axis(arg) if axis is None: @@ -361,7 +470,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.cat.default) class CatParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for concatenation operations that join tensors along a dimension.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through concatenation, ensuring all inputs have compatible axes. + + Returns: + List containing the parallel axis or empty list if concatenation conflicts with parallelization + """ nodes, cat_axis = self.node.all_input_nodes, self.node.args[1] axis, ndim = self.extract_axis(nodes[0]), nodes[0].meta["val"].ndim cat_axis = (cat_axis + ndim) % ndim @@ -375,7 +492,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.constant_pad_nd.default) class PadParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for padding operations that add values to tensor boundaries.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through padding, checking if padding affects the parallel dimension. + + Returns: + List containing the parallel axis or empty list if padding conflicts with parallelization + """ pad, ndim = self.node.args[1], self.node.args[0].meta["val"].ndim axis = self.extract_axis(self.node.args[0]) if axis is None: @@ -387,7 +512,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.ops.aten.copy.default) class CopyParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for copy operations that transfer data between tensors.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through copy, ensuring source and destination have compatible axes. + + Returns: + List containing the parallel axis or empty list if axes are incompatible + """ dst, src = self.node.all_input_nodes axis_dst = self.extract_axis(dst) axis_src = self.extract_axis(src) @@ -398,7 +531,15 @@ def propagate(self) -> List[int]: @REGISTRY.register(torch.nn.functional.scaled_dot_product_attention) class SpdaAttnParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Handler for scaled dot-product attention operations.""" + def propagate(self) -> List[int]: + """ + Propagate parallel axis through attention, ensuring query, key, and value have compatible axes. + + Returns: + List containing the parallel axis or empty list if axes are incompatible + """ q, k, v = self.node.args[:3] q_axis = self.extract_axis(q) # parallel axis must be the head dimension if being parallelized @@ -408,8 +549,16 @@ def propagate(self) -> List[int]: class FallbackParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + """Fallback handler for operations not explicitly registered in the registry.""" + def propagate(self) -> List[int]: - # by default we don't parallelize inputs and constants(except parameters embedded in modules) + """ + Handle parallel axis propagation for unregistered operations using heuristics. + + Returns: + List of possible parallel axes based on operation type and input analysis + """ + # by default we don't parallelize inputs and constants(except parameters embeded in modules) if self.node.op in ["placeholder", "get_attr"]: return [None] elif self.node.op == "output": diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 62d5894dac..3c152a456c 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -44,6 +44,17 @@ class ColumnParallelLinear(nn.Module): """ def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None: + """ + Initialize the column-parallel linear layer. + + Sets up parameter metadata for parallel execution and configures weight and bias + tensors for column-wise parallelization across the tensor parallel group. + + Args: + ctx (ParallelExecutionCtx): Parallel execution context containing runtime information. + linear (nn.Linear): The original linear module being replaced. + gather_output (bool, optional): Whether to gather output at the end of forward pass. Defaults to True. + """ super(ColumnParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) @@ -94,6 +105,15 @@ def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: self.register_parameter("bias", None) def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the column-parallel linear layer. + + Args: + input (torch.Tensor): Input tensor to be processed. + + Returns: + torch.Tensor: Output tensor after linear transformation, optionally gathered across processes. + """ input = differentiable_identity(input, self.process_group) output = F.linear(input, self.weight, self.bias) if self.gather_output: @@ -121,6 +141,17 @@ class RowParallelLinear(nn.Module): """ def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False) -> None: + """ + Initialize the row-parallel linear layer. + + Sets up parameter metadata for parallel execution and configures weight and bias + tensors for row-wise parallelization across the tensor parallel group. + + Args: + ctx (ParallelExecutionCtx): Parallel execution context containing runtime information. + linear (nn.Linear): The original linear module being replaced. + input_is_parallel (bool, optional): Whether the input tensor has already been parallelized. Defaults to False. + """ super(RowParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) @@ -166,6 +197,15 @@ def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parall self.register_parameter("bias", None) def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the row-parallel linear layer. + + Args: + input (torch.Tensor): Input tensor to be processed. + + Returns: + torch.Tensor: Output tensor after linear transformation and all-reduce operation, with bias added if present. + """ if not self.input_is_parallel: input = differentiable_scatter(input, self.process_group) diff --git a/optimum/fx/parallelization/parallel_layers/loss.py b/optimum/fx/parallelization/parallel_layers/loss.py index 9937089844..937842d267 100644 --- a/optimum/fx/parallelization/parallel_layers/loss.py +++ b/optimum/fx/parallelization/parallel_layers/loss.py @@ -24,6 +24,13 @@ # Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py class _ShardedCrossEntropy(torch.autograd.Function): + """ + Custom autograd function for computing cross-entropy loss on sharded logits across multiple processes. + + This function handles the forward and backward passes for cross-entropy computation when the vocabulary + is distributed across multiple GPUs/processes. + """ + @staticmethod def forward( ctx, @@ -31,6 +38,18 @@ def forward( target: torch.Tensor, # (batch_size, length) group: dist.ProcessGroup, ): + """ + Compute the forward pass of sharded cross-entropy loss. + + Args: + ctx: Context object for saving tensors needed in backward pass + sharded_logits: Logits tensor sharded across processes with shape (batch_size, length, sharded_hidden_size) + target: Target indices tensor with shape (batch_size, length) + group: Process group for distributed communication + + Returns: + torch.Tensor: Cross-entropy loss tensor + """ # Maximum value along last dimension across all GPUs. logits_max = torch.max(sharded_logits, dim=-1)[0] dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) @@ -83,6 +102,16 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): + """ + Compute the backward pass of sharded cross-entropy loss. + + Args: + ctx: Context object containing saved tensors from forward pass + grad_output: Gradient of the loss with respect to the output + + Returns: + tuple: Gradients with respect to sharded_logits, target, and group (None for last two) + """ # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors @@ -103,10 +132,30 @@ def backward(ctx, grad_output: torch.Tensor): def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup): + """ + Apply sharded cross-entropy loss computation using the custom autograd function. + + Args: + sharded_logits: Logits tensor sharded across processes + target: Target indices tensor + process_group: Process group for distributed communication + + Returns: + torch.Tensor: Cross-entropy loss tensor + """ return _ShardedCrossEntropy.apply(sharded_logits, target, process_group) def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup): + """ + Create a wrapper function for sharded cross-entropy that mimics PyTorch's CrossEntropyLoss interface. + + Args: + process_group: Process group for distributed communication + + Returns: + function: Wrapper function that accepts standard cross-entropy loss parameters + """ @wraps(sharded_cross_entropy) def wrapper( sharded_logits: torch.Tensor, @@ -118,6 +167,25 @@ def wrapper( reduction: str = "mean", label_smoothing: float = 0.0, ): + """ + Compute sharded cross-entropy loss with standard PyTorch interface. + + Args: + sharded_logits: Logits tensor sharded across processes + target: Target indices tensor + weight: Manual rescaling weight given to each class (not supported) + size_average: Deprecated parameter for backward compatibility + ignore_index: Index to ignore in loss computation (not supported) + reduce: Deprecated parameter for backward compatibility + reduction: Reduction method ('mean', 'sum', or 'none') + label_smoothing: Label smoothing factor (not supported) + + Returns: + torch.Tensor: Cross-entropy loss tensor with specified reduction applied + + Raises: + ValueError: If unsupported parameters are provided + """ if weight is not None or ignore_index != -100 or label_smoothing != 0.0: raise ValueError( "Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation." @@ -147,14 +215,34 @@ def wrapper( class VocabParallelCrossEntropyLoss(nn.Module): """ Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet. + + This module provides a PyTorch nn.Module interface for computing cross-entropy loss on vocabulary + that is distributed across multiple processes. """ def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None: + """ + Initialize the vocabulary parallel cross-entropy loss module. + + Args: + ctx: Parallel execution context containing process group information + reduction: Reduction method to apply to the loss ('mean', 'sum', or 'none') + """ super(VocabParallelCrossEntropyLoss, self).__init__() self.process_group = ctx.tp_group self.reduction = reduction def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor): + """ + Compute the forward pass of the parallel cross-entropy loss. + + Args: + sharded_logits: Logits tensor sharded across processes + target: Target indices tensor + + Returns: + torch.Tensor: Cross-entropy loss with the specified reduction applied + """ loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group) if self.reduction == "mean": return loss.mean() diff --git a/optimum/utils/preprocessing/text_classification.py b/optimum/utils/preprocessing/text_classification.py index aa0d78581d..4ed9044f92 100644 --- a/optimum/utils/preprocessing/text_classification.py +++ b/optimum/utils/preprocessing/text_classification.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Text classification processing.""" +"""Text classification processing utilities for preparing datasets and tokenizing text data.""" import copy import logging @@ -31,6 +31,20 @@ class TextClassificationProcessing(TaskProcessor): + """ + Processor for text classification tasks that handles tokenization and dataset preparation. + + This class provides functionality to preprocess text classification datasets by tokenizing + input text and preparing it for model training or inference. It supports both single text + and text pair classification tasks. + + Attributes: + ACCEPTED_PREPROCESSOR_CLASSES: Tuple of accepted tokenizer classes. + DEFAULT_DATASET_ARGS: Default arguments for loading datasets. + DEFAUL_DATASET_DATA_KEYS: Default keys for accessing dataset text data. + ALLOWED_DATA_KEY_NAMES: Set of allowed data key names. + DEFAULT_REF_KEYS: Default reference keys for labels. + """ ACCEPTED_PREPROCESSOR_CLASSES = (PreTrainedTokenizerBase,) DEFAULT_DATASET_ARGS = {"path": "glue", "name": "sst2"} DEFAULT_DATASET_DATA_KEYS = {"primary": "sentence"} @@ -40,6 +54,17 @@ class TextClassificationProcessing(TaskProcessor): def create_defaults_and_kwargs_from_preprocessor_kwargs( self, preprocessor_kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Extract default tokenization parameters and remaining kwargs from preprocessor arguments. + + Args: + preprocessor_kwargs: Dictionary of preprocessor arguments, can be None. + + Returns: + Tuple containing: + - defaults: Dictionary with default tokenization parameters (padding, truncation, max_length) + - kwargs: Dictionary with remaining preprocessor arguments + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} kwargs = copy.deepcopy(preprocessor_kwargs) @@ -52,6 +77,17 @@ def create_defaults_and_kwargs_from_preprocessor_kwargs( def dataset_processing_func( self, example: Dict[str, Any], data_keys: Dict[str, str], ref_keys: Optional[List[str]] = None ) -> Dict[str, Any]: + """ + Process a single dataset example by tokenizing the text input(s). + + Args: + example: Dictionary containing the dataset example data. + data_keys: Dictionary mapping data key names to column names in the dataset. + ref_keys: Optional list of reference keys for labels. + + Returns: + Dictionary containing tokenized inputs ready for model consumption. + """ tokenized_inputs = self.preprocessor( text=example[data_keys["primary"]], text_pair=example[data_keys["secondary"]] if "secondary" in data_keys else None, @@ -61,6 +97,16 @@ def dataset_processing_func( return tokenized_inputs def try_to_guess_data_keys(self, column_names: List[str]) -> Optional[Dict[str, str]]: + """ + Attempt to automatically identify primary and secondary text columns in the dataset. + + Args: + column_names: List of column names in the dataset. + + Returns: + Dictionary mapping 'primary' and optionally 'secondary' to column names, + or None if primary column cannot be identified. + """ primary_key_name = None primary_key_name_candidates = ["sentence", "text", "premise"] for name in column_names: @@ -86,6 +132,15 @@ def try_to_guess_data_keys(self, column_names: List[str]) -> Optional[Dict[str, return {"primary": primary_key_name, "secondary": secondary_key_name} def try_to_guess_ref_keys(self, column_names: List[str]) -> Optional[List[str]]: + """ + Attempt to automatically identify label columns in the dataset. + + Args: + column_names: List of column names in the dataset. + + Returns: + List containing the identified label column name, or None if not found. + """ for name in column_names: if "label" in name: return [name] @@ -101,6 +156,22 @@ def load_dataset( shuffle: bool = False, **load_dataset_kwargs, ) -> Union["DatasetDict", "Dataset"]: + """ + Load and prepare a text classification dataset. + + Args: + path: Path or name of the dataset to load. + data_keys: Optional dictionary mapping data key names to dataset column names. + ref_keys: Optional list of reference keys for labels. + only_keep_necessary_columns: Whether to keep only necessary columns. + load_smallest_split: Whether to load only the smallest dataset split. + num_samples: Optional number of samples to load. + shuffle: Whether to shuffle the dataset. + **load_dataset_kwargs: Additional arguments passed to the dataset loading function. + + Returns: + Loaded dataset as either a DatasetDict or Dataset object. + """ dataset = super().load_dataset( path, data_keys=data_keys,