Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/stable-diffusion/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
opencv-python
imagesize
3 changes: 2 additions & 1 deletion examples/stable-diffusion/training/run_1x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ python train_text_to_image_sdxl.py \
--validation_prompt="a robotic cat with wings" \
--validation_epochs 48 \
--checkpointing_steps 2500 \
--logging_step 10 --discount_chkpoint_saving_in_throughput 2>&1 | tee log_1x_r512.txt
--logging_step 10 \
--adjust_throughput 2>&1 | tee log_1x_r512.txt
2 changes: 1 addition & 1 deletion examples/stable-diffusion/training/run_8x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py
--validation_epochs 48 \
--checkpointing_steps 336 \
--mediapipe dataset_sdxl_pokemon \
--discount_chkpoint_saving_in_throughput 2>&1 | tee log_8x_r512.txt
--adjust_throughput 2>&1 | tee log_8x_r512.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,10 @@ def parse_args(input_args=None):
case 3: a non empty path is passed -> images from that location are used ",
)
parser.add_argument(
"--discount_chkpoint_saving_in_throughput",
"--adjust_throughput",
default=False,
action="store_true",
help="Checkpoitn saving takes a lot of time. Ignore time for checkpoint saving for throughput calculations"
help="Checkpoint saving takes a lot of time. Ignore time for checkpoint saving for throughput calculations"
)


Expand Down Expand Up @@ -1340,7 +1340,7 @@ def compute_time_ids(original_size, crops_coords_top_left):

del pipeline

duration = time.perf_counter() - t0 - (checkpoint_time if args.discount_chkpoint_saving_in_throughput else 0)
duration = time.perf_counter() - t0 - (checkpoint_time if args.adjust_throughput else 0)
ttt = time.perf_counter() - t_start
throughput = (args.max_train_steps - args.throughput_warmup_steps) * total_batch_size / duration

Expand Down