Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 156 additions & 44 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,22 @@
See benchmark_pipeline_utils.py for step-by-step instructions.
"""

import importlib
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torch import nn
from torchrec.distributed.benchmark.benchmark_pipeline_utils import (
BaseModelConfig,
create_model_config,
DeepFMConfig,
DLRMConfig,
generate_data,
generate_pipeline,
TestSparseNNConfig,
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
)
from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_func,
benchmark_module,
BenchmarkResult,
cmd_conf,
CPUMemoryStats,
Expand All @@ -62,6 +59,18 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig


@dataclass
class UnifiedBenchmarkConfig:
"""Unified configuration for both pipeline and module benchmarking."""

benchmark_type: str = "pipeline" # "pipeline" or "module"

# Module benchmarking specific options
module_path: str = "" # e.g., "torchrec.models.deepfm"
module_class: str = "" # e.g., "SimpleDeepFMNNWrapper"
module_kwargs: Dict[str, Any] = field(default_factory=dict)


@dataclass
class RunOptions:
"""
Expand Down Expand Up @@ -201,57 +210,160 @@ class ModelSelectionConfig:
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])


@cmd_conf
def main(
run_option: RunOptions,
def dynamic_import_module(module_path: str, module_class: str) -> Type[nn.Module]:
"""Dynamically import a module class from a given path."""
try:
module = importlib.import_module(module_path)
return getattr(module, module_class)
except (ImportError, AttributeError) as e:
raise RuntimeError(f"Failed to import {module_class} from {module_path}: {e}")


def create_module_instance(
unified_config: UnifiedBenchmarkConfig,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
table_config: EmbeddingTablesConfig,
model_selection: ModelSelectionConfig,
pipeline_config: PipelineConfig,
model_config: Optional[BaseModelConfig] = None,
) -> None:
) -> nn.Module:
"""Create a module instance based on the unified config."""
ModuleClass = dynamic_import_module(
unified_config.module_path, unified_config.module_class
)

# Handle common module instantiation patterns
if unified_config.module_class == "SimpleDeepFMNNWrapper":
from torchrec.modules.embedding_modules import EmbeddingBagCollection

ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
return ModuleClass(
embedding_bag_collection=ebc,
num_dense_features=10, # Default value, can be overridden via module_kwargs
**unified_config.module_kwargs,
)
elif unified_config.module_class == "DLRMWrapper":
from torchrec.modules.embedding_modules import EmbeddingBagCollection

ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
return ModuleClass(
embedding_bag_collection=ebc,
dense_in_features=10, # Default value, can be overridden via module_kwargs
dense_arch_layer_sizes=[20, 128], # Default value
over_arch_layer_sizes=[5, 1], # Default value
**unified_config.module_kwargs,
)
elif unified_config.module_class == "EmbeddingBagCollection":
return ModuleClass(tables=tables, **unified_config.module_kwargs)
else:
# Generic instantiation - try with tables and weighted_tables
try:
return ModuleClass(
tables=tables,
weighted_tables=weighted_tables,
**unified_config.module_kwargs,
)
except TypeError:
# Fallback to just tables
try:
return ModuleClass(tables=tables, **unified_config.module_kwargs)
except TypeError:
# Fallback to no embedding tables
return ModuleClass(**unified_config.module_kwargs)


def run_module_benchmark(
unified_config: UnifiedBenchmarkConfig,
table_config: EmbeddingTablesConfig,
run_option: RunOptions,
) -> BenchmarkResult:
"""Run module-level benchmarking."""
tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
num_weighted_features=table_config.num_weighted_features,
embedding_feature_dim=table_config.embedding_feature_dim,
)

if model_config is None:
model_config = create_model_config(
model_name=model_selection.model_name,
batch_size=model_selection.batch_size,
batch_sizes=model_selection.batch_sizes,
num_float_features=model_selection.num_float_features,
feature_pooling_avg=model_selection.feature_pooling_avg,
use_offsets=model_selection.use_offsets,
dev_str=model_selection.dev_str,
long_kjt_indices=model_selection.long_kjt_indices,
long_kjt_offsets=model_selection.long_kjt_offsets,
long_kjt_lengths=model_selection.long_kjt_lengths,
pin_memory=model_selection.pin_memory,
embedding_groups=model_selection.embedding_groups,
feature_processor_modules=model_selection.feature_processor_modules,
max_feature_lengths=model_selection.max_feature_lengths,
over_arch_clazz=model_selection.over_arch_clazz,
postproc_module=model_selection.postproc_module,
zch=model_selection.zch,
hidden_layer_size=model_selection.hidden_layer_size,
deep_fm_dimension=model_selection.deep_fm_dimension,
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
)
module = create_module_instance(
unified_config, tables, weighted_tables, table_config
)

# launch trainers
run_multi_process_func(
func=runner,
world_size=run_option.world_size,
return benchmark_module(
module=module,
tables=tables,
weighted_tables=weighted_tables,
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
num_float_features=10, # Default value
sharding_type=run_option.sharding_type,
planner_type=run_option.planner_type,
world_size=run_option.world_size,
num_benchmarks=5, # Default value
batch_size=2048, # Default value
compute_kernel=run_option.compute_kernel,
device_type="cuda",
)


@cmd_conf
def main(
run_option: RunOptions,
table_config: EmbeddingTablesConfig,
model_selection: ModelSelectionConfig,
pipeline_config: PipelineConfig,
unified_config: UnifiedBenchmarkConfig,
model_config: Optional[BaseModelConfig] = None,
) -> None:
# Route to appropriate benchmark type based on unified config
if unified_config.benchmark_type == "module":
print("Running module-level benchmark...")
result = run_module_benchmark(unified_config, table_config, run_option)
print(f"Module benchmark completed: {result}")
elif unified_config.benchmark_type == "pipeline":
print("Running pipeline-level benchmark...")
tables, weighted_tables = generate_tables(
num_unweighted_features=table_config.num_unweighted_features,
num_weighted_features=table_config.num_weighted_features,
embedding_feature_dim=table_config.embedding_feature_dim,
)

if model_config is None:
model_config = create_model_config(
model_name=model_selection.model_name,
batch_size=model_selection.batch_size,
batch_sizes=model_selection.batch_sizes,
num_float_features=model_selection.num_float_features,
feature_pooling_avg=model_selection.feature_pooling_avg,
use_offsets=model_selection.use_offsets,
dev_str=model_selection.dev_str,
long_kjt_indices=model_selection.long_kjt_indices,
long_kjt_offsets=model_selection.long_kjt_offsets,
long_kjt_lengths=model_selection.long_kjt_lengths,
pin_memory=model_selection.pin_memory,
embedding_groups=model_selection.embedding_groups,
feature_processor_modules=model_selection.feature_processor_modules,
max_feature_lengths=model_selection.max_feature_lengths,
over_arch_clazz=model_selection.over_arch_clazz,
postproc_module=model_selection.postproc_module,
zch=model_selection.zch,
hidden_layer_size=model_selection.hidden_layer_size,
deep_fm_dimension=model_selection.deep_fm_dimension,
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
)

# launch trainers
run_multi_process_func(
func=runner,
world_size=run_option.world_size,
tables=tables,
weighted_tables=weighted_tables,
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
)
else:
raise ValueError(
f"Unknown benchmark_type: {unified_config.benchmark_type}. Must be 'module' or 'pipeline'"
)


def run_pipeline(
run_option: RunOptions,
table_config: EmbeddingTablesConfig,
Expand Down
27 changes: 24 additions & 3 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import resource
import time
import timeit
from dataclasses import dataclass, fields, is_dataclass, MISSING
from dataclasses import dataclass, field, fields, is_dataclass, MISSING
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -140,6 +140,25 @@ def __str__(self) -> str:
return f"Rank {self.rank}: CPU Memory Peak RSS: {self.peak_rss_mbs/1000:.2f} GB"


@dataclass
class ModuleBenchmarkConfig:
"""Configuration for module-level benchmarking."""

module_path: str = "" # e.g., "torchrec.models.deepfm"
module_class: str = "" # e.g., "SimpleDeepFMNNWrapper"
module_kwargs: Dict[str, Any] = field(
default_factory=dict
) # Additional kwargs for module instantiation
num_float_features: int = 0
sharding_type: ShardingType = ShardingType.TABLE_WISE
planner_type: str = "embedding"
world_size: int = 2
num_benchmarks: int = 5
batch_size: int = 2048
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
device_type: str = "cuda"


@dataclass
class BenchmarkResult:
"Class for holding results of benchmark runs"
Expand Down Expand Up @@ -728,8 +747,9 @@ def _init_module_and_run_benchmark(
def _func_to_benchmark(
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
) -> None:
for bench_input in bench_inputs:
model(bench_input)
with torch.inference_mode():
for bench_input in bench_inputs:
model(bench_input)

name = f"{sharding_type.value}-{planner_type}"

Expand Down Expand Up @@ -828,6 +848,7 @@ def benchmark_module(
if weighted_tables is None:
weighted_tables = []

# Use multiprocessing for distributed benchmarking (always assume train mode)
res = multi_process_benchmark(
callable=_init_module_and_run_benchmark,
module=module,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Example YAML configuration for module-level benchmarking
# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=module_benchmark_config.yaml

UnifiedBenchmarkConfig:
benchmark_type: "module"
module_path: "torchrec.models.deepfm"
module_class: "SimpleDeepFMNNWrapper"
module_kwargs:
hidden_layer_size: 20
deep_fm_dimension: 5

RunOptions:
world_size: 2
num_batches: 10
sharding_type: "table_wise"
compute_kernel: "fused"
input_type: "kjt"
planner_type: "embedding"
dense_optimizer: "SGD"
dense_lr: 0.1
sparse_optimizer: "EXACT_ADAGRAD"
sparse_lr: 0.1

EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
embedding_feature_dim: 128
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Example YAML configuration for pipeline-level benchmarking
# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=pipeline_benchmark_config.yaml

UnifiedBenchmarkConfig:
benchmark_type: "pipeline"

PipelineConfig:
pipeline: "sparse"
emb_lookup_stream: "data_dist"
apply_jit: false

RunOptions:
world_size: 2
num_batches: 10
sharding_type: "table_wise"
compute_kernel: "fused"
input_type: "kjt"
planner_type: "embedding"
dense_optimizer: "SGD"
dense_lr: 0.1
sparse_optimizer: "EXACT_ADAGRAD"
sparse_lr: 0.1

ModelSelectionConfig:
model_name: "test_sparse_nn"
batch_size: 8192
num_float_features: 10
feature_pooling_avg: 10
use_offsets: false
long_kjt_indices: true
long_kjt_offsets: true
long_kjt_lengths: true
pin_memory: true

EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
embedding_feature_dim: 128
Loading
Loading