From 61c9a7d83e2cf31148578bb0ff9cbb9154c54cf3 Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Tue, 1 Aug 2023 13:00:14 +0800 Subject: [PATCH 1/8] zero-shot CLIP --- .../clip_zs/clip-vit-base-patch16_cifar100.py | 67 ++++ configs/clip_zs/clip-vit-base-patch16_in1k.py | 69 ++++ .../clip-vit-large-patch14_cifar100.py | 67 ++++ .../clip_zs/clip-vit-large-patch14_in1k.py | 69 ++++ mmpretrain/datasets/categories.py | 221 ++++++++++++ mmpretrain/models/multimodal/__init__.py | 3 +- .../models/multimodal/clip_zs/__init__.py | 5 + mmpretrain/models/multimodal/clip_zs/clip.py | 324 ++++++++++++++++++ .../multimodal/clip_zs/clip_transformer.py | 128 +++++++ mmpretrain/models/multimodal/clip_zs/utils.py | 105 ++++++ 10 files changed, 1057 insertions(+), 1 deletion(-) create mode 100644 configs/clip_zs/clip-vit-base-patch16_cifar100.py create mode 100644 configs/clip_zs/clip-vit-base-patch16_in1k.py create mode 100644 configs/clip_zs/clip-vit-large-patch14_cifar100.py create mode 100644 configs/clip_zs/clip-vit-large-patch14_in1k.py create mode 100644 mmpretrain/models/multimodal/clip_zs/__init__.py create mode 100644 mmpretrain/models/multimodal/clip_zs/clip.py create mode 100644 mmpretrain/models/multimodal/clip_zs/clip_transformer.py create mode 100644 mmpretrain/models/multimodal/clip_zs/utils.py diff --git a/configs/clip_zs/clip-vit-base-patch16_cifar100.py b/configs/clip_zs/clip-vit-base-patch16_cifar100.py new file mode 100644 index 00000000000..170942c9863 --- /dev/null +++ b/configs/clip_zs/clip-vit-base-patch16_cifar100.py @@ -0,0 +1,67 @@ +cl_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='/public/DATA/qbw/img_cls_dataset/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIP_zs', + vision_backbone=dict( + type='CLIPVisionTransformer', + input_resolution=224, + patch_size=16, + width=768, + layers=12, + heads=12, + output_dim=512, + ), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip_zs/clip-vit-base-patch16_in1k.py new file mode 100644 index 00000000000..e620fb293b6 --- /dev/null +++ b/configs/clip_zs/clip-vit-base-patch16_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root= + '/public/DATA/qbw/img_cls_dataset/in1k/imagenet-1k-huggingface/data/', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIP_zs', + vision_backbone=dict( + type='CLIPVisionTransformer', + input_resolution=224, + patch_size=16, + width=768, + layers=12, + heads=12, + output_dim=512, + ), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='imagenet', + text_prompt='openai_imagenet', + context_length=77, +) diff --git a/configs/clip_zs/clip-vit-large-patch14_cifar100.py b/configs/clip_zs/clip-vit-large-patch14_cifar100.py new file mode 100644 index 00000000000..6c97a451874 --- /dev/null +++ b/configs/clip_zs/clip-vit-large-patch14_cifar100.py @@ -0,0 +1,67 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='/public/DATA/qbw/img_cls_dataset/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIP_zs', + vision_backbone=dict( + type='CLIPVisionTransformer', + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768, + ), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip_zs/clip-vit-large-patch14_in1k.py new file mode 100644 index 00000000000..0559eee6f41 --- /dev/null +++ b/configs/clip_zs/clip-vit-large-patch14_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root= + '/public/DATA/qbw/img_cls_dataset/in1k/imagenet-1k-huggingface/data/', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIP_zs', + vision_backbone=dict( + type='CLIPVisionTransformer', + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768, + ), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='imagenet', + text_prompt='openai_imagenet', + context_length=77, +) diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py index 011ee5c1609..9e75f7953b8 100644 --- a/mmpretrain/datasets/categories.py +++ b/mmpretrain/datasets/categories.py @@ -1438,3 +1438,224 @@ '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', '柳树', '狼', '女人', '蠕虫') + +IMAGENET_SIMPLE_CATEGORIES = ( + 'tench', 'goldfish', 'great white shark', 'tiger shark', + 'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen', + 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', + 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', + 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', + 'great grey owl', 'fire salamander', 'smooth newt', 'newt', + 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', + 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana', + 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', + 'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile', + 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', + 'garter snake', 'water snake', 'vine snake', 'night snake', + 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', + 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake', + 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', + 'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', + 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', + 'quail', 'partridge', 'african grey parrot', 'macaw', + 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', + 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', + 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', + 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', + 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', + 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', + 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', + 'flamingo', 'little blue heron', 'great egret', 'bittern bird', + 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', + 'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher', + 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', + 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', + 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', + 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound', + 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', + 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', + 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', + 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', + 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', + 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', + 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', + 'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', + 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', + 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', + 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', + 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher', + 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', + 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', + 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo', + 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', + 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', + 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', + 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', + 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', + 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', + 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', + 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', + 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', + 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', + 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', + 'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', + 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', + 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', + 'parachute', 'parallel bars', 'park bench', 'parking meter', + 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case', + 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', + 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship', + 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', + 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', + 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', + 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', + 'radiator', 'radio', 'radio telescope', 'rain barrel', + 'recreational vehicle', 'fishing casting reel', 'reflex camera', + 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', + 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', + 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker', + 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', + 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', + 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store', + 'shoji screen / room divider', 'shopping basket', 'shopping cart', + 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask', + 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', + 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', + 'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', + 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', + 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', + 'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen', + 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', + 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', + 'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof', + 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', + 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', + 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', + 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', + 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', + 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', + 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', + 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', + 'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', + 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', + 'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate', + 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle', + 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', + 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', + 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', + 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', + 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', + 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', + 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', + 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', + 'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom', + 'bolete', 'corn cob', 'toilet paper') diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index 072c0f84f72..cb8acfc657d 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -5,6 +5,7 @@ from .blip import * # noqa: F401,F403 from .blip2 import * # noqa: F401,F403 from .chinese_clip import * # noqa: F401, F403 + from .clip_zs import * # noqa: F401, F403 from .flamingo import * # noqa: F401, F403 from .llava import * # noqa: F401, F403 from .minigpt4 import * # noqa: F401, F403 @@ -17,5 +18,5 @@ register_multimodal_placeholder([ 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter' + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP_zs' ], MODELS) diff --git a/mmpretrain/models/multimodal/clip_zs/__init__.py b/mmpretrain/models/multimodal/clip_zs/__init__.py new file mode 100644 index 00000000000..3f98bbd4ed6 --- /dev/null +++ b/mmpretrain/models/multimodal/clip_zs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip_zs.clip import CLIP_zs +from ..clip_zs.clip_transformer import CLIPTransformer, CLIPVisionTransformer + +__all__ = ['CLIP_zs', 'CLIPTransformer', 'CLIPVisionTransformer'] diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py new file mode 100644 index 00000000000..bb4e99b872b --- /dev/null +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -0,0 +1,324 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, + IMAGENET_SIMPLE_CATEGORIES) +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT + +CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] +PROTOTYPE_MAP = { + 'imagenet': IMAGENET_SIMPLE_CATEGORIES, + 'cifar100': CIFAR100_CATEGORIES +} +PROMPT_MAP = { + 'openai_imagenet': OPENAI_IMAGENET_PROMPT, + 'openai_cifar100': OPENAI_CIFAR100_PROMPT, + 'vanilla': [lambda c: f'a photo of a {c}'] +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +@MODELS.register_module() +class CLIP_zs(BaseModel): + """The implementation of `ChineseCLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. Defaults to 'openai'. + context_length (int): The context length to use. Defaults to 52. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + text_prototype: Union[str, List[str]], + text_prompt: str = 'vanilla', + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.context_length = context_length + + self.visual = MODELS.build(vision_backbone) + text_backbone['attn_mask'] = self.build_attention_mask() + self.transformer = MODELS.build(text_backbone) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, proj_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + self.tokenizer = TOKENIZER.build(tokenizer) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + self.tokenizer.vocab = self.tokenizer.get_vocab( + ) # CLIPTokenizer has no attribute named 'vocab', so manually + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + # return self.vision_backbone(images)[-1] @ self.vision_projection + return self.visual(images)[0] + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x)[0] + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + texts.argmax(dim=-1)] @ self.text_projection + + return x + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/mmpretrain/models/multimodal/clip_zs/clip_transformer.py b/mmpretrain/models/multimodal/clip_zs/clip_transformer.py new file mode 100644 index 00000000000..3726239d10b --- /dev/null +++ b/mmpretrain/models/multimodal/clip_zs/clip_transformer.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from typing import Optional, Tuple + +import torch +from torch import nn + +from mmpretrain.models.utils.clip_generator_helper import ( + LayerNorm, ResidualAttentionBlock) +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CLIPTransformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +@MODELS.register_module() +class CLIPVisionTransformer(nn.Module): + """Vision Transformer for CLIP. + + Args: + input_resolution (int): The image size. + patch_size (int): The patch size. + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + out_dim (int): The output dimension. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + finetune=False, + average_targets: int = 1) -> None: + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = CLIPTransformer(width, layers, heads) + + self.finetune = finetune + if finetune is False: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.average_targets = average_targets + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function.""" + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attention, z = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + if self.proj is not None: + x = x @ self.proj + + return x, attention diff --git a/mmpretrain/models/multimodal/clip_zs/utils.py b/mmpretrain/models/multimodal/clip_zs/utils.py new file mode 100644 index 00000000000..f442019a1b9 --- /dev/null +++ b/mmpretrain/models/multimodal/clip_zs/utils.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +OPENAI_CIFAR100_PROMPT = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +OPENAI_IMAGENET_PROMPT = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] From 750d8b7307f8ff00b93714870fc7477865c4558d Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Tue, 1 Aug 2023 14:20:01 +0800 Subject: [PATCH 2/8] modify zero-shot clip config --- configs/clip_zs/clip-vit-base-patch16_cifar100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/clip_zs/clip-vit-base-patch16_cifar100.py b/configs/clip_zs/clip-vit-base-patch16_cifar100.py index 170942c9863..81825aa3ae3 100644 --- a/configs/clip_zs/clip-vit-base-patch16_cifar100.py +++ b/configs/clip_zs/clip-vit-base-patch16_cifar100.py @@ -1,4 +1,4 @@ -cl_base_ = '../_base_/default_runtime.py' +_base_ = '../_base_/default_runtime.py' # data settings data_preprocessor = dict( From e6d980b9113925e4358ab746ced2103dcfb0b906 Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Thu, 10 Aug 2023 00:00:27 +0800 Subject: [PATCH 3/8] add in1k_sub_prompt(8 prompts) for improvement --- configs/clip_zs/clip-vit-base-patch16_in1k.py | 2 +- configs/clip_zs/clip-vit-large-patch14_in1k.py | 2 +- mmpretrain/models/multimodal/clip_zs/clip.py | 8 +++++--- mmpretrain/models/multimodal/clip_zs/utils.py | 10 ++++++++++ 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip_zs/clip-vit-base-patch16_in1k.py index e620fb293b6..cc80ca4cb67 100644 --- a/configs/clip_zs/clip-vit-base-patch16_in1k.py +++ b/configs/clip_zs/clip-vit-base-patch16_in1k.py @@ -64,6 +64,6 @@ transformer_width=512, proj_dim=512, text_prototype='imagenet', - text_prompt='openai_imagenet', + text_prompt='openai_imagenet_sub', context_length=77, ) diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip_zs/clip-vit-large-patch14_in1k.py index 0559eee6f41..b566381c139 100644 --- a/configs/clip_zs/clip-vit-large-patch14_in1k.py +++ b/configs/clip_zs/clip-vit-large-patch14_in1k.py @@ -64,6 +64,6 @@ transformer_width=768, proj_dim=768, text_prototype='imagenet', - text_prompt='openai_imagenet', + text_prompt='openai_imagenet_sub', context_length=77, ) diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py index bb4e99b872b..7aca462c559 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -12,17 +12,19 @@ from mmpretrain.registry import MODELS, TOKENIZER from mmpretrain.structures import DataSample from mmpretrain.utils import track_on_main_process -from .utils import OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT +from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, + OPENAI_IMAGENET_PROMPT_SUB) CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] PROTOTYPE_MAP = { 'imagenet': IMAGENET_SIMPLE_CATEGORIES, - 'cifar100': CIFAR100_CATEGORIES + 'cifar100': CIFAR100_CATEGORIES, } PROMPT_MAP = { 'openai_imagenet': OPENAI_IMAGENET_PROMPT, 'openai_cifar100': OPENAI_CIFAR100_PROMPT, - 'vanilla': [lambda c: f'a photo of a {c}'] + 'vanilla': [lambda c: f'a photo of a {c}'], + 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB } diff --git a/mmpretrain/models/multimodal/clip_zs/utils.py b/mmpretrain/models/multimodal/clip_zs/utils.py index f442019a1b9..65239bc37d6 100644 --- a/mmpretrain/models/multimodal/clip_zs/utils.py +++ b/mmpretrain/models/multimodal/clip_zs/utils.py @@ -21,6 +21,16 @@ lambda c: f'a photo of the big {c}.', ] +OPENAI_IMAGENET_PROMPT_SUB = [ + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +] + OPENAI_IMAGENET_PROMPT = [ lambda c: f'a bad photo of a {c}.', lambda c: f'a photo of many {c}.', From 5b087043f680ca3869850537f9e6f73d620bbe1f Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Fri, 11 Aug 2023 22:50:17 +0800 Subject: [PATCH 4/8] add some annotations doc --- configs/clip_zs/clip-vit-base-patch16_in1k.py | 2 +- configs/clip_zs/clip-vit-large-patch14_in1k.py | 2 +- mmpretrain/models/multimodal/clip_zs/clip.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip_zs/clip-vit-base-patch16_in1k.py index cc80ca4cb67..ab33bda78aa 100644 --- a/configs/clip_zs/clip-vit-base-patch16_in1k.py +++ b/configs/clip_zs/clip-vit-base-patch16_in1k.py @@ -64,6 +64,6 @@ transformer_width=512, proj_dim=512, text_prototype='imagenet', - text_prompt='openai_imagenet_sub', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub context_length=77, ) diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip_zs/clip-vit-large-patch14_in1k.py index b566381c139..66d0957c467 100644 --- a/configs/clip_zs/clip-vit-large-patch14_in1k.py +++ b/configs/clip_zs/clip-vit-large-patch14_in1k.py @@ -64,6 +64,6 @@ transformer_width=768, proj_dim=768, text_prototype='imagenet', - text_prompt='openai_imagenet_sub', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub context_length=77, ) diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py index 7aca462c559..877f6cf7af3 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -49,8 +49,9 @@ class CLIP_zs(BaseModel): proj_dim (int): Projection dimension for similarity computation. text_prototype (str): Text prototype, which can be a key in `PROTOTYPE_MAP` or list of text. - text_prompt (str): The prompt for text prototype. Defaults to 'openai'. - context_length (int): The context length to use. Defaults to 52. + text_prompt (str): The prompt for text prototype. + Defaults to 'vanilla',which refers to "a photo of {cls}". + context_length (int): The context length to use. Defaults to 77. data_preprocessor (Union[dict, nn.Module], optional): The config for preprocessing input data. If None or no specified type, it will use "MultiModalDataPreprocessor" as type. @@ -83,6 +84,8 @@ def __init__(self, self.context_length = context_length self.visual = MODELS.build(vision_backbone) + + # build attn_mask for casual-attn text_backbone['attn_mask'] = self.build_attention_mask() self.transformer = MODELS.build(text_backbone) From e92c1d7f2058f2f356bcb68cd552fa0189eeadda Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Sat, 19 Aug 2023 03:11:19 +0800 Subject: [PATCH 5/8] clip base class & clip_zs sub-class --- mmpretrain/models/multimodal/clip_zs/clip.py | 127 +++++++++++-------- 1 file changed, 77 insertions(+), 50 deletions(-) diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py index 877f6cf7af3..15b27894271 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod from typing import List, Optional, Tuple, Union import numpy as np @@ -38,8 +39,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return ret.type(orig_type) -@MODELS.register_module() -class CLIP_zs(BaseModel): +class CLIP(BaseModel): """The implementation of `ChineseCLIP `_. Args: @@ -68,8 +68,6 @@ def __init__(self, vocab_size: int, transformer_width: int, proj_dim: int, - text_prototype: Union[str, List[str]], - text_prompt: str = 'vanilla', context_length: int = 77, data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): @@ -103,16 +101,6 @@ def __init__(self, self.tokenizer = TOKENIZER.build(tokenizer) - # for zero-shot classification - if isinstance(text_prototype, - str) and text_prototype in PROTOTYPE_MAP.keys(): - self.prototype = PROTOTYPE_MAP[text_prototype] - else: - self.prototype = text_prototype - self.text_prototype_embeds = None - - self.prompt = PROMPT_MAP[text_prompt] - self.tokenizer.vocab = self.tokenizer.get_vocab( ) # CLIPTokenizer has no attribute named 'vocab', so manually @@ -233,6 +221,81 @@ def compute_similarity(self, images, texts): # shape (N, N) return logits_per_image, logits_per_text + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +@MODELS.register_module() +class CLIP_zs(CLIP): + + def __init__( + self, + vision_backbone: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + text_prototype: Union[str, List[str]] = 'imagenet', + text_prompt: str = 'vanilla', + ): + super(CLIP_zs, + self).__init__(vision_backbone, text_backbone, tokenizer, + vocab_size, transformer_width, proj_dim, + context_length, data_preprocessor, init_cfg) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + def predict(self, images: torch.Tensor, data_samples: DataSample = None) -> DataSample: @@ -291,39 +354,3 @@ def prepare_text_prototype(self, device) -> None: class_embeddings.append(class_feature) self.text_prototype_embeds = torch.stack( class_embeddings, dim=1).to(device) - - def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: - """Returns the tokenized representation of given input string(s) - - Args: - texts (Union[str, List[str]]): An input string or a list of input - strings to tokenize - context_length (int): The context length to use. Defaults to 52. - - Returns: - torch.Tensor: Resulting tokens. - """ - if isinstance(texts, str): - texts = [texts] - - all_tokens = [] - for text in texts: - # adapt the text to Chinese BERT vocab - # text = text.lower().replace('“', "\"").replace('”', "\"") - - # add special tokens - all_tokens.append( - [self.tokenizer.vocab['<|startoftext|>'] - ] + # <|startoftext|>代表[CLS] token - self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(text))[:self.context_length - 2] + - [self.tokenizer.vocab['<|endoftext|>']]) - - result = torch.zeros( - len(all_tokens), self.context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - assert len(tokens) <= self.context_length - result[i, :len(tokens)] = torch.tensor(tokens) - - return result From c48b3cd400ce6a984e6253a0192deb2e93348a60 Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Sun, 20 Aug 2023 03:45:29 +0800 Subject: [PATCH 6/8] some modifications of details after review --- configs/clip_zs/clip-vit-base-patch16_cifar100.py | 2 +- configs/clip_zs/clip-vit-base-patch16_in1k.py | 3 +-- configs/clip_zs/clip-vit-large-patch14_cifar100.py | 2 +- configs/clip_zs/clip-vit-large-patch14_in1k.py | 3 +-- mmpretrain/models/multimodal/clip_zs/__init__.py | 4 ++-- mmpretrain/models/multimodal/clip_zs/clip.py | 2 +- 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/configs/clip_zs/clip-vit-base-patch16_cifar100.py b/configs/clip_zs/clip-vit-base-patch16_cifar100.py index 81825aa3ae3..51e35f4812e 100644 --- a/configs/clip_zs/clip-vit-base-patch16_cifar100.py +++ b/configs/clip_zs/clip-vit-base-patch16_cifar100.py @@ -23,7 +23,7 @@ num_workers=8, dataset=dict( type='CIFAR100', - data_root='/public/DATA/qbw/img_cls_dataset/cifar100', + data_root='data/cifar100', split='test', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip_zs/clip-vit-base-patch16_in1k.py index ab33bda78aa..8aeb07484e1 100644 --- a/configs/clip_zs/clip-vit-base-patch16_in1k.py +++ b/configs/clip_zs/clip-vit-base-patch16_in1k.py @@ -24,8 +24,7 @@ num_workers=8, dataset=dict( type='ImageNet', - data_root= - '/public/DATA/qbw/img_cls_dataset/in1k/imagenet-1k-huggingface/data/', + data_root='data/imagenet', split='val', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), diff --git a/configs/clip_zs/clip-vit-large-patch14_cifar100.py b/configs/clip_zs/clip-vit-large-patch14_cifar100.py index 6c97a451874..043ee1f60c1 100644 --- a/configs/clip_zs/clip-vit-large-patch14_cifar100.py +++ b/configs/clip_zs/clip-vit-large-patch14_cifar100.py @@ -23,7 +23,7 @@ num_workers=8, dataset=dict( type='CIFAR100', - data_root='/public/DATA/qbw/img_cls_dataset/cifar100', + data_root='data/cifar100', split='test', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip_zs/clip-vit-large-patch14_in1k.py index 66d0957c467..a33bd68545d 100644 --- a/configs/clip_zs/clip-vit-large-patch14_in1k.py +++ b/configs/clip_zs/clip-vit-large-patch14_in1k.py @@ -24,8 +24,7 @@ num_workers=8, dataset=dict( type='ImageNet', - data_root= - '/public/DATA/qbw/img_cls_dataset/in1k/imagenet-1k-huggingface/data/', + data_root='data/imagenet', split='val', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), diff --git a/mmpretrain/models/multimodal/clip_zs/__init__.py b/mmpretrain/models/multimodal/clip_zs/__init__.py index 3f98bbd4ed6..a0aee351765 100644 --- a/mmpretrain/models/multimodal/clip_zs/__init__.py +++ b/mmpretrain/models/multimodal/clip_zs/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from ..clip_zs.clip import CLIP_zs +from ..clip_zs.clip import CLIP, CLIP_zs from ..clip_zs.clip_transformer import CLIPTransformer, CLIPVisionTransformer -__all__ = ['CLIP_zs', 'CLIPTransformer', 'CLIPVisionTransformer'] +__all__ = ['CLIP', 'CLIP_zs', 'CLIPTransformer', 'CLIPVisionTransformer'] diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py index 15b27894271..7f1c8de46c8 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -40,7 +40,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CLIP(BaseModel): - """The implementation of `ChineseCLIP `_. + """The implementation of `CLIP `_. Args: vision_backbone (dict): Config dict for vision backbone. From 3bc444239ca3f5aafaba19f996d0da89defaeb89 Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Mon, 21 Aug 2023 22:54:31 +0800 Subject: [PATCH 7/8] convert into and use mmpretrain-vit --- .../clip_zs/clip-vit-base-patch16_cifar100.py | 13 ++- configs/clip_zs/clip-vit-base-patch16_in1k.py | 13 ++- .../clip-vit-large-patch14_cifar100.py | 13 ++- .../clip_zs/clip-vit-large-patch14_in1k.py | 13 ++- mmpretrain/models/multimodal/__init__.py | 2 +- .../models/multimodal/clip_zs/__init__.py | 4 +- mmpretrain/models/multimodal/clip_zs/clip.py | 18 ++- .../multimodal/clip_zs/clip_transformer.py | 107 +++++++----------- .../openai-clip_to_mmpretrain-clip.py | 77 +++++++++++++ 9 files changed, 160 insertions(+), 100 deletions(-) create mode 100644 tools/model_converters/openai-clip_to_mmpretrain-clip.py diff --git a/configs/clip_zs/clip-vit-base-patch16_cifar100.py b/configs/clip_zs/clip-vit-base-patch16_cifar100.py index 51e35f4812e..dd1a08af8e5 100644 --- a/configs/clip_zs/clip-vit-base-patch16_cifar100.py +++ b/configs/clip_zs/clip-vit-base-patch16_cifar100.py @@ -39,14 +39,15 @@ model = dict( type='CLIP_zs', vision_backbone=dict( - type='CLIPVisionTransformer', - input_resolution=224, + type='VisionTransformer', + arch='base', + img_size=224, patch_size=16, - width=768, - layers=12, - heads=12, - output_dim=512, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), text_backbone=dict( type='CLIPTransformer', width=512, diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip_zs/clip-vit-base-patch16_in1k.py index 8aeb07484e1..6f28b9de80b 100644 --- a/configs/clip_zs/clip-vit-base-patch16_in1k.py +++ b/configs/clip_zs/clip-vit-base-patch16_in1k.py @@ -40,14 +40,15 @@ model = dict( type='CLIP_zs', vision_backbone=dict( - type='CLIPVisionTransformer', - input_resolution=224, + type='VisionTransformer', + arch='base', + img_size=224, patch_size=16, - width=768, - layers=12, - heads=12, - output_dim=512, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), text_backbone=dict( type='CLIPTransformer', width=512, diff --git a/configs/clip_zs/clip-vit-large-patch14_cifar100.py b/configs/clip_zs/clip-vit-large-patch14_cifar100.py index 043ee1f60c1..96c2d45cc30 100644 --- a/configs/clip_zs/clip-vit-large-patch14_cifar100.py +++ b/configs/clip_zs/clip-vit-large-patch14_cifar100.py @@ -39,14 +39,15 @@ model = dict( type='CLIP_zs', vision_backbone=dict( - type='CLIPVisionTransformer', - input_resolution=224, + type='VisionTransformer', + arch='large', + img_size=224, patch_size=14, - width=1024, - layers=24, - heads=16, - output_dim=768, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), text_backbone=dict( type='CLIPTransformer', width=768, diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip_zs/clip-vit-large-patch14_in1k.py index a33bd68545d..893538bcd12 100644 --- a/configs/clip_zs/clip-vit-large-patch14_in1k.py +++ b/configs/clip_zs/clip-vit-large-patch14_in1k.py @@ -40,14 +40,15 @@ model = dict( type='CLIP_zs', vision_backbone=dict( - type='CLIPVisionTransformer', - input_resolution=224, + type='VisionTransformer', + arch='large', + img_size=224, patch_size=14, - width=1024, - layers=24, - heads=16, - output_dim=768, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), text_backbone=dict( type='CLIPTransformer', width=768, diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index cb8acfc657d..b2d87d9d0ae 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -18,5 +18,5 @@ register_multimodal_placeholder([ 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP_zs' + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', 'CLIP_zs' ], MODELS) diff --git a/mmpretrain/models/multimodal/clip_zs/__init__.py b/mmpretrain/models/multimodal/clip_zs/__init__.py index a0aee351765..d46214a6211 100644 --- a/mmpretrain/models/multimodal/clip_zs/__init__.py +++ b/mmpretrain/models/multimodal/clip_zs/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from ..clip_zs.clip import CLIP, CLIP_zs -from ..clip_zs.clip_transformer import CLIPTransformer, CLIPVisionTransformer +from ..clip_zs.clip_transformer import CLIPProjection, CLIPTransformer -__all__ = ['CLIP', 'CLIP_zs', 'CLIPTransformer', 'CLIPVisionTransformer'] +__all__ = ['CLIP', 'CLIP_zs', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip_zs/clip.py index 7f1c8de46c8..3d2eaa5e405 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip_zs/clip.py @@ -63,6 +63,7 @@ class CLIP(BaseModel): def __init__(self, vision_backbone: dict, + projection: dict, text_backbone: dict, tokenizer: dict, vocab_size: int, @@ -81,10 +82,16 @@ def __init__(self, self.context_length = context_length + # build the vision transformer self.visual = MODELS.build(vision_backbone) + # build the visual projection + self.visual_proj = MODELS.build(projection) + # build attn_mask for casual-attn text_backbone['attn_mask'] = self.build_attention_mask() + + # build the text transformer self.transformer = MODELS.build(text_backbone) self.vocab_size = vocab_size @@ -163,8 +170,7 @@ def forward( def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: """The function to extract image latent features.""" - # return self.vision_backbone(images)[-1] @ self.vision_projection - return self.visual(images)[0] + return self.visual_proj(self.visual(images))[0] def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: """The function to extract text latent features.""" @@ -270,6 +276,7 @@ class CLIP_zs(CLIP): def __init__( self, vision_backbone: dict, + projection: dict, text_backbone: dict, tokenizer: dict, vocab_size: int, @@ -282,9 +289,10 @@ def __init__( text_prompt: str = 'vanilla', ): super(CLIP_zs, - self).__init__(vision_backbone, text_backbone, tokenizer, - vocab_size, transformer_width, proj_dim, - context_length, data_preprocessor, init_cfg) + self).__init__(vision_backbone, projection, text_backbone, + tokenizer, vocab_size, transformer_width, + proj_dim, context_length, data_preprocessor, + init_cfg) # for zero-shot classification if isinstance(text_prototype, diff --git a/mmpretrain/models/multimodal/clip_zs/clip_transformer.py b/mmpretrain/models/multimodal/clip_zs/clip_transformer.py index 3726239d10b..4b5f76661cb 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip_transformer.py +++ b/mmpretrain/models/multimodal/clip_zs/clip_transformer.py @@ -3,10 +3,11 @@ from typing import Optional, Tuple import torch +from mmengine.model import BaseModule from torch import nn -from mmpretrain.models.utils.clip_generator_helper import ( - LayerNorm, ResidualAttentionBlock) +from mmpretrain.models.utils.clip_generator_helper import \ + ResidualAttentionBlock from mmpretrain.registry import MODELS @@ -55,74 +56,44 @@ def forward( @MODELS.register_module() -class CLIPVisionTransformer(nn.Module): - """Vision Transformer for CLIP. +class CLIPProjection(BaseModule): + """Neck with CLIP Projection. Args: - input_resolution (int): The image size. - patch_size (int): The patch size. - width (int): The feature dimension. - layers (int): The number of layers. - heads (int): The number of attention heads. - out_dim (int): The output dimension. - fineturn (bool): Whether to fineturn the model. - average_target (bool): Whether to average the target. + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. """ def __init__(self, - input_resolution: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - finetune=False, - average_targets: int = 1) -> None: - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False) - - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn( - (input_resolution // patch_size)**2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = CLIPTransformer(width, layers, heads) - - self.finetune = finetune - if finetune is False: - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - self.average_targets = average_targets - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward function.""" - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], - -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([ - self.class_embedding.to(x.dtype) + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x - ], - dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x, attention, z = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - if self.proj is not None: - x = x @ self.proj - - return x, attention + in_channels: int, + out_channels: int, + init_cfg: Optional[dict] = None): + super(CLIPProjection, self).__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + scale = in_channels**-0.5 + self.proj = nn.Parameter(scale * + torch.randn(in_channels, out_channels)) + + def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Tuple): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + Returns: + Tuple(torch.Tensor)): A tuple of reducted features. + """ + if isinstance(inputs, tuple): + inputs = inputs[-1] + out = inputs @ self.proj + elif isinstance(inputs, torch.Tensor): + out = inputs @ self.proj + else: + raise TypeError( + '`CLIPProjection` neck inputs should be tuple or torch.tensor') + return (out, ) diff --git a/tools/model_converters/openai-clip_to_mmpretrain-clip.py b/tools/model_converters/openai-clip_to_mmpretrain-clip.py new file mode 100644 index 00000000000..72725502551 --- /dev/null +++ b/tools/model_converters/openai-clip_to_mmpretrain-clip.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_clip(ckpt): + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('visual.conv1'): + new_k = k.replace('conv1', 'patch_embed.projection') + elif k.startswith('visual.positional_embedding'): + new_k = k.replace('positional_embedding', 'pos_embed') + new_v = v.unsqueeze(dim=0) + elif k.startswith('visual.class_embedding'): + new_k = k.replace('class_embedding', 'cls_token') + new_v = v.unsqueeze(dim=0).unsqueeze(dim=0) + elif k.startswith('visual.ln_pre'): + new_k = k.replace('ln_pre', 'pre_norm') + elif k.startswith('visual.transformer.resblocks'): + new_k = k.replace('transformer.resblocks', 'layers') + if 'ln_1' in k: + new_k = new_k.replace('ln_1', 'ln1') + elif 'ln_2' in k: + new_k = new_k.replace('ln_2', 'ln2') + elif 'mlp.c_fc' in k: + new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0') + elif 'mlp.c_proj' in k: + new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1') + elif 'attn.in_proj_weight' in k: + new_k = new_k.replace('in_proj_weight', 'qkv.weight') + elif 'attn.in_proj_bias' in k: + new_k = new_k.replace('in_proj_bias', 'qkv.bias') + elif 'attn.out_proj' in k: + new_k = new_k.replace('out_proj', 'proj') + elif k.startswith('visual.ln_post'): + new_k = k.replace('ln_post', 'ln1') + elif k.startswith('visual.proj'): + new_k = k.replace('visual.proj', 'visual_proj.proj') + else: + new_k = k + + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained clip ' + 'models to mmpretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_clip(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main() From be23c550bdd62d23a1b7dad4beb876e4cad658b9 Mon Sep 17 00:00:00 2001 From: HustQBW <995020860@qq.com> Date: Fri, 1 Sep 2023 01:50:03 +0800 Subject: [PATCH 8/8] modify names of some files and directories --- .../clip_vit-base-p16_zeroshot-cls_cifar100.py} | 2 +- .../clip_vit-base-p16_zeroshot-cls_in1k.py} | 2 +- .../clip_vit-large-p14_zeroshot-cls_cifar100.py} | 2 +- .../clip_vit-large-p14_zeroshot-cls_in1k.py} | 2 +- mmpretrain/models/multimodal/__init__.py | 5 +++-- mmpretrain/models/multimodal/clip/__init__.py | 5 +++++ mmpretrain/models/multimodal/{clip_zs => clip}/clip.py | 4 ++-- .../models/multimodal/{clip_zs => clip}/clip_transformer.py | 0 mmpretrain/models/multimodal/{clip_zs => clip}/utils.py | 0 mmpretrain/models/multimodal/clip_zs/__init__.py | 5 ----- 10 files changed, 14 insertions(+), 13 deletions(-) rename configs/{clip_zs/clip-vit-base-patch16_cifar100.py => clip/clip_vit-base-p16_zeroshot-cls_cifar100.py} (98%) rename configs/{clip_zs/clip-vit-base-patch16_in1k.py => clip/clip_vit-base-p16_zeroshot-cls_in1k.py} (98%) rename configs/{clip_zs/clip-vit-large-patch14_cifar100.py => clip/clip_vit-large-p14_zeroshot-cls_cifar100.py} (98%) rename configs/{clip_zs/clip-vit-large-patch14_in1k.py => clip/clip_vit-large-p14_zeroshot-cls_in1k.py} (98%) create mode 100644 mmpretrain/models/multimodal/clip/__init__.py rename mmpretrain/models/multimodal/{clip_zs => clip}/clip.py (99%) rename mmpretrain/models/multimodal/{clip_zs => clip}/clip_transformer.py (100%) rename mmpretrain/models/multimodal/{clip_zs => clip}/utils.py (100%) delete mode 100644 mmpretrain/models/multimodal/clip_zs/__init__.py diff --git a/configs/clip_zs/clip-vit-base-patch16_cifar100.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py similarity index 98% rename from configs/clip_zs/clip-vit-base-patch16_cifar100.py rename to configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py index dd1a08af8e5..dd684a50a31 100644 --- a/configs/clip_zs/clip-vit-base-patch16_cifar100.py +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py @@ -37,7 +37,7 @@ # model settings model = dict( - type='CLIP_zs', + type='CLIPZeroShot', vision_backbone=dict( type='VisionTransformer', arch='base', diff --git a/configs/clip_zs/clip-vit-base-patch16_in1k.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py similarity index 98% rename from configs/clip_zs/clip-vit-base-patch16_in1k.py rename to configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py index 6f28b9de80b..80c4fde82f5 100644 --- a/configs/clip_zs/clip-vit-base-patch16_in1k.py +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py @@ -38,7 +38,7 @@ # model settings model = dict( - type='CLIP_zs', + type='CLIPZeroShot', vision_backbone=dict( type='VisionTransformer', arch='base', diff --git a/configs/clip_zs/clip-vit-large-patch14_cifar100.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py similarity index 98% rename from configs/clip_zs/clip-vit-large-patch14_cifar100.py rename to configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py index 96c2d45cc30..a6dd7c11412 100644 --- a/configs/clip_zs/clip-vit-large-patch14_cifar100.py +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py @@ -37,7 +37,7 @@ # model settings model = dict( - type='CLIP_zs', + type='CLIPZeroShot', vision_backbone=dict( type='VisionTransformer', arch='large', diff --git a/configs/clip_zs/clip-vit-large-patch14_in1k.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py similarity index 98% rename from configs/clip_zs/clip-vit-large-patch14_in1k.py rename to configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py index 893538bcd12..10500017a93 100644 --- a/configs/clip_zs/clip-vit-large-patch14_in1k.py +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py @@ -38,7 +38,7 @@ # model settings model = dict( - type='CLIP_zs', + type='CLIPZeroShot', vision_backbone=dict( type='VisionTransformer', arch='large', diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index b2d87d9d0ae..73645f0f5e6 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -5,7 +5,7 @@ from .blip import * # noqa: F401,F403 from .blip2 import * # noqa: F401,F403 from .chinese_clip import * # noqa: F401, F403 - from .clip_zs import * # noqa: F401, F403 + from .clip import * # noqa: F401, F403 from .flamingo import * # noqa: F401, F403 from .llava import * # noqa: F401, F403 from .minigpt4 import * # noqa: F401, F403 @@ -18,5 +18,6 @@ register_multimodal_placeholder([ 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', 'CLIP_zs' + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', + 'CLIPZeroShot' ], MODELS) diff --git a/mmpretrain/models/multimodal/clip/__init__.py b/mmpretrain/models/multimodal/clip/__init__.py new file mode 100644 index 00000000000..f7a117ea7ca --- /dev/null +++ b/mmpretrain/models/multimodal/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip.clip import CLIP, CLIPZeroShot +from ..clip.clip_transformer import CLIPProjection, CLIPTransformer + +__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip_zs/clip.py b/mmpretrain/models/multimodal/clip/clip.py similarity index 99% rename from mmpretrain/models/multimodal/clip_zs/clip.py rename to mmpretrain/models/multimodal/clip/clip.py index 3d2eaa5e405..b509a63b3be 100644 --- a/mmpretrain/models/multimodal/clip_zs/clip.py +++ b/mmpretrain/models/multimodal/clip/clip.py @@ -271,7 +271,7 @@ def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: @MODELS.register_module() -class CLIP_zs(CLIP): +class CLIPZeroShot(CLIP): def __init__( self, @@ -288,7 +288,7 @@ def __init__( text_prototype: Union[str, List[str]] = 'imagenet', text_prompt: str = 'vanilla', ): - super(CLIP_zs, + super(CLIPZeroShot, self).__init__(vision_backbone, projection, text_backbone, tokenizer, vocab_size, transformer_width, proj_dim, context_length, data_preprocessor, diff --git a/mmpretrain/models/multimodal/clip_zs/clip_transformer.py b/mmpretrain/models/multimodal/clip/clip_transformer.py similarity index 100% rename from mmpretrain/models/multimodal/clip_zs/clip_transformer.py rename to mmpretrain/models/multimodal/clip/clip_transformer.py diff --git a/mmpretrain/models/multimodal/clip_zs/utils.py b/mmpretrain/models/multimodal/clip/utils.py similarity index 100% rename from mmpretrain/models/multimodal/clip_zs/utils.py rename to mmpretrain/models/multimodal/clip/utils.py diff --git a/mmpretrain/models/multimodal/clip_zs/__init__.py b/mmpretrain/models/multimodal/clip_zs/__init__.py deleted file mode 100644 index d46214a6211..00000000000 --- a/mmpretrain/models/multimodal/clip_zs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from ..clip_zs.clip import CLIP, CLIP_zs -from ..clip_zs.clip_transformer import CLIPProjection, CLIPTransformer - -__all__ = ['CLIP', 'CLIP_zs', 'CLIPTransformer', 'CLIPProjection']