From 94c173ff8bb11362e45dd9262751f07bf9293660 Mon Sep 17 00:00:00 2001 From: Guy Tevet Date: Fri, 12 Apr 2024 09:32:47 +0300 Subject: [PATCH] Inference is twice faster by calling CLIP just once and caching the results --- .gitignore | 2 ++ diffusion/gaussian_diffusion.py | 4 ++++ model/cfg_sampler.py | 1 + model/mdm.py | 5 ++++- 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b6e47617..c2e6a793 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,5 @@ dmypy.json # Pyre type checker .pyre/ + +save/ diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index fbfb3da0..b825820d 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -637,6 +637,10 @@ def p_sample_loop( if dump_steps is not None: dump = [] + if 'text' in model_kwargs['y'].keys(): + # encoding once instead of each iteration saves lots of time + model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text']) + for i, sample in enumerate(self.p_sample_loop_progressive( model, shape, diff --git a/model/cfg_sampler.py b/model/cfg_sampler.py index 88e12322..6e02517e 100644 --- a/model/cfg_sampler.py +++ b/model/cfg_sampler.py @@ -20,6 +20,7 @@ def __init__(self, model): self.nfeats = self.model.nfeats self.data_rep = self.model.data_rep self.cond_mode = self.model.cond_mode + self.encode_text = self.model.encode_text def forward(self, x, timesteps, y=None): cond_mode = self.model.cond_mode diff --git a/model/mdm.py b/model/mdm.py index 14fd5bda..d874d841 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -148,7 +148,10 @@ def forward(self, x, timesteps, y=None): force_mask = y.get('uncond', False) if 'text' in self.cond_mode: - enc_text = self.encode_text(y['text']) + if 'text_embed' in y.keys(): # caching option + enc_text = y['text_embed'] + else: + enc_text = self.encode_text(y['text']) emb += self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) if 'action' in self.cond_mode: action_emb = self.embed_action(y['action'])