Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Many short rows with financial table structure model #178

Open
abielr opened this issue Apr 10, 2024 · 2 comments
Open

Many short rows with financial table structure model #178

abielr opened this issue Apr 10, 2024 · 2 comments

Comments

@abielr
Copy link

abielr commented Apr 10, 2024

I am finding that when using the latest TATR-v1.1-Fin model that in the initial TSR phase the model detects many very thin rows in the middle of a table, even when the table seems simple and is from the original training set. However, the TATR-v1.1-All model is working fine on the same table. I'm not sure if I'm doing something incorrect here in preprocessing the data?

Screenshots below demonstrate the problem on the image AAL_2014_page_192_table_0.jpg from the FinTabNet.c training set. Example code below can be used to re-create the issue, just toggle between model_type = 'all' and model_type = 'fin' at the top of the code

Original table:
image

Fin model:
image

All model:
image

from PIL import Image, ImageDraw
from pathlib import Path
import torch
from transformers import TableTransformerForObjectDetection
from torchvision import transforms

model_type = 'fin' # can be 'fin' or 'all'
PATH = Path(__file__).parent.resolve()
cropped_table = Image.open(PATH / "data/AAL_2014_page_192_table_0.jpg").convert("RGB")

#############
# Setup and utility functions
#############

device = "cuda" if torch.cuda.is_available() else "cpu"

class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
        
        return resized_image
    
structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

structure_class_thresholds = {
    "table": 0.5,
    "table column": 0.5,
    "table row": 0.5,
    "table column header": 0.5,
    "table projected row header": 0.5,
    "table spanning cell": 0.5,
    "no object": 10
}

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def outputs_to_objects(outputs, img_size, class_idx2name):
    m = outputs['logits'].softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = class_idx2name[int(label)]
        if not class_label == 'no object':
            objects.append({'label': class_label, 'score': float(score),
                            'bbox': [float(elem) for elem in bbox]})

    return objects

def draw_bboxes(bboxes, page_image, color='red'):
    page_image = page_image.copy()
    draw = ImageDraw.Draw(page_image)

    for bbox in bboxes:
        draw.rectangle(bbox, outline=color)

    return page_image

structure_models = {
    'all': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all"),
    'fin': TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-fin")
}
for k in structure_models:
    structure_models[k] = structure_models[k].to(device)
structure_id2label = structure_models['all'].config.id2label
structure_id2label[len(structure_id2label)] = "no object"

#############
# Running the model
#############

pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)

with torch.no_grad():
  structure_outputs = structure_models[model_type](pixel_values)

structure_outputs = outputs_to_objects(structure_outputs, cropped_table.size, structure_id2label)

draw_bboxes([x['bbox'] for x in structure_outputs], cropped_table)
@Oleksii94
Copy link

@NielsRogge
Copy link

That's an interesting, weird issue. Could NMS (non-maximum suppression) help here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants