Skip to content

Qwen2.5-VL-7B egale3 train#102

Merged
sleepcoo merged 28 commits intosgl-project:mainfrom
Lzhang-hub:qwen-vl
Aug 22, 2025
Merged

Qwen2.5-VL-7B egale3 train#102
sleepcoo merged 28 commits intosgl-project:mainfrom
Lzhang-hub:qwen-vl

Conversation

@Lzhang-hub
Copy link
Copy Markdown
Contributor

@Lzhang-hub Lzhang-hub commented Aug 1, 2025

Motivation

This is a draft pr for support train qwen2.5-vl-7b model.

Modifications

prepare data

  • dataset: FreedomIntelligence/ALLaVA-4V.
  • process data: in the process, pixel_values and image_grid_thw were added, except for input_ids loss_mask attention_mask.

train

  • Add QwenVLOnlineEagle3Model in core/eagle3.py, the main difference is that the input for the draft model is not input_ids, but input embeds that integrate image embeds.
  • add vlm train in train_eagle3_online.py

acc

benchmark:
image

loss metrics
image

acc metrics
image

speedup

server: sglang for qwen-2.5-vl eagle3 infer
benchmark scripts: use mmstar benchmark

Note: draft model Rayzl/qwen2.5-vl-7b-eagle3-sgl is only train on 30k vqa datasets, more data is still training.

  • with eagle

server cmd:

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft Rayzl/qwen2.5-vl-7b-eagle3-sgl --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host 0.0.0.0 --port 8080

benchmark:
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 100

result:

Latency: 34.241 s
Output throughput: 181.069 token/s
Accept length: 3.219
  • without eagle

server cmd:

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --tp 1 --mem-fraction-static 0.7 --host 0.0.0.0 --port 8080

benchmark:
python run_mmstar.py --host http://0.0.0.0 --port 8080 --parallel 1 --num-questions 100

result:

Latency: 54.813 s
Output throughput: 121.230 token/s
Accept length: 1.000

e2e speedup 1.5x

Train scripts

## env dep
pip install qwen_vl_utils

## prepare data
python scripts/prepare_data.py --dataset allava4v --sample-size 100000 --split-eval

## train
bash examples/run_qwen2_5_vl_eagle3_online.sh

Note:

  • current pr only support tp=1

TODO

  • support tp>1
  • offline train

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Lzhang-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented initial support for training the Qwen2.5-VL-7B model using the Eagle3 method. This involves significant changes across data preparation, model definition, and training scripts to accommodate the unique requirements of vision-language models, such as handling image inputs and specialized rotary embeddings. The goal is to enable efficient training of this multimodal model within our existing framework.

Highlights

  • Qwen2.5-VL-7B Model Support: I've added comprehensive support for training the Qwen2.5-VL-7B model, enabling it to leverage its multimodal capabilities within our framework.
  • Eagle3 Draft Model Integration for VLMs: I've integrated the Eagle3 draft model specifically for Qwen2.5-VL, including a new Qwen2_5_VLForCausalLMEagle3 model and QwenVLOnlineEagle3Model for online training, which handles the unique multimodal rotary embedding of Qwen2.5-VL.
  • Enhanced VLM Data Preparation: I've updated the data preparation scripts to support sharegpt4v and allava4v datasets, ensuring that pixel_values and image_grid_thw are correctly processed for VLM training.
  • VLM Training Script Adaptations: I've modified the training script to incorporate VLM-specific logic, such as loading AutoProcessor and passing image-related inputs (pixel_values, image_grid_thw) through the training loop.
  • New Configuration and Training Script: I've added a new configuration file and a dedicated shell script to streamline the setup and execution of Qwen2.5-VL-7B Eagle3 training runs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Lzhang-hub Lzhang-hub changed the title Qwen2.5-VL-7B egale3 draft pr [Draft] Qwen2.5-VL-7B egale3 draft pr Aug 1, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for training the Qwen2.5-VL-7B model, which is a significant and complex addition. The changes are extensive, touching data preparation, model definition, training scripts, and core components. While the overall approach seems sound, I've identified several critical issues that need to be addressed. These include a missing dependency, incorrect handling of position_ids in the multimodal rotary position embedding logic which will likely cause runtime errors and incorrect model behavior, and bugs in the training and testing scripts. I've also included some suggestions for refactoring to improve code maintainability. Addressing these points will be crucial for the stability and correctness of the new VLM training capabilities.

Comment on lines +678 to +681
input_ids = padding(input_ids, left=False)
target = padding(target, left=False)
loss_mask = padding(loss_mask, left=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The position_ids are computed once before the TTT loop but are not updated within the loop. As input_ids, target, and loss_mask are shifted in each iteration using padding(..., left=False), position_ids should also be updated similarly to maintain correct positional information for subsequent TTT steps. Without this, the rotary embeddings will be computed with stale position information.

Suggested change
input_ids = padding(input_ids, left=False)
target = padding(target, left=False)
loss_mask = padding(loss_mask, left=False)
input_ids = padding(input_ids, left=False)
target = padding(target, left=False)
loss_mask = padding(loss_mask, left=False)
position_ids = padding(position_ids, left=False)

from tqdm import tqdm
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer,ImageProcessingMixin
from qwen_vl_utils import process_vision_info
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file imports process_vision_info from qwen_vl_utils. However, qwen_vl_utils does not seem to be part of this repository or a listed dependency. This will cause an ImportError at runtime. Please ensure this utility is added to the repository or included as a dependency.


else:
lck = len(cache_hidden[0])
cos, sin = self.rotary_emb(query_states, position_ids+ lck)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for updating position_ids within the TTT loop by adding lck is incorrect for multimodal inputs. position_ids has a shape of (3, batch_size, seq_len) for Qwen-VL, where each of the 3 components corresponds to different modalities (text, image height, image width). Adding a scalar lck will broadcast incorrectly. Only the text-related position IDs (at index 0) should be offset.

A correct update would be to modify only the text-related part of the position IDs. However, a better approach would be to handle the position updates in the QwenVLOnlineEagle3Model.forward loop, which would simplify the logic here.

Suggested change
cos, sin = self.rotary_emb(query_states, position_ids+ lck)
cos, sin = self.rotary_emb(query_states, position_ids)

Comment on lines +405 to +407
eval_logdict[f"train/ploss_{i}"] = plosses[i].item()
for i in range(len(acces)):
eval_logdict[f"train/acc_{i}"] = acces[i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The wandb logging for evaluation metrics seems to be using the wrong keys. The metrics are logged under train/ploss_{i} and train/acc_{i} which is misleading during the evaluation phase. This should be corrected to eval/ploss_{i} and eval/acc_{i} to accurately reflect that these are evaluation metrics.

Suggested change
eval_logdict[f"train/ploss_{i}"] = plosses[i].item()
for i in range(len(acces)):
eval_logdict[f"train/acc_{i}"] = acces[i]
eval_logdict[f"eval/ploss_{i}"] = plosses[i].item()
for i in range(len(acces)):
eval_logdict[f"eval/acc_{i}"] = acces[i]


class Qwen2_5_VLForCausalLMEagle3(Eagle3DraftModel):

config_class = LlamaConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The config_class is set to LlamaConfig, but for the Qwen2_5_VLForCausalLMEagle3 model, it should be Qwen2_5_VLConfig to ensure the correct configuration is used when the model is instantiated via AutoConfig.

Suggested change
config_class = LlamaConfig
config_class = Qwen2_5_VLConfig

print(f"Loss mask sum: {processed['loss_mask'][0].sum()}")
loss_mask = processed["loss_mask"][0].squeeze(0).tolist()
input_ids = input_ids.squeeze(0)
current_mask = input_ids[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current_mask is initialized with input_ids[0], which is the first token ID. It should be initialized with the first value from loss_mask to correctly compare and group tokens for colored printing.

Suggested change
current_mask = input_ids[0]
current_mask = loss_mask[0]

Comment on lines +326 to +339
if args.is_vlm:
plosses, _, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
pixel_values=data["pixel_values"].cuda(),
image_grid_thw=data["image_grid_thw"].cuda(),
)
else:
plosses, _, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model call is duplicated in the training loop for the VLM and non-VLM cases. This can be refactored to reduce code duplication and improve readability by constructing a dictionary of model inputs and then unpacking it for the model call.

            model_inputs = {
                "input_ids": data["input_ids"].cuda(),
                "attention_mask": data["attention_mask"].cuda(),
                "loss_mask": data["loss_mask"].cuda(),
            }
            if args.is_vlm:
                model_inputs["pixel_values"] = data["pixel_values"].cuda()
                model_inputs["image_grid_thw"] = data["image_grid_thw"].cuda()

            plosses, _, acces = eagle3_model(**model_inputs)

@sleepcoo
Copy link
Copy Markdown
Collaborator

sleepcoo commented Aug 1, 2025

Great job!!!!

@Lzhang-hub Lzhang-hub marked this pull request as ready for review August 5, 2025 10:20
@Lzhang-hub Lzhang-hub changed the title [Draft] Qwen2.5-VL-7B egale3 draft pr Qwen2.5-VL-7B egale3 train Aug 5, 2025
@LugerW-A
Copy link
Copy Markdown

LugerW-A commented Aug 6, 2025

The dataset.map step is very slow, and it hangs when num_proc is greater than 1.

@Lzhang-hub
Copy link
Copy Markdown
Contributor Author

The dataset.map step is very slow, and it hangs when num_proc is greater than 1.

Can you provide your training script? I processed 30,000 datasets, image max size 2k, which takes about 15 minutes.
I trying use ImageProcessorFast instead ImageProcessor

@KerwinKai KerwinKai mentioned this pull request Aug 7, 2025
6 tasks
@LugerW-A
Copy link
Copy Markdown

LugerW-A commented Aug 7, 2025

@Lzhang-hub I use the default command with my own data
I found that setting the temperature to 0 produces different inference results than those from the model without Eagle
Using sgl-project/sglang#8801 this pr
Maybe something wrong with mrope?
Log:
Missing validation function mapping in ROPE_VALIDATION_FUNCTIONS for 'rope_type'='mrope'
[2025-08-07 20:54:10] Warning: User-specified context_length (128000) is greater than the derived context_length (8192). This may lead to incorrect model outputs or CUDA errors.

@ChiikawaSama
Copy link
Copy Markdown

@ChiikawaSama Did you solve it ?

no, but the overall acc seems correct

@sleepcoo sleepcoo merged commit 1618db7 into sgl-project:main Aug 22, 2025
shimizust pushed a commit to shimizust/SpecForge that referenced this pull request Aug 22, 2025
* support qwen2_5_vl online

* delete nohup

* add qwen2.5-vl eagle model

* add todo

* clean dev code

* support batch and fix position_ids bug

* add eval wandb metrics

* fix eval bug

* fix  eval dataloader bug

* add comment

* merge main

* rename vlm online eagle3 model name

* clean code

* fix ttt input embeds bug

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>

* fix eval metrics bug

* merge qwen-vl draft model to llama3

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>

* fix qwen vl train shell

* add timeout config

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>

* qwenvl draft input without image embedding

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>

* qwenvl draft input without image embedding
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>

* Revert "qwenvl draft input without image embedding"

This reverts commit 1e8eab8.

* fix gitignore

* fix wandb error

* fix lint

---------

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>
@mmdbhs
Copy link
Copy Markdown
Contributor

mmdbhs commented Aug 25, 2025

Does it work for qwen2.5-vl-3B?

@Lzhang-hub
Copy link
Copy Markdown
Contributor Author

Does it work for qwen2.5-vl-3B?

@mmdbhs I haven’t tried it, but theoretically it should work. You can give it a try, and if you encounter any problems, feel free to provide feedback at any time.

@oswen
Copy link
Copy Markdown

oswen commented Sep 9, 2025

Great job, does it work for qwen2.5-vl-32B?

@Lzhang-hub
Copy link
Copy Markdown
Contributor Author

Great job, does it work for qwen2.5-vl-32B

It may be need tp, now not supported.

@oswen
Copy link
Copy Markdown

oswen commented Sep 12, 2025

When I use the model and startup script you provided to launch the sglang inference instance, I get this error. Is it due to my transformers or sglang version?
Clipboard_Screenshot_1757645657
Here is my shell:
CUDA_VISIBLE_DEVICES=0 python
-m sglang.launch_server
--model-path /ckpt/qwen2.5-vl-ckpts/Qwen2.5-VL-7B-Instruct
--speculative-draft /ckpt/qwen2.5-vl-ckpts/Qwen2.5-vl-7b-eagle3-sgl-en-zh
--trust-remote-code --chat-template qwen2-vl
--chunked-prefill-size -1
--cuda-graph-max-bs 1
--speculative-algo EAGLE3
--speculative-num-steps 4
--speculative-eagle-topk 6
--speculative-num-draft-tokens 24
--tp 1
--mem-fraction-static 0.7
--host 0.0.0.0
--port 7891

My transformers and sglang version:
sglang 0.5.1
transformers 4.55.2

@Lzhang-hub
Copy link
Copy Markdown
Contributor Author

@oswen You need install sglang from source. v0.5.1 not support qwen-vl eagle3 infer.

@oswen
Copy link
Copy Markdown

oswen commented Sep 12, 2025

@oswen You need install sglang from source. v0.5.1 not support qwen-vl eagle3 infer.

thanks for replying,so which branch of sglang should I chose?the master?

@icicle4
Copy link
Copy Markdown

icicle4 commented Oct 23, 2025

@oswen You need install sglang from source. v0.5.1 not support qwen-vl eagle3 infer.

thanks for replying,so which branch of sglang should I chose?the master?

pip uninstall sglang

pip install sglang==0.5.3

@icicle4
Copy link
Copy Markdown

icicle4 commented Oct 28, 2025

After following your procedure above, the ACC of my draft model reached around 0.6. Similarly, I deployed it using SGLang, but when I tested the Accept Length, I found that with the model you provided, the Accept Length could exceed 3.0, whereas mine only reached about 2.2. Do you have any idea what might be causing this difference? @Lzhang-hub

Train Script

prepare data

python scripts/prepare_data.py --dataset allava4v --sample-size 100000 --split-eval

train

bash examples/run_qwen2_5_vl_eagle3_online.sh

Acc result

image

Test Script

Test your model

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft-model Rayzl/qwen2.5-vl-7b-eagle3-sgl --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host localhost --port 9001

python run_mmstar.py --host http://localhost --port 9001 --parallel 1 --num-questions 100

Result:
Latency: 45.612 s
Output throughput: 141.148 token/s
Accept length: 3.186

Test My trained model

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft-model ${my_local_dir}/epoch_9 --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host localhost --port 9001

python run_mmstar.py --host http://localhost --port 9001 --parallel 1 --num-questions 100

Latency: 60.094 s
Output throughput: 106.566 token/s
Accept length: 2.267

@icicle4
Copy link
Copy Markdown

icicle4 commented Oct 30, 2025

After following your procedure above, the ACC of my draft model reached around 0.6. Similarly, I deployed it using SGLang, but when I tested the Accept Length, I found that with the model you provided, the Accept Length could exceed 3.0, whereas mine only reached about 2.2. Do you have any idea what might be causing this difference? @Lzhang-hub

Train Script

prepare data

python scripts/prepare_data.py --dataset allava4v --sample-size 100000 --split-eval

train

bash examples/run_qwen2_5_vl_eagle3_online.sh

Acc result

image # Test Script ## Test your model python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft-model Rayzl/qwen2.5-vl-7b-eagle3-sgl --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host localhost --port 9001

python run_mmstar.py --host http://localhost --port 9001 --parallel 1 --num-questions 100

Result: Latency: 45.612 s Output throughput: 141.148 token/s Accept length: 3.186

Test My trained model

python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --speculative-draft-model ${my_local_dir}/epoch_9 --trust-remote-code --chat-template qwen2-vl --chunked-prefill-size -1 --cuda-graph-max-bs 1 --speculative-algo EAGLE3 --speculative-num-steps 4 --speculative-eagle-topk 6 --speculative-num-draft-tokens 24 --tp 1 --mem-fraction-static 0.7 --host localhost --port 9001

python run_mmstar.py --host http://localhost --port 9001 --parallel 1 --num-questions 100

Latency: 60.094 s Output throughput: 106.566 token/s Accept length: 2.267

Resolved. It caused due to wrong data preprocess.

@cranehuang
Copy link
Copy Markdown

cranehuang commented Nov 10, 2025

The dataset.map step is very slow, and it hangs when num_proc is greater than 1.

I have the same issue,have you solved it? thanks @LugerW-A

@icicle4
Copy link
Copy Markdown

icicle4 commented Nov 10, 2025

The dataset.map step is very slow, and it hangs when num_proc is greater than 1.

I have the same issue,have you solved it? thanks

No idea. Just accept it, use num_proc=0.

@cranehuang
Copy link
Copy Markdown

The dataset.map step is very slow, and it hangs when num_proc is greater than 1.

I have the same issue,have you solved it? thanks

No idea. Just accept it, use num_proc=0.

OKay,thanks

@Abigbigbig
Copy link
Copy Markdown

May I ask if this model can be trained with a 48GB A6000 GPU? I encountered a resource limit exceeded issue during kernel compilation using this GPU.

@330205812
Copy link
Copy Markdown

@LugerW-A @icicle4
use this code to avoid stuck while loading data in main func

def main():
    # ================================================
    # 1. Initialize
    # ================================================
    parser, args = parse_args()
    set_seed(args.seed)
    init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
    sanity_check(args)
    print_with_rank("Initialized distributed environment")

    # # ================================================
    # # 2. Build models
    # # ================================================
    # draft_model_config, draft_model = build_draft_model(args)
    # target_model, processor = build_target_model(args, draft_model_config)

    # ================================================
    # 2. Pre-load Processor (CPU only) & Config
    # ================================================
    if args.draft_model_config is None:
        auto_config_path = create_draft_config_from_target(
            target_model_path=args.target_model_path, cache_dir=args.cache_dir
        )
        draft_model_config = AutoDraftModelConfig.from_file(auto_config_path)
    else:
        draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)

    processor = None
    if args.is_vlm:
        processor = AutoProcessor.from_pretrained(
            args.target_model_path,
            min_pixels=args.min_pixels,
            max_pixels=args.max_pixels,
        )

    # ================================================
    # 3. Build dataloader
    # ================================================

    train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders(
        args, draft_model_config, processor
    )
    
    _, draft_model = build_draft_model(args) 
    
    target_model, _ = build_target_model(args, draft_model_config)
    
    # we load the vocab mapping then
    draft_model.load_vocab_mapping(vocab_mapping_path)
    print_with_rank("Loaded vocab mapping")

    # Calculate total steps if not provided
    if args.total_steps is None:
        steps_per_epoch = math.ceil(
            len(train_dataloader) / args.draft_accumulation_steps
        )
        args.total_steps = args.num_epochs * steps_per_epoch
        print_with_rank(
            f"Auto-calculated total_steps: {args.total_steps} (num_epochs={args.num_epochs} * steps_per_epoch={steps_per_epoch})"
        )
    else:
        print_with_rank(f"Using provided total_steps: {args.total_steps}")
    
    ...other

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.