Skip to content

Commit

Permalink
[test bugs] uda_mvtec (PaddlePaddle#3133)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sunting78 authored Apr 7, 2023
1 parent 3a94093 commit 80f9551
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions contrib/QualityInspector/qinspector/uad/datasets/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self,
class_name='bottle',
is_train=True,
resize=[256, 256],
cropsize=[224, 224]):
cropsize=[224, 224],
is_predict=False):
assert class_name in CLASS_NAMES, 'class_name: {}, should be in {}'.format(
class_name, CLASS_NAMES)
self.dataset_root_path = dataset_root_path
Expand All @@ -49,7 +50,7 @@ def __init__(self,
self.cropsize = cropsize

# load dataset
if is_train:
if not is_predict:
self.x, self.y, self.mask = self.load_dataset_folder()

# set transforms
Expand Down
2 changes: 1 addition & 1 deletion contrib/QualityInspector/tools/uad/padim/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
print("Testing model for {} with sigle picture".format(class_name))

# build datasets
MVTecDataset = mvtec.MVTecDataset(is_train=False)
MVTecDataset = mvtec.MVTecDataset(is_predict=True)
transform_x = MVTecDataset.get_transform_x()
x = Image.open(args.img_path).convert('RGB')
x = transform_x(x).unsqueeze(0)
Expand Down
2 changes: 1 addition & 1 deletion contrib/QualityInspector/tools/uad/patchcore/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def main():
model.eval()

# build data
MVTecDataset = mvtec.MVTecDataset(is_train=False)
MVTecDataset = mvtec.MVTecDataset(is_predict=True)
transform_x = MVTecDataset.get_transform_x()
x = Image.open(args.img_path).convert('RGB')
x = transform_x(x).unsqueeze(0)
Expand Down

0 comments on commit 80f9551

Please sign in to comment.