Skip to content

Commit

Permalink
Merge pull request #12 from Doggettx/autocast-improvements
Browse files Browse the repository at this point in the history
Performance boost and fix sigmoid for higher resolutions
  • Loading branch information
Doggettx authored Sep 12, 2022
2 parents cd3d653 + 5fe97c6 commit ab0bff6
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
13 changes: 7 additions & 6 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
Expand Down Expand Up @@ -162,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
Expand Down Expand Up @@ -190,14 +189,16 @@ def forward(self, x, context=None, mask=None):
mem_free_total = mem_free_cuda + mem_free_torch

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
mem_required = tensor_size * 2.5
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1


if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
Expand All @@ -209,7 +210,7 @@ def forward(self, x, context=None, mask=None):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale

s2 = s1.softmax(dim=-1)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
Expand Down
10 changes: 7 additions & 3 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim):

def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
t = torch.sigmoid(x)
x *= t
del t

return x


def Normalize(in_channels, num_groups=32):
Expand Down Expand Up @@ -215,7 +219,7 @@ def forward(self, x):
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1

Expand All @@ -229,7 +233,7 @@ def forward(self, x):
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2)
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2

# attend to values
Expand Down
3 changes: 3 additions & 0 deletions scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def main():
opt = parser.parse_args()
seed_everything(opt.seed)

# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
model = model.half()
Expand Down
3 changes: 3 additions & 0 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def main():

seed_everything(opt.seed)

# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
model = model.half()
Expand Down

0 comments on commit ab0bff6

Please sign in to comment.