Skip to content

Commit 135fa49

Browse files
authored
Small speed improvements to --async-offload (#10593)
* ops: dont take an offload stream if you dont need one * ops: prioritize mem transfer The async offload streams reason for existence is to transfer from RAM to GPU. The post processing compute steps are a bonus on the side stream, but if the compute stream is running a long kernel, it can stall the side stream, as it wait to type-cast the bias before transferring the weight. So do a pure xfer of the weight straight up, then do everything bias, then go back to fix the weight type and do weight patches.
1 parent 44869ff commit 135fa49

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

comfy/ops.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
8484
if device is None:
8585
device = input.device
8686

87-
if offloadable:
87+
if offloadable and (device != s.weight.device or
88+
(s.bias is not None and device != s.bias.device)):
8889
offload_stream = comfy.model_management.get_offload_stream(device)
8990
else:
9091
offload_stream = None
@@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
9495
else:
9596
wf_context = contextlib.nullcontext()
9697

97-
bias = None
9898
non_blocking = comfy.model_management.device_supports_non_blocking(device)
99+
100+
weight_has_function = len(s.weight_function) > 0
101+
bias_has_function = len(s.bias_function) > 0
102+
103+
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
104+
105+
bias = None
99106
if s.bias is not None:
100-
has_function = len(s.bias_function) > 0
101-
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
107+
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
102108

103-
if has_function:
109+
if bias_has_function:
104110
with wf_context:
105111
for f in s.bias_function:
106112
bias = f(bias)
107113

108-
has_function = len(s.weight_function) > 0
109-
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
110-
if has_function:
114+
weight = weight.to(dtype=dtype)
115+
if weight_has_function:
111116
with wf_context:
112117
for f in s.weight_function:
113118
weight = f(weight)

0 commit comments

Comments
 (0)