Skip to content

Commit

Permalink
lab 9
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeyk committed Apr 16, 2021
1 parent 7c5ef42 commit ae4e75d
Show file tree
Hide file tree
Showing 87 changed files with 18,664 additions and 104 deletions.
12 changes: 9 additions & 3 deletions lab1/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
12 changes: 9 additions & 3 deletions lab2/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
12 changes: 9 additions & 3 deletions lab3/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
6 changes: 5 additions & 1 deletion lab4/text_recognizer/lit_models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch.nn as nn
import wandb
try:
import wandb
except ModuleNotFoundError:
pass


from .metrics import CharacterErrorRate
from .base import BaseLitModel
Expand Down
12 changes: 9 additions & 3 deletions lab4/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
6 changes: 5 additions & 1 deletion lab5/text_recognizer/lit_models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch.nn as nn
import wandb
try:
import wandb
except ModuleNotFoundError:
pass


from .metrics import CharacterErrorRate
from .base import BaseLitModel
Expand Down
12 changes: 9 additions & 3 deletions lab5/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
6 changes: 5 additions & 1 deletion lab7/text_recognizer/lit_models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch.nn as nn
import wandb
try:
import wandb
except ModuleNotFoundError:
pass


from .metrics import CharacterErrorRate
from .base import BaseLitModel
Expand Down
2 changes: 1 addition & 1 deletion lab7/text_recognizer/models/resnet_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def predict(self, x: torch.Tensor) -> torch.Tensor:
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
output_tokens[:, Sy : Sy + 1] = output[-1:] # Set the last output token

# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
Expand Down
44 changes: 35 additions & 9 deletions lab7/text_recognizer/paragraph_text_recognizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import argparse
from pathlib import Path
from typing import Sequence
from typing import Sequence, Union
import argparse
import json

from PIL import Image
import torch

from text_recognizer.data import IAMParagraphs
from text_recognizer.data.iam_paragraphs import resize_image, IMAGE_SCALE_FACTOR, get_transform
from text_recognizer.models import ResnetTransformer
from text_recognizer.lit_models import TransformerLitModel
from text_recognizer.models import ResnetTransformer
import text_recognizer.util as util


Expand All @@ -33,19 +35,43 @@ def __init__(self):
checkpoint_path=CONFIG_AND_WEIGHTS_DIRNAME / "model.pt", args=args, model=model
)
self.lit_model.eval()
self.scripted_model = self.lit_model.to_torchscript(method="script", file_path=None)

@torch.no_grad()
def predict(self, image_filename: Path) -> str:
"""Predict/infer text in input image filename."""
pil_img = util.read_image_pil(image_filename, grayscale=True)
pil_img = resize_image(pil_img, IMAGE_SCALE_FACTOR) # ideally resize should have been part of transform
img_tensor = self.transform(pil_img)
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict/infer text in input image (which can be a file path)."""
image_pil = image
if not isinstance(image, Image.Image):
image_pil = util.read_image_pil(image, grayscale=True)

image_pil = resize_image(image_pil, IMAGE_SCALE_FACTOR)
image_tensor = self.transform(image_pil)

y_pred = self.lit_model(img_tensor.unsqueeze(axis=0))[0]
y_pred = self.scripted_model(image_tensor.unsqueeze(axis=0))[0]
pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens)

return pred_str


def convert_y_label_to_string(y: torch.Tensor, mapping: Sequence[str], ignore_tokens: Sequence[int]) -> str:
return "".join([mapping[i] for i in y if i not in ignore_tokens])


def main():
"""
Example runs:
```
python text_recognizer/paragraph_text_recognizer.py text_recognizer/tests/support/paragraphs/a01-077.png
python text_recognizer/paragraph_text_recognizer.py https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png
"""
parser = argparse.ArgumentParser(description="Recognize handwritten text in an image file.")
parser.add_argument("filename", type=str)
args = parser.parse_args()

text_recognizer = ParagraphTextRecognizer()
pred_str = text_recognizer.predict(args.filename)
print(pred_str)


if __name__ == "__main__":
main()
12 changes: 9 additions & 3 deletions lab7/text_recognizer/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Utility functions for text_recognizer module."""
from io import BytesIO
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve

# import base64
import base64
import hashlib

from PIL import Image
from tqdm import tqdm
import numpy as np
import smart_open


def to_categorical(y, num_classes):
Expand All @@ -17,7 +18,12 @@ def to_categorical(y, num_classes):


def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with Image.open(image_uri) as image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)


def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
Expand Down
Loading

0 comments on commit ae4e75d

Please sign in to comment.