Skip to content

Commit a9a09d6

Browse files
committed
Fix yaml import
1 parent 0933179 commit a9a09d6

6 files changed

+26
-9
lines changed

demo.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import os
22
import argparse
33
import torch
4-
import ruamel_yaml as yaml
4+
try:
5+
import ruamel_yaml as yaml
6+
except ModuleNotFoundError:
7+
import ruamel.yaml as yaml
8+
59

610
from model.prismer_caption import PrismerCaption
711
from dataset import create_dataset, create_loader

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ pycocotools
1515
geffnet
1616
fire
1717
huggingface_hub
18-
rich
18+
rich
19+
ruamel.yaml

train_caption.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
# https://github.com/NVlabs/prismer/blob/main/LICENSE
66

77
import argparse
8-
import ruamel_yaml as yaml
98
import numpy as np
109
import random
1110
import time
1211
import functools
1312
import json
1413
import torch
1514
import os
16-
15+
try:
16+
import ruamel_yaml as yaml
17+
except ModuleNotFoundError:
18+
import ruamel.yaml as yaml
19+
1720
from accelerate import Accelerator, FullyShardedDataParallelPlugin
1821
from model.prismer_caption import PrismerCaption
1922
from model.modules.utils import interpolate_pos_embed

train_classification.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# https://github.com/NVlabs/prismer/blob/main/LICENSE
66

77
import argparse
8-
import ruamel_yaml as yaml
98
import numpy as np
109
import random
1110
import time
1211
import functools
1312
import torch
13+
try:
14+
import ruamel_yaml as yaml
15+
except ModuleNotFoundError:
16+
import ruamel.yaml as yaml
1417

1518
from accelerate import Accelerator, FullyShardedDataParallelPlugin
1619
from model.prismer_caption import PrismerCaption

train_pretrain.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
# https://github.com/NVlabs/prismer/blob/main/LICENSE
66

77
import argparse
8-
import ruamel_yaml as yaml
98
import numpy as np
109
import random
1110
import time
1211
import datetime
1312
import functools
1413
import torch
15-
14+
try:
15+
import ruamel_yaml as yaml
16+
except ModuleNotFoundError:
17+
import ruamel.yaml as yaml
18+
1619
from accelerate import Accelerator, FullyShardedDataParallelPlugin
1720
from model.prismer_caption import PrismerCaption
1821
from dataset import create_dataset, create_loader

train_vqa.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
# https://github.com/NVlabs/prismer/blob/main/LICENSE
66

77
import argparse
8-
import ruamel_yaml as yaml
98
import numpy as np
109
import random
1110
import time
1211
import datetime
1312
import functools
1413
import torch
15-
14+
try:
15+
import ruamel_yaml as yaml
16+
except ModuleNotFoundError:
17+
import ruamel.yaml as yaml
18+
1619
from accelerate import Accelerator, FullyShardedDataParallelPlugin
1720
from model.prismer_vqa import PrismerVQA
1821
from model.modules.utils import interpolate_pos_embed

0 commit comments

Comments
 (0)