Skip to content

Commit

Permalink
optimize a tiny bit
Browse files Browse the repository at this point in the history
  • Loading branch information
A4P7J1N7M05OT committed May 12, 2024
1 parent 511b9f8 commit 95f4c52
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4880,9 +4880,9 @@ def save_sd_model_on_train_end_common(
def schedule_timesteps(max_calls, min_timestep, max_timestep, restarts=0, b_size=(1,)):
if not hasattr(schedule_timesteps, 'current_step'):
schedule_timesteps.current_step = 1
schedule_timesteps.max_calls_split = max_calls // (restarts + 1)

max_calls_split = max_calls // (restarts + 1)
interpolation_factor = schedule_timesteps.current_step / max_calls_split
interpolation_factor = schedule_timesteps.current_step / schedule_timesteps.max_calls_split

# Calculate mode for the triangular distribution
mode = (max_timestep * interpolation_factor) + min_timestep
Expand All @@ -4893,7 +4893,7 @@ def schedule_timesteps(max_calls, min_timestep, max_timestep, restarts=0, b_size
# Generate a random timestep using a triangular distribution
timestep = torch.from_numpy(np.random.triangular(min_timestep, mode, max_timestep, size=b_size)).long()

if schedule_timesteps.current_step >= max_calls_split:
if schedule_timesteps.current_step >= schedule_timesteps.max_calls_split:
schedule_timesteps.current_step = 0
schedule_timesteps.current_step += 1

Expand Down

0 comments on commit 95f4c52

Please sign in to comment.