Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to force model to generate image? #17

Open
haochuan-li opened this issue Dec 15, 2023 · 2 comments
Open

How to force model to generate image? #17

haochuan-li opened this issue Dec 15, 2023 · 2 comments

Comments

@haochuan-li
Copy link

haochuan-li commented Dec 15, 2023

Hi! Great work.

image

I see there's a "force image generation" option in the gradio demo.
I wonder how to implement this in code? Can anyone enlighten me on this?

Thanks.

@sijeh
Copy link
Collaborator

sijeh commented Jan 23, 2024

Sorry for the late reply. Force image generation can be achieved by manually adding BOI token (Begin of image). The code can be found in the following link:

input_text += BOI_TOKEN

@haochuan-li
Copy link
Author

haochuan-li commented Jan 25, 2024

Thanks for the reply!

@sijeh, I have another question related to the zero-shot retrieval evaluation. I cannot reproduce Table1 results in SEED-LLaMA paper.

Here's my code preparing Text Embedding and Image Embedding for Flickr30k

"""
Setting: Using Seed-LLaMA Tokenizer 2
"""
import hydra
from omegaconf import OmegaConf
from lavis.models import load_model
device = 'cuda'

tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer_hf.yaml'
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
seed_tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=False)

"""Preparing Flickr Text Embedding, simply follow blip2 retrieval"""
blip2_model = load_model("blip2", "pretrain")
blip2_model.eval().to(device)

text_emb = []
blip_text = blip2_model.tokenizer(captions, padding='max_length', truncation=True, max_length=32, return_tensors='pt')

blip_dataset = TextDataset(blip_text)
blip_dataloader = DataLoader(blip_dataset, 
                                shuffle=False, 
                                drop_last=False, 
                                num_workers=8,
                                pin_memory=True, 
                                batch_size=args.batch_size)

for (input_ids, attention_mask) in tqdm(blip_dataloader, desc='text', unit='text'):
    qformer_output = blip2_model.Qformer.bert(input_ids.to(device), attention_mask=attention_mask.to(device), return_dict=True).last_hidden_state[:,0,:]
    text_emb.append(qformer_output.detach().cpu())
text_emb = torch.concat(text_emb) # Text Emb for Retrieval, shape=[5000, 768]


"""Preparing Flickr Image Embedding"""
causal_code_pt = []
causal_emb_pt = []
for im in tqdm(imgs_gt, desc="tokenizing img", unit='img'):
    _, causal_code, causal_emb = seed_tokenizer.encode_image(image_torch=transform(im).to(device))
    causal_code_pt.append(causal_code[0][-1].squeeze())  # take the final embedding
    causal_emb_pt.append(causal_emb[0][-1].squeeze()) # take the final embedding

causal_code_pt = torch.stack(causal_code_pt) # Causal Code For Retrieval, shape=[1000, 768]
causal_emb_pt = torch.stack(causal_emb_pt) # Causal Emb For Retrieval, shape=[1000,768]

"""
The Detail about how to get causal code and causal emb, 
I modified the code in models/seed_qformer/qformer_quantizer.py
"""

def get_codebook_indices(self, image):
    with torch.no_grad():
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        print("image embeds", image_embeds.shape) # [1,257,1408]
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        ) 
        # query_output hidden shape=[1,32,768]
        # query output down shape=[1,32,32]
        # query output up shape=[1,32,768]
    
        query_output_down = self.encode_task_layer(query_output.last_hidden_state)
        quant, loss_embed, embed_ind = self.quantize(query_output_down)
        embed_ind = embed_ind.reshape(quant.shape[0], -1)
        
        query_output_up = self.decode_task_layer(quant)
    return embed_ind, query_output_up, query_output.last_hidden_state

"""Compute Similarity Matrix"""
causal_code /= causal_code.norm(dim=-1, keepdim=True)
causal_emb /= causal_emb.norm(dim=-1, keepdim=True)

blip_causal_code_sim = (text_emb @ causal_code.T) 
blip_causal_emb_sim = (text_emb @ causal_emb.T)

Results in paper

image

Reproduced Results

image

Question

image

I'm not sure whether this is the right way to get the text embedding and image embedding illustrated in the SEED-LLaMA paper. Please correct me if I'm wrong.

Looking Forward to your reply.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants