TrOCR has not yet released for French, so we trained a French model for PoC purpose. Based on this model, it is recommended to collect more data to additionally train the 1st stage or perform fine-tuning as the 2nd stage.
It's a special case of the English trOCR model introduced in the paper TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models by Li et al. and first released in this repository
This was possible thanks to daekun-ml and Niels Rogge than enabled us to publish this model with their tutorials and code.
We created training data of ~723k examples by taking random samples of the following datasets:
- MultiLegalPile - 90k
- French book Reviews - 20k
- WikiNeural - 83k
- Multilingual cc news - 119k
- Reviews Amazon Multi - 153k
- Opus Book - 70k
- BerlinText - 38k
We collected parts of each of the datasets and then cut randomly the sentences to collect the final training set.
Image data was generated with TextRecognitionDataGenerator (https://github.com/Belval/TextRecognitionDataGenerator) introduced in the TrOCR paper. Below is a code snippet for generating images.
python3 ./trdg/run.py -i ocr_dataset_poc.txt -w 5 -t {num_cores} -f 64 -l ko -c {num_samples} -na 2 --output_dir {dataset_dir}
The encoder model used facebook/deit-base-distilled-patch16-384
and the decoder model used camembert-base
. It is easier than training by starting weights from microsoft/trocr-base-stage1
.
We used heuristic parameters without separate hyperparameter tuning.
- learning_rate = 4e-5
- epochs = 25
- fp16 = True
- max_length = 32
For the dev set we got those results
- size of the test set: 72k examples
- CER: 0.13
- WER: 0.26
- Val Loss: 0.424
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer
import requests
from io import BytesIO
from PIL import Image
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("agomberto/trocr-base-printed-fr")
tokenizer = AutoTokenizer.from_pretrained("agomberto/trocr-base-printed-fr")
url = "https://github.com/agombert/trocr-base-printed-fr/blob/main/sample_imgs/0.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
pixel_values = processor(img, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values, max_length=32)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)