Skip to content

Commit 0401ecf

Browse files
author
saehoonkim
committed
clean-up unnecessary code blocks and update docs for better readability
1 parent 0b5dbb0 commit 0401ecf

File tree

10 files changed

+61
-50
lines changed

10 files changed

+61
-50
lines changed

assets/improved_sr_arch.jpg

1.1 MB
Loading

assets/improved_sr_arch.png

-4.13 MB
Binary file not shown.

configs/decoder_900M_vit_l.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ model:
1919
xf_layers: 0
2020
xf_heads: 0
2121
xf_final_ln: false
22-
xf_padding: false
2322
resblock_updown: true
2423
learn_sigma: true
25-
cache_text_emb: false
2624
text_drop: 0.3
2725
clip_emb_type: image
2826
clip_emb_drop: 0.1

configs/prior_1B_vit_l.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ model:
77
xf_layers: 20
88
xf_heads: 32
99
xf_final_ln: true
10-
xf_padding: false
1110
text_drop: 0.2
1211
clip_dim: 768
13-
clip_xf_width: 768
1412

1513
diffusion:
1614
steps: 1000

karlo/models/decoder_model.py

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212

1313
class Text2ImProgressiveModel(torch.nn.Module):
14+
"""
15+
A decoder that generates 64x64px images based on the text prompt.
16+
17+
:param config: yaml config to define the decoder.
18+
:param tokenizer: tokenizer used in clip.
19+
"""
20+
1421
def __init__(
1522
self,
1623
config,

karlo/models/prior_model.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212

1313
class PriorDiffusionModel(torch.nn.Module):
14+
"""
15+
A prior that generates clip image feature based on the text prompt.
16+
17+
:param config: yaml config to define the decoder.
18+
:param tokenizer: tokenizer used in clip.
19+
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
20+
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
21+
"""
22+
1423
def __init__(self, config, tokenizer, clip_mean, clip_std):
1524
super().__init__()
1625

@@ -40,9 +49,7 @@ def __init__(self, config, tokenizer, clip_mean, clip_std):
4049
xf_layers=self._model_conf.xf_layers,
4150
xf_heads=self._model_conf.xf_heads,
4251
xf_final_ln=self._model_conf.xf_final_ln,
43-
xf_padding=self._model_conf.xf_padding,
4452
clip_dim=self._model_conf.clip_dim,
45-
clip_xf_width=self._model_conf.clip_xf_width,
4653
)
4754

4855
cf_token, cf_mask = self.set_cf_text_tensor()

karlo/modules/unet.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ class UNetModel(nn.Module):
421421
:param conv_resample: if True, use learned convolutions for upsampling and
422422
downsampling.
423423
:param dims: determines if the signal is 1D, 2D, or 3D.
424+
:param clip_dim: dimension of clip feature.
424425
:param num_classes: if specified (as an int), then this model will be
425426
class-conditional with `num_classes` classes.
426427
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
@@ -431,6 +432,8 @@ class UNetModel(nn.Module):
431432
of heads for upsampling. Deprecated.
432433
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
433434
:param resblock_updown: use residual blocks for up/downsampling.
435+
:param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
436+
:param use_time_embedding: use time embedding for condition.
434437
"""
435438

436439
def __init__(
@@ -672,6 +675,7 @@ class SuperResUNetModel(UNetModel):
672675
A UNetModel that performs super-resolution.
673676
674677
Expects an extra kwarg `low_res` to condition on a low-resolution image.
678+
Assumes that the shape of low-resolution and the input should be the same.
675679
"""
676680

677681
def __init__(self, *args, **kwargs):
@@ -686,22 +690,21 @@ def __init__(self, *args, **kwargs):
686690

687691
def forward(self, x, timesteps, low_res=None, **kwargs):
688692
_, _, new_height, new_width = x.shape
689-
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
690-
x = th.cat([x, upsampled], dim=1)
693+
assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
694+
695+
x = th.cat([x, low_res], dim=1)
691696
return super().forward(x, timesteps, **kwargs)
692697

693698

694699
class PLMImUNet(UNetModel):
695700
"""
696-
A UNetModel that conditions on text with an encoding transformer.
697-
698-
Expects an extra kwarg `tokens` of text.
701+
A UNetModel that conditions on text with a pretrained text encoder in CLIP.
699702
700703
:param text_ctx: number of text tokens to expect.
701704
:param xf_width: width of the transformer.
702-
:param xf_layers: depth of the transformer.
703-
:param xf_heads: heads in the transformer.
704-
:param xf_final_ln: use a LayerNorm after the output layer.
705+
:param clip_emb_mult: #extra tokens by projecting clip text feature.
706+
:param clip_emb_type: type of condition (here, we fix clip image feature).
707+
:param clip_emb_drop: dropout rato of clip image feature for cfg.
705708
"""
706709

707710
def __init__(
@@ -725,21 +728,21 @@ def __init__(
725728
else:
726729
super().__init__(*args, **kwargs, encoder_channels=xf_width)
727730

728-
# Project text encoded feat seq from pre-trained LM
731+
# Project text encoded feat seq from pre-trained text encoder in CLIP
729732
self.text_seq_proj = nn.Sequential(
730733
nn.Linear(self.clip_dim, xf_width),
731734
LayerNorm(xf_width),
732735
)
733736
# Project CLIP text feat
734737
self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
735738

736-
if self.clip_emb_mult is not None:
737-
assert (
738-
self.clip_dim is not None
739-
), "CLIP representation dim should be specified"
740-
self.clip_tok_proj = nn.Linear(
741-
self.clip_dim, self.xf_width * self.clip_emb_mult
742-
)
739+
assert clip_emb_mult is not None
740+
assert clip_emb_type == "image"
741+
assert self.clip_dim is not None, "CLIP representation dim should be specified"
742+
743+
self.clip_tok_proj = nn.Linear(
744+
self.clip_dim, self.xf_width * self.clip_emb_mult
745+
)
743746
if self.clip_emb_drop > 0:
744747
self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
745748

@@ -761,21 +764,19 @@ def forward(
761764
bsz = x.shape[0]
762765
hs = []
763766
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
764-
if self.clip_dim is not None:
765-
emb = emb + self.clip_emb(y)
767+
emb = emb + self.clip_emb(y)
766768

767769
xf_out = self.text_seq_proj(txt_feat_seq)
768770
xf_out = xf_out.permute(0, 2, 1)
769771
emb = emb + self.text_feat_proj(txt_feat)
770-
if self.clip_emb_mult is not None:
771-
xf_out = th.cat(
772-
[
773-
self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
774-
xf_out,
775-
],
776-
dim=2,
777-
)
778-
mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
772+
xf_out = th.cat(
773+
[
774+
self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
775+
xf_out,
776+
],
777+
dim=2,
778+
)
779+
mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
779780
mask = th.where(mask, 0.0, float("-inf"))
780781

781782
h = x

karlo/modules/xf.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,12 @@ class PriorTransformer(nn.Module):
138138
"""
139139
A Causal Transformer that conditions on CLIP text embedding, text.
140140
141-
Expects an extra kwarg `tokens` of text.
142-
143141
:param text_ctx: number of text tokens to expect.
144142
:param xf_width: width of the transformer.
145143
:param xf_layers: depth of the transformer.
146144
:param xf_heads: heads in the transformer.
147145
:param xf_final_ln: use a LayerNorm after the output layer.
146+
:param clip_dim: dimension of clip feature.
148147
"""
149148

150149
def __init__(
@@ -154,27 +153,23 @@ def __init__(
154153
xf_layers,
155154
xf_heads,
156155
xf_final_ln,
157-
xf_padding,
158156
clip_dim,
159-
clip_xf_width,
160157
):
161158
super().__init__()
162159

163160
self.text_ctx = text_ctx
164161
self.xf_width = xf_width
165162
self.xf_layers = xf_layers
166163
self.xf_heads = xf_heads
167-
self.xf_padding = xf_padding
168164
self.clip_dim = clip_dim
169-
self.clip_xf_width = clip_xf_width
170165
self.ext_len = 4
171166

172167
self.time_embed = nn.Sequential(
173168
nn.Linear(xf_width, xf_width),
174169
nn.SiLU(),
175170
nn.Linear(xf_width, xf_width),
176171
)
177-
self.text_enc_proj = nn.Linear(clip_xf_width, xf_width)
172+
self.text_enc_proj = nn.Linear(clip_dim, xf_width)
178173
self.text_emb_proj = nn.Linear(clip_dim, xf_width)
179174
self.clip_img_proj = nn.Linear(clip_dim, xf_width)
180175
self.out_proj = nn.Linear(xf_width, clip_dim)
@@ -194,12 +189,6 @@ def __init__(
194189
)
195190
self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
196191

197-
if self.xf_padding:
198-
self.padding_embedding = nn.Parameter(
199-
th.empty(text_ctx + self.ext_len, xf_width)
200-
)
201-
nn.init.normal_(self.padding_embedding, std=0.01)
202-
203192
nn.init.normal_(self.prd_emb, std=0.01)
204193
nn.init.normal_(self.positional_embedding, std=0.01)
205194

@@ -229,10 +218,6 @@ def forward(
229218
]
230219
input = th.cat(input_seq, dim=1)
231220
input = input + self.positional_embedding.to(input.dtype)
232-
if self.xf_padding:
233-
input = th.where(
234-
mask[..., None], input, self.padding_embedding[None].to(input.dtype)
235-
)
236221

237222
mask = th.where(mask, 0.0, float("-inf"))
238223
mask = (mask[:, None, :] + causal_mask).to(input.dtype)

karlo/sampler/i2i.py

+8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313

1414

1515
class I2ISampler(BaseSampler):
16+
"""
17+
A sampler for image variation. In the original unclip paper, image variation transforms the noise obtained by DDIM inversion into a sample in RGB space.
18+
Here, we simply transform the white noise to image, conditioned on the clip image feature.
19+
20+
:param root_dir: directory for model checkpoints.
21+
:param sampling_type: ["default", "fast"]
22+
"""
23+
1624
def __init__(
1725
self,
1826
root_dir: str,

karlo/sampler/t2i.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313

1414

1515
class T2ISampler(BaseSampler):
16+
"""
17+
A sampler for text-to-image generation.
18+
19+
:param root_dir: directory for model checkpoints.
20+
:param sampling_type: ["default", "fast"]
21+
"""
22+
1623
def __init__(
1724
self,
1825
root_dir: str,

0 commit comments

Comments
 (0)