diff --git a/flagscale/runner/auto_tuner/simulator/README.md b/flagscale/runner/auto_tuner/simulator/README.md new file mode 100644 index 000000000..6dff02de1 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/README.md @@ -0,0 +1,22 @@ +# Environment +Begin at the root path of `FlagScale` repository: +``` +conda activate flagscale +cd flagscale/flagscale/runner/auto_tuner/simulator/custom_backend/ +python setup.py develop +``` + +# Setup +Set necessary parameters in `config_gen.py`. For example: +``` +device_type_list = ["A", "B"] +device_num_list = [4, 4] +global_batch_size = 32 +num_micro_batches = 8 +num_layers = 4 +``` +# Run a Task +Start the auto-tuning: +``` +python config_gen.py +``` \ No newline at end of file diff --git a/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py new file mode 100644 index 000000000..2cf11ce23 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py @@ -0,0 +1,181 @@ +import os +import subprocess +import re +import time +# from megatron.training import get_args + +def kill_other_python_processes(): + current_pid = os.getpid() + clear_cmd = f"pkill -f python -o --signal TERM --ignore \"${current_pid}\"" + subprocess.run(clear_cmd, text=True, shell=True) + +def compute_pipeline_parallelism_cost( + scheme: str='1F1B', + # num_stages: int=1, + num_micro_batches: int=1, + process_mesh: list=None, + pp_layers_split: list=None, + fwd_time_per_stage_chunk: list=None, + bwd_time_per_stage_chunk: list=None, + comm_time_between_stages: list=None, + # TODO: add fine-greaied recomputation +): + print(f"--- Compute Pipeline Cost ---") + + # process_mesh: [tp0,cp0,ep0,dp0,pp0,(tp1,cp1,...)] + # comm_time_between_stages[i] means the comm time between stage i-1 and stage i + num_pp_stages = sum(process_mesh[4::5]) + assert len(pp_layers_split) == num_pp_stages, \ + "\flength of list {num_layers_per_stage} should match {num_stages}" + assert len(fwd_time_per_stage_chunk) == num_pp_stages, \ + "\flength of list {fwd_time_per_stage_chunk} should match {num_stages}" + assert len(bwd_time_per_stage_chunk) == num_pp_stages, \ + "\flength of list {bwd_time_per_stage_chunk} should match {num_stages}" + assert len(comm_time_between_stages) == num_pp_stages, \ + "\flength of list {comm_time_between_stages} should match {num_stages}" + + pp_last_stage_time = num_micro_batches * (fwd_time_per_stage_chunk[num_pp_stages-1] + bwd_time_per_stage_chunk[num_pp_stages-1]) + if num_pp_stages==1: + return num_micro_batches * (fwd_time_per_stage_chunk[num_pp_stages-1] + bwd_time_per_stage_chunk[num_pp_stages-1]) + + pipeline_cost = 0 + # TODO: consider when comm time > comp time + # each stage onlt depends on its next stage + if scheme == '1F1B' or scheme== 'AFAB': + pipeline_cost = pp_last_stage_time + for stage_from_last in range(2, num_pp_stages): + pp_this_stage_overlapped_time = (num_micro_batches-1) * (fwd_time_per_stage_chunk[num_pp_stages-1] + bwd_time_per_stage_chunk[num_pp_stages-1]) + pp_this_stage_compute_time = fwd_time_per_stage_chunk[num_pp_stages-stage_from_last] + bwd_time_per_stage_chunk[num_pp_stages-stage_from_last] + pp_last_stage_overall_time = pipeline_cost + 2 * comm_time_between_stages[num_pp_stages-stage_from_last+1] + # not consider the situation that comm stucks the comp + # which means the comm time should no more than the comp time(fwd time) + pipeline_cost = pp_this_stage_compute_time + max(pp_last_stage_overall_time, pp_this_stage_overlapped_time) + else: + raise(ValueError("Scheme must be '1F1B' or 'AFAB'.")) + + return pipeline_cost + +import random + +def simulator( + process_mesh: list=None, + stage: int=0, + num_layers: int=None, + simulated_rank: int=None, + pp_layers_split: list=None +): + + os.environ["PYTHONPATH"] = "/share/project/heyongzhe/FlagScale/megatron:/share/project/heyongzhe/FlagScale" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["RANK"] = str(simulated_rank) + os.environ["LOCAL_RANK"] = str(simulated_rank) + # os.environ["WORLD_SIZE"] = args.world_size + # os.environ["WORLD_SIZE"] = "8" + os.environ["WORLD_SIZE"] = "32" + rdav_endpoint = random.randint(0, 40000) + os.environ["RDZV_ENDPOINT"]="localhost:" + str(rdav_endpoint) + # os.environ["RZDV_ENDPOINT"]="localhost:37832" + os.environ["RDZV_BACKEND"]="c10d" + os.environ["MASTER_ADDR"]="localhost" + + program_entry = " ./flagscale/train/train_aquila.py " + simulation_arguments = " --enable-hetero --enable-simulator --distributed-backend dummy " + # fine_grained_recomputation_args = "--recompute-granularity-per-stage-micro-batch '[1, 1, 1]' --recompute-method-per-stage-micro-batch '[1, 1, 1]' --recompute-num-layers-per-stage-micro-batch '[1, 1, 1]'" + fine_grained_recomputation_args = "" + # print(stage) + + pp_layer_split_args = " --hetero-pipeline-layer-split " + for layers in pp_layers_split: + pp_layer_split_args = pp_layer_split_args + str(layers) + " " + + process_mesh_str = " --hetero-process-meshes " + for dim in process_mesh: + process_mesh_str = process_mesh_str + str(dim) + " " + + num_pp_stages = sum(process_mesh[4::5]) + pp_size_args = " --pipeline-model-parallel-size " + str(num_pp_stages) + " " + + # TODO: too ugly to show this command in the code, re-organize these parameters in another way later + train_command = "python " + program_entry + "--tensor-model-parallel-size 1 --disable-bias-linear --use-flash-attn --sequence-parallel --use-distributed-optimizer --use-mcore-models --transformer-impl transformer_engine --hetero-device-types A800 BI150 --hetero-current-device-type A800 --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 --bf16 --attention-softmax-in-fp32 --accumulate-allreduce-grads-in-fp32 --log-interval 1 --log-throughput --tensorboard-log-interval 1 --wandb-project aquila2 --wandb-exp-name test --tensorboard-dir /share/project/heyongzhe/FlagScale/outputs/tensorboard --wandb-save-dir /share/project/heyongzhe/FlagScale/outputs/wandb --num-layers 32 --hidden-size 4096 --num-attention-heads 32 --seq-length 2048 --max-position-embeddings 2048 --norm-epsilon 1e-05 --use-rotary-position-embeddings --no-position-embedding --swiglu --multiple-of 256 --normalization RMSNorm --rotary-interleaved-patch --untie-embeddings-and-output-weights --init-method-std 0.0165 --attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 0.1 --clip-grad 1.0 --train-samples 128 --global-batch-size 64 --micro-batch-size 1 --seed 42 --lr 0.0002 --weight-decay 0.01 --adam-beta1 0.9 --adam-beta2 0.95 --lr 0.00015 --min-lr 1.5e-05 --lr-warmup-samples 0 --lr-decay-style cosine --data-path /share/project/caozhou/adaptive_flash_ckpt/FlagScale/data/pile_wikipedia_demo --split 1 --tokenizer-type AquilaTokenizerFS --vocab-file ./examples/aquila/tokenizer/vocab.json --merge-file ./examples/aquila/tokenizer/merges.txt --special-tokens-file ./examples/aquila/tokenizer/special_tokens.txt --vocab-size 100008 " + process_mesh_str + simulation_arguments + pp_layer_split_args + fine_grained_recomputation_args + pp_size_args + + # enough sleeping time is needed to really kill the survival megatron process + # as least 5 sec before & after killing can not succeed every time + print("sleeping...") + # print(train_command) + # time.sleep(10) + kill_other_python_processes() + # time.sleep(10) + print("start...") + + result = subprocess.run(train_command, capture_output=True, text=True, shell=True) + output = result.stdout.strip() + print(train_command) + print(output) + # example output: "[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12" + match = re.search(r"forward:\s*([\d.]+),\s*backward:\s*([\d.]+),\s*communication:\s*([\d.]+)", output) + + if match: + fwd_time = float(match.group(1)) + bwd_time = float(match.group(2)) + comm_time = float(match.group(3)) + print("forward:", fwd_time) + print("backward:", bwd_time) + print("communication:", comm_time) + else: + raise(ValueError("Results not found. Example output: \"[simulatior output] forward: 12.34, backward: 56.78, communication: 90.12\"")) + return fwd_time, bwd_time, comm_time + + +# call simulator to obtain the execution of each stage +def simulate_pipeline_parallelism_per_stage_time( + process_mesh: list=None, + pp_layers_split: list=None, + fwd_time_per_stage_chunk: list=None, + bwd_time_per_stage_chunk: list=None, + comm_time_between_stages: list=None, +): + print(f"--- Simulation Begin ---") + print(f"Process Mesh: {process_mesh}") + print(f"PP Layer Split: {pp_layers_split}") + for stage, num_layers in enumerate(pp_layers_split): + # TODO: confirm simulated_rank for different stage + print(f"Stage: {stage}; Num Layers: {num_layers}") + simulated_rank = 0 + fwd_time, bwd_time, comm_time = simulator(process_mesh, stage, num_layers, simulated_rank, pp_layers_split) + fwd_time_per_stage_chunk.append(fwd_time) + bwd_time_per_stage_chunk.append(bwd_time) + comm_time_between_stages.append(comm_time) + print(f"--- Simulation End ---") + + + +def analyze_pp_time( + scheme: str='1F1B', + num_micro_batches: int=1, + process_mesh: list=None, + pp_layers_split: list=None + ): + fwd_time_per_stage_chunk = [] + bwd_time_per_stage_chunk = [] + comm_time_between_stages = [] + + simulate_pipeline_parallelism_per_stage_time( + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages + ) + + pipeline_cost = compute_pipeline_parallelism_cost( + scheme=scheme, + num_micro_batches=num_micro_batches, + process_mesh=process_mesh, + pp_layers_split=pp_layers_split, + fwd_time_per_stage_chunk=fwd_time_per_stage_chunk, + bwd_time_per_stage_chunk=bwd_time_per_stage_chunk, + comm_time_between_stages=comm_time_between_stages + ) + + return pipeline_cost diff --git a/flagscale/runner/auto_tuner/simulator/config_gen.py b/flagscale/runner/auto_tuner/simulator/config_gen.py new file mode 100644 index 000000000..ab8cd9144 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/config_gen.py @@ -0,0 +1,348 @@ +from itertools import product +from itertools import combinations + +import json +import ast + +import os +# from itertools import product +import flagscale.train.theoretical_memory_usage as mem_usg +import analylize_pipeline_time + +from functools import reduce + +BYTES_OF_GB = 10**9 + +# device_type_list = ["A800", "A800", "BI150", "BI150"] +# device_num_list = [8, 8, 8, 8] +# memory_capacity_of_devices = [80, 80, 32, 32] # GB + +device_type_list = ["A800", "BI150"] +device_num_list = [4, 4] +memory_capacity_of_devices = [80, 32] # GB + +global_batch_size = 512 +num_micro_batches = 8 +num_layers = 32 + +num_gpus = sum(device_num_list) + + +class DevicesInfo: + def __init__(self, device_type_list: list, device_num_list: list): + assert len(device_type_list) == len(device_num_list), \ + "\flength of list {device_type_list} should match {device_num_list}" + self.device_type_list = device_type_list + self.device_num_list = device_num_list + self.device_types_count = len(device_type_list) + self.possible_parallelisms = [] + +class HeteroConfig: + def __init__(self, + mesh, + device_types, + pp_layer_split, + recompute_granularity = None, + recompute_method = "uniform", + recompute_num_layers = 1, + theory_peak_memory = 0.0, + oom_error=False): + self.mesh = mesh + self.device_types = device_types + self.pp_layer_split = pp_layer_split + # self.micro_batch_size = 1 + self.recompute_granularity = recompute_granularity + self.recompute_method = recompute_method + self.recompute_num_layers = recompute_num_layers + + self.simulated_time = 0.0 + self.theory_peak_memory = theory_peak_memory + self.oom_error = oom_error + +def generate_hetero_meshes( + devices_info: DevicesInfo, + global_batch_size: int = None, + num_layers: int = None, + output_file: str = "results.json" +): + def enumerate_parallelism(device_num: int = None): + possible_parallelisms = [] + for tp in range(1, device_num + 1): + for dp in range(1, device_num // tp + 1): + if device_num % (dp * tp) == 0: + pp = device_num // (dp * tp) + # mesh: [tp, cp, ep, dp, pp] + possible_parallelisms.append([tp, 1, 1, dp, pp]) + return possible_parallelisms + + def is_legal_combination(comb: list): + pp = sum(comb[4::5]) + # check dp is legal + max_dp = global_batch_size // pp + for dp in comb[3::5]: + if max_dp % dp != 0: + return False + return True + + def is_extreme_strategy(comb: list): + for mesh_index in range(len(comb)//5): + # num_devices_in_mesh = sum( + # comb[ + # mesh_index * 5 : mesh_index * 5 + 4 + # ] + # ) + num_devices_in_mesh = reduce(lambda x, y: x * y, comb[mesh_index * 5 : mesh_index * 5 + 5]) + dp_size_in_mesh = comb[ + mesh_index * 5 + 3 + ] + tp_size_in_mesh = comb[ + mesh_index * 5 + 0 + ] + pp_size_in_mesh = comb[ + mesh_index * 5 + 4 + ] + print(mesh_index, comb[mesh_index * 5 : mesh_index * 5 + 5], num_devices_in_mesh, dp_size_in_mesh, tp_size_in_mesh, pp_size_in_mesh) + if pp_size_in_mesh > num_devices_in_mesh // 2 or tp_size_in_mesh > 8 or dp_size_in_mesh > num_devices_in_mesh / 4: + return True + else: + return False + + def combine_possible_parallelisms(possible_parallelisms, output_file): + ''' Combine and filter results, writing them to a file to avoid OOM. ''' + all_combinations = product(*possible_parallelisms) + with open(output_file, "w") as f: + for comb in all_combinations: + result = sum(comb, []) + if is_legal_combination(result): + if not is_extreme_strategy(result): + f.write(",".join(map(str, result)) + "\n") + + # Ensure output file does not exist initially + if os.path.exists(output_file): + os.remove(output_file) + + # Enumerate all possible meshes for each kind of device + for i in range(devices_info.device_types_count): + device_num = devices_info.device_num_list[i] + devices_info.possible_parallelisms.append(enumerate_parallelism(device_num)) + + # Combine possibilities and write results to file + combine_possible_parallelisms(devices_info.possible_parallelisms, output_file) + print(f"Results written to {output_file}") + + +def split_layers(num_layers, pp_stages): + results = [] + # print(pp_stages) + for split_points in combinations(range(1, num_layers), pp_stages - 1): + # print(split_points) + if len(split_points) == 0: + continue + splits = [split_points[0]] + [split_points[i] - split_points[i - 1] for i in range(1, len(split_points))] + [num_layers - split_points[-1]] + # to prune some extreme splits + if max(splits) / min(splits) > 2: + continue + # print(splits) + results.append(splits) + return results + + +class MeshArguments: + def __init__(self, + mesh_config: HeteroConfig): + # [tp, cp, ep, dp, pp] + self.data_parallel_size = mesh_config.mesh[3] + # TODO: pp size not correct when computing memory, because former method divides the layers evenly + # no embed and dropout for stages except the 1st, and make the layers changable + + # if args.pipeline_model_parallel_size > 1: + # activation_memory = ( + # perlayer_activation + # * args.num_layers + # / args.pipeline_model_parallel_size + # * in_flight_microbatches + # + embedding_activation_memory + # + dropout_activation_memory + # ) + # else: + # activation_memory = ( + # perlayer_activation * args.num_layers + # + embedding_activation_memory + # + dropout_activation_memory + # + output_layer_and_loss_activation_memory + # ) + self.pipeline_model_parallel_size = sum(mesh_config.mesh[4::5]) + self.tensor_model_parallel_size = mesh_config.mesh[0] + self.virtual_pipeline_model_parallel_size = None + self.num_experts = 1 + + self.swiglu = True + self.micro_batch_size = global_batch_size / num_micro_batches / self.data_parallel_size + self.num_layers = num_layers + self.num_attention_heads = 32 + self.group_query_attention = None # TODO + self.num_query_groups = 1 # TODO + + self.seq_length = 2048 + self.padded_vocab_size = 4096 # TODO + self.hidden_size = 4096 + # self.ffn_hidden_size + self.multiple_of = 256 + hidden_dim = int(4 * self.hidden_size * 2 / 3) + self.ffn_hidden_size = self.multiple_of * ( + (hidden_dim + self.multiple_of - 1) // self.multiple_of + ) + # self.kv_channels + self.kv_channels = self.hidden_size // self.num_attention_heads + + self.recompute_granularity = mesh_config.recompute_granularity + self.recompute_method = mesh_config.recompute_method + self.recompute_num_layers = mesh_config.recompute_num_layers + + self.use_flash_attn = True + self.sequence_parallel = True + self.use_distributed_optimizer =True + self.untie_embeddings_and_output_weights = False # TODO + + self.enable_hetero = True + + + +def report_oom_error( + memory_capacity_of_devices: list, + meshes_config: list, + peak_memory_usage_per_stage: list +): + stage_index = 0 + for mesh_index, num_stages_in_current_mesh in enumerate(meshes_config[4::5]): + for i in range(num_stages_in_current_mesh): + if peak_memory_usage_per_stage[stage_index+i] >= memory_capacity_of_devices[mesh_index]: + return True + stage_index = stage_index + num_stages_in_current_mesh + return False + +def calculate_peak_memory_per_stage(mesh_config): + peak_memory_usage_per_stage = [] + model_parallel_training_args = MeshArguments(mesh_config) + stage_index = 0 + mesh_index = 0 + for pp_stage_num_per_mesh in mesh_config.mesh[4::5]: + model_parallel_training_args.data_parallel_size = mesh_config.mesh[3 + 5 * mesh_index] + model_parallel_training_args.tensor_model_parallel_size = mesh_config.mesh[0 + 5 * mesh_index] + for stage in range(pp_stage_num_per_mesh): + model_parallel_training_args.num_layers = mesh_config.pp_layer_split[stage_index] + + peak_activation_memory_usage = mem_usg.compute_activation_memory(args=model_parallel_training_args, num_microbatches=num_micro_batches) + peak_weight_optimizer_usage = mem_usg.compute_weight_and_optimizer_memory(args=model_parallel_training_args) + peak_memory_usage = peak_activation_memory_usage + peak_weight_optimizer_usage + + peak_memory_usage_per_stage.append(peak_memory_usage/BYTES_OF_GB) + stage_index = stage_index + 1 + + mesh_index = mesh_index + 1 + + return peak_memory_usage_per_stage + + +def gen_hetero_configs( + device_type_list, + device_num_list, + global_batch_size, + num_layers, + # num_micro_batches, + # hetero_configs: list, + output_config_file: str = "hetero_configs.json" # 新增参数用于保存 hetero_config +): + devices_info = DevicesInfo(device_type_list=device_type_list, device_num_list=device_num_list) + + # 调用 generate_hetero_meshes,生成并写入结果文件 + generate_hetero_meshes( + devices_info=devices_info, + global_batch_size=global_batch_size, + num_layers=num_layers, + output_file="results.json" # 保存 hetero_meshes 的中间文件 + ) + + # 从 results.json 读取 hetero_meshes + hetero_meshes = [] + with open("results.json", "r") as f: + for line in f: + hetero_meshes.append(list(map(int, line.strip().split(",")))) + # print(hetero_meshes) + # assert False + # 遍历 hetero_meshes 并生成 hetero_config + with open(output_config_file, "w") as config_file: # 打开输出文件 + for mesh in hetero_meshes: + pp_stages = sum(mesh[4::5]) + # in order to prune the num of layers in each stage to even number + pp_layer_splits = split_layers(num_layers=num_layers//2, pp_stages=pp_stages) + for split in pp_layer_splits: + split = [x * 2 for x in split] + hetero_config = HeteroConfig(mesh=mesh, + pp_layer_split=split, + device_types=device_type_list) + # hetero_configs.append(hetero_config) + + # 保存 HeteroConfig 的每个成员变量到文件 + theory_peak_memory_per_stage = calculate_peak_memory_per_stage(hetero_config) + oom_error = report_oom_error( + memory_capacity_of_devices=memory_capacity_of_devices, + meshes_config=mesh, + peak_memory_usage_per_stage=theory_peak_memory_per_stage) + # if oom_error: + # continue + config_data = { + "mesh": hetero_config.mesh, + "device_types": hetero_config.device_types, + "pp_layer_split": hetero_config.pp_layer_split, + "recompute_granularity": hetero_config.recompute_granularity, + "recompute_method": hetero_config.recompute_method, + "recompute_num_layers": hetero_config.recompute_num_layers, + "simulated_time": hetero_config.simulated_time, + "theory_peak_memory": theory_peak_memory_per_stage, + "oom_error": oom_error + } + config_file.write(f"{config_data}\n") + + print(f"Hetero configurations saved to {output_config_file}") + +import json + +def read_configs_from_json(file_path: str): + configs_list = [] + with open(file_path, "r") as file: + for line in file: + config_data = json.loads(line.strip()) + configs_list.append(config_data) + return configs_list + + +# for test and usage +if __name__ == "__main__": + # hetero_configs = [] + + # generate all possible and legal mesh configs, each element of hetero_configs is a mesh list + # gen_hetero_configs( + # device_type_list=device_type_list, + # device_num_list=device_num_list, + # global_batch_size=global_batch_size, + # num_layers=num_layers, + # output_config_file = "hetero_configs.json" + # # num_micro_batches=num_micro_batches, + # # hetero_configs=hetero_configs + # ) + + # assert False + # simulation + file_path = "hetero_configs.json" + hetero_configs = read_configs_from_json(file_path) + + for hetero_config in hetero_configs: + print(hetero_config) + pp_cost = hetero_config.simulated_time = analylize_pipeline_time.analyze_pp_time( + scheme="1F1B", + num_micro_batches=num_micro_batches, + process_mesh=hetero_config['mesh'], + pp_layers_split=hetero_config['pp_layer_split'] + ) + print(f"pipeline cost: {pp_cost}") \ No newline at end of file diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/README.md b/flagscale/runner/auto_tuner/simulator/custom_backend/README.md new file mode 100644 index 000000000..a5e926808 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/README.md @@ -0,0 +1,35 @@ +## Build + +```python setup.py develop``` + +## Usage + +```python +import os + +import torch +import dummy_collectives + +import torch.distributed as dist + +os.environ['MASTER_ADDR'] = 'localhost' +os.environ['MASTER_PORT'] = '29500' + +dist.init_process_group("cpu:gloo,cuda:dummy", rank=0, world_size=1) + +# this goes through gloo +x = torch.ones(6) +dist.all_reduce(x) +print(f"cpu allreduce: {x}") + +# this goes through dummy +if torch.cuda.is_available(): + y = x.cuda() + dist.all_reduce(y) + print(f"cuda allreduce: {y}") + + try: + dist.broadcast(y, 0) + except RuntimeError: + print("got RuntimeError when calling broadcast") +``` \ No newline at end of file diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp new file mode 100644 index 000000000..a71eb8536 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/include/dummy.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include + +using AnyType = std::variant; + + +namespace c10d { + +class ProcessGroup; // 假设的类 +class Store; // 假设的类 + +class BackendDummy : public Backend { + public: + + BackendDummy(int rank, int size); + + const std::string getBackendName() const override; + void startCoalescing() override; + c10::intrusive_ptr endCoalescing() override; + +c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& /* opts */ = AllgatherOptions()) override; + +c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) override; + +c10::intrusive_ptr broadcast( + std::vector &data, + const BroadcastOptions &opts = BroadcastOptions()) override; + +c10::intrusive_ptr allreduce( + std::vector &tensors, + const AllreduceOptions &opts = AllreduceOptions()) override; + +c10::intrusive_ptr allreduce_coalesced( + std::vector &tensors, + const AllreduceCoalescedOptions &opts = + AllreduceCoalescedOptions()) override; + +c10::intrusive_ptr reduce( + std::vector &tensors, + const ReduceOptions &opts = ReduceOptions()) override; + +c10::intrusive_ptr all_gather_object( + std::vector &outputTensors, + AnyType &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()); + +c10::intrusive_ptr allgather( + std::vector> &outputTensors, + std::vector &inputTensors, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr _allgather_base( + at::Tensor &outputBuffer, + at::Tensor &inputBuffer, + const AllgatherOptions &opts = AllgatherOptions()) override; + +c10::intrusive_ptr barrier( + const BarrierOptions &opts = BarrierOptions()) override; + +c10::intrusive_ptr gather( + std::vector> &outputTensors, + std::vector &inputTensors, + const GatherOptions &opts = GatherOptions()) override; + +c10::intrusive_ptr scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ScatterOptions &opts = ScatterOptions()) override; + +c10::intrusive_ptr reduce_scatter( + std::vector &outputTensors, + std::vector> &inputTensors, + const ReduceScatterOptions &opts = ReduceScatterOptions()) override; + +c10::intrusive_ptr alltoall_base( + at::Tensor &outputTensor, + at::Tensor &inputTensor, + std::vector &outputSplitSizes, + std::vector &inputSplitSizes, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr alltoall( + std::vector &outputTensors, + std::vector &inputTensors, + const AllToAllOptions &opts = AllToAllOptions()) override; + +c10::intrusive_ptr send( + std::vector &tensors, + int dstRank, + int tag) override; + +c10::intrusive_ptr recv( + std::vector &tensors, + int srcRank, + int tag) override; + +c10::intrusive_ptr recvAnysource( + std::vector &tensors, + int tag) override; + +static c10::intrusive_ptr createBackendDummy( + const c10::intrusive_ptr<::c10d::Store> &store, + int rank, + int size, + const std::chrono::duration &timeout); + +static void BackendDummyConstructor() __attribute__((constructor)) +{ + py::object module = py::module::import("torch.distributed"); + py::object register_backend = + module.attr("Backend").attr("register_backend"); + register_backend("dummy", py::cpp_function(createBackendDummy)); + } +}; + +class WorkDummy : public Work { + friend class BackendDummy; +public: + WorkDummy( + OpType opType, + c10::intrusive_ptr future) // future of the output + : Work( + -1, // rank, only used by recvAnySource, irrelevant in this demo + opType), + future_(std::move(future)) {} + bool isCompleted() override; + bool isSuccess() const override; + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; + virtual c10::intrusive_ptr getFuture() override; + +private: + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py new file mode 100644 index 000000000..6017f45b0 --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/setup.py @@ -0,0 +1,27 @@ +import os +import torch +from setuptools import setup +from torch.utils import cpp_extension + +sources = ["src/dummy.cpp"] +include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"] + +if torch.cuda.is_available(): + module = cpp_extension.CUDAExtension( + name="dummy_collectives", + sources=sources, + include_dirs=include_dirs, + ) +else: + module = cpp_extension.CppExtension( + name="dummy_collectives", + sources=sources, + include_dirs=include_dirs, + ) + +setup( + name="Dummy-Collectives", + version="0.0.1", + ext_modules=[module], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) diff --git a/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp new file mode 100644 index 000000000..231ef1b1e --- /dev/null +++ b/flagscale/runner/auto_tuner/simulator/custom_backend/src/dummy.cpp @@ -0,0 +1,285 @@ +#include "dummy.hpp" +#include +// #include +// #include +// #include +// #include + +namespace c10d { + + +bool WorkDummy::isCompleted() { + return true; +} + +bool WorkDummy::isSuccess() const { + return true; +} + +bool WorkDummy::wait(std::chrono::milliseconds /* unused */) { + return true; +} + +c10::intrusive_ptr WorkDummy::getFuture() { + return future_; +} + +// If necessary, pass store/rank/size to the ctor and exchange connection +// information here +BackendDummy::BackendDummy(int rank, int size) + : Backend(rank, size) {} + +const std::string BackendDummy::getBackendName() const{ + return "dummy"; +} + +void BackendDummy::startCoalescing(){ + return; + } + +c10::intrusive_ptr BackendDummy::endCoalescing(){ + at::Tensor outputTensors; + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter_tensor_coalesced( + std::vector& outputTensors, + std::vector& inputTensors, + const ReduceScatterOptions&) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allgather_into_tensor_coalesced( + std::vector& outputTensors/* outputs */, + std::vector& inputTensors/* inputs */, + const AllgatherOptions& ) { + // printf("dummy reduce_scatter_tensor_coalesced\n"); + for (auto& outputTensor : outputTensors) { + outputTensor.fill_(1); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_reduce_scatter_base( + at::Tensor& outputTensors/* outputBuffer */, + at::Tensor& inputTensors/* inputBuffer */, + const ReduceScatterOptions& ) { + // printf("dummy _reduce_scatter_base\n"); + // for (auto& outputTensor : outputTensors) { + // outputTensor.fill_(1); + // } + outputTensors.fill_(1); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allgather that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy allgather\n"); + for (auto& outputTensorVec : outputTensors) { + for (auto& outputTensor : outputTensorVec) { + outputTensor.fill_(1); + } + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + future->markCompleted(c10::IValue(outputTensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::all_gather_object( + std::vector& outputTensors, + AnyType& inputTensors, + const AllgatherOptions& /* unused */) { + // printf("dummy all_gather_object Begin\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::_allgather_base( + at::Tensor& /* unused */, + at::Tensor& /* unused */, + const AllgatherOptions& /* unused */) { + // printf("dummy _allgather_base\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +// This is a dummy allreduce that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + // printf("dummy allreduce\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allreduce_coalesced( + std::vector& /* unused */, + const AllreduceCoalescedOptions& /* unused */) { + // printf("dummy allreduce_coalesced\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall( + std::vector& /* unused */, + std::vector& /* unused */, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + // printf("dummy alltoall_base\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::barrier( + const BarrierOptions& /* unused */) { + // printf("dummy barrier\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + // printf("dummy broadcast\n"); + for (auto& tensor : tensors) { + tensor.zero_(); + } + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::gather( + std::vector>& /* unused */, + std::vector& /* unused */, + const GatherOptions& /* unused */) { + // printf("dummy gather\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce( + std::vector& /* unused */, + const ReduceOptions& /* unused */) { + // printf("dummy reduce\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ReduceScatterOptions& /* unused */) { + // printf("dummy reduce_scatter\n"); + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ScatterOptions& /* unused */) { + // printf("dummy scatter\n"); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::send( + std::vector& tensors, + int dstRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recv( + std::vector& tensors, + int srcRank, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::recvAnysource( + std::vector& tensors, + int tag) { + auto future = c10::make_intrusive( + c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::createBackendDummy( + const c10::intrusive_ptr<::c10d::Store>& /* unused */, + int rank, + int size, + const std::chrono::duration& /* unused */) { + return c10::make_intrusive(rank, size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("createBackendDummy", &BackendDummy::createBackendDummy); +} + +} // namespace c10d diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py index 4f35026c0..c3320eccf 100644 --- a/flagscale/train/arguments.py +++ b/flagscale/train/arguments.py @@ -14,6 +14,12 @@ ImportWarning, ) +import os +import threading +import datetime +import multiprocessing +import dummy_collectives + from flagscale.train.hetero.parallel_context import RankMapper @@ -52,26 +58,53 @@ def _initialize_distributed(self): # Manually set the device ids. if device_count > 0: torch.cuda.set_device(args.local_rank) - device_id = torch.device(f"cuda:{args.local_rank}") + device_id = torch.device(f'cuda:{args.local_rank}') else: device_id = None - - # Call the init process - init_process_group_kwargs = { - "backend": args.distributed_backend, - "world_size": args.world_size, - "rank": args.rank, - "timeout": timedelta(minutes=args.distributed_timeout_minutes), - } - if args.distributed_backend == "flagcx": - init_process_group_kwargs["backend"] = "cpu:gloo,cuda:flagcx" - # for communication based cpu - if args.enable_hetero and args.hetero_use_cpu_communication: - # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): - # init_process_group_kwargs['backend'] = 'cpu:gloo' - # Force the group of backend gloo only support cpu - init_process_group_kwargs["backend"] = "cpu:gloo" - torch.distributed.init_process_group(**init_process_group_kwargs) + if args.enable_simulator: + # Define a function to initialize and run operations with a virtual rank + def run_virtual_rank(rank, world_size, timeout): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "37832" + init_process_group_kwargs = { + 'backend' : args.distributed_backend, + 'world_size': world_size, + 'rank': rank, + 'timeout': datetime.timedelta(minutes=timeout), + } + torch.distributed.init_process_group(**init_process_group_kwargs) + torch.distributed.barrier() + + # Call the init process with multithreads + args.distributed_timeout_minutes = 1 + threads = [] + # Start a thread for each virtual rank + for rank in range(1, 2): # 2 for skipping launching thousands of threads + # for rank in range(1, args.world_size): + thread = threading.Thread(target=run_virtual_rank, args=(rank, args.world_size, args.distributed_timeout_minutes)) + thread.start() + threads.append(thread) + rank = 0 + gpu_task = multiprocessing.Process(target=run_virtual_rank, args=(rank, args.world_size, args.distributed_timeout_minutes)) + gpu_task.start() + # Wait for all threads to complete + for thread in threads: + thread.join() + else: + # Call the init process + init_process_group_kwargs = { + "backend": args.distributed_backend, + "world_size": args.world_size, + "rank": args.rank, + "timeout": timedelta(minutes=args.distributed_timeout_minutes), + } + # for communication based cpu + if args.enable_hetero and args.hetero_use_cpu_communication: + # if not all(device_type == args.hetero_device_types[0] for device_type in args.hetero_device_types): + # init_process_group_kwargs['backend'] = 'gloo' + # Force the group of backend gloo only support cpu + init_process_group_kwargs['backend'] = 'cpu:gloo' + torch.distributed.init_process_group(**init_process_group_kwargs) def _build_rank_mapper(self): self._initialize_distributed() diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py index 415d0f787..b9413dd1c 100644 --- a/flagscale/train/hetero/parallel_context.py +++ b/flagscale/train/hetero/parallel_context.py @@ -82,6 +82,7 @@ def __init__(self, args): self._rank_infos = {} self._physical_rank_to_logical_rank = {} self._logical_rank_to_physical_rank = {} + self._enable_simulator = args.enable_simulator self.build_rank_mapping() def build_rank_mapping(self): @@ -91,8 +92,14 @@ def build_rank_mapping(self): all_rank_infos = [None] * world_size cur_rank_info = {'rank': rank, 'device_type': self._hetero_current_device_type} - torch.distributed.all_gather_object( - all_rank_infos, cur_rank_info) + if self._enable_simulator: + for index, value in enumerate(all_rank_infos): + corresponding_rank_info = {'rank': index, 'device_type': self._hetero_current_device_type} + all_rank_infos[index] = corresponding_rank_info + else: + torch.distributed.all_gather_object( + all_rank_infos, cur_rank_info) + physical_ranks = [] for info in all_rank_infos: self._rank_infos[info['rank']] = info @@ -271,7 +278,10 @@ def build_process_group( ranks = self._rank_mapper.to_physical_ranks(logical_ranks) group = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, pg_options=pg_options, group_desc=group_name) if gloo: - group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") + if self._args.enable_simulator: + group_gloo = create_group(ranks, timeout=self._timeout, backend=self._distributed_backend, group_desc=group_name+"_gloo") + else: + group_gloo = create_group(ranks, timeout=self._timeout, backend="gloo", group_desc=group_name+"_gloo") self._all_group_ranks[group_name].append(ranks) if self._rank in ranks: self._group_ranks[group_name] = ranks @@ -637,9 +647,14 @@ def build_all_process_meshes(self): "rank": rank, "process_mesh_idx": self._current_process_mesh_index, } - torch.distributed.all_gather_object( - all_rank_to_process_mesh, cur_rank_to_process_mesh - ) + if self._args.enable_simulator: + for index, value in enumerate(all_rank_to_process_mesh): + corresponding_mesh_info = {'rank': index, 'process_mesh_idx': self._current_process_mesh_index} + all_rank_to_process_mesh[index] = corresponding_mesh_info + else: + torch.distributed.all_gather_object( + all_rank_to_process_mesh, cur_rank_to_process_mesh + ) for item in all_rank_to_process_mesh: self._rank_to_process_mesh[item["rank"]] = self._process_meshes[ item["process_mesh_idx"] @@ -755,7 +770,10 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): ) ranks = list(itertools.chain.from_iterable(ranks_list)) self._global_all_group_ranks["mp"].append(ranks) - group = create_group(ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="mp") + if self._args.enable_simulator: + group = create_group(ranks, timeout=self._timeout, use_local_synchronization=True, backend=self._args.distributed_backend, group_desc="mp") + else: + group = create_group(ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="mp") if self._rank in ranks: self._global_group_ranks["mp"] = ranks self._global_process_groups["mp"] = group @@ -791,15 +809,20 @@ def _backtrack(mesh_index, prev_rank, path, token = "pp", is_expert=False): else: embedding_ranks = ranks position_embedding_ranks = ranks - group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="embd") + if self._args.enable_simulator: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=True, backend=self._args.distributed_backend, group_desc="embd") + else: + group = create_group(embedding_ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="embd") if self._rank in embedding_ranks: self._global_process_groups["embd"].append(group) self._global_process_group_to_ranks[group] = embedding_ranks if self._rank in ranks: self._global_group_ranks["embd"].append(embedding_ranks) - - group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="embd_pos") + if self._args.enable_simulator: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=True, backend=self._args.distributed_backend, group_desc="embd_pos") + else: + group = create_group(position_embedding_ranks, timeout=self._timeout, use_local_synchronization=True, group_desc="embd_pos") if self._rank in position_embedding_ranks: self._global_process_groups["embd_pos"].append(group) self._global_process_group_to_ranks[group] = position_embedding_ranks @@ -1617,6 +1640,8 @@ def _build_ddp_config(args): if hasattr(args, f.name): kwargs[f.name] = getattr(args, f.name) kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad kwargs['bucket_size'] = args.ddp_bucket_size kwargs['average_in_collective'] = args.ddp_average_in_collective diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 6d7fb28a1..3a3ea1aef 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -2124,6 +2124,9 @@ def build_train_valid_test_data_loaders( args.do_valid = getattr(args, "do_valid", False) or flags[1].item() args.do_test = getattr(args, "do_test", False) or flags[2].item() + if args.enable_simulator: + args.do_train = 1 + return train_dataloader, valid_dataloader, test_dataloader diff --git a/flagscale/train/train_gpt.py b/flagscale/train/train_gpt.py index df2f30009..2596efda8 100644 --- a/flagscale/train/train_gpt.py +++ b/flagscale/train/train_gpt.py @@ -199,6 +199,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): # Check individual rank losses are not NaN prior to DP all-reduce. rerun_state_machine = get_rerun_state_machine() + if args.enable_simulator: + args.check_for_nan_in_loss_and_grad = False if args.check_for_nan_in_loss_and_grad: rerun_state_machine.validate_result( result=loss[0], diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index e654d8bbc..1afb38376 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -1935,7 +1935,7 @@ def _add_distributed_args(parser): default=False, help='if set, overlap pipeline parallel communication in warmup and flush', dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo', 'flagcx'], + choices=['nccl', 'gloo', 'flagcx', 'dummy'], help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') @@ -2529,6 +2529,8 @@ def _add_auto_tuner_args(parser): group.add_argument('--auto-tune', action='store_true', help='use auto tuner') + group.add_argument('--enable-simulator', action='store_true', + help='Use single process to simulate the distributed training.') return parser