From 6dd81bcc424c00204a608b4b1f88cbb4a8d5de1b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 17 Sep 2021 10:46:57 +0100 Subject: [PATCH] Better styling --- pytorch_lightning/callbacks/rich_model_summary.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index a2981829e04bc..0751bca92cf8f 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -75,18 +75,17 @@ def summarize( console = Console() - table = Table(title="Model Summary") - - table.add_column(" ") - table.add_column("Name", justify="left", style="cyan", no_wrap=True) - table.add_column("Type", style="magenta") - table.add_column("Params", justify="right", style="green") + table = Table(header_style="bold magenta") + table.add_column(" ", style="dim") + table.add_column("Name", justify="left", no_wrap=True) + table.add_column("Type") + table.add_column("Params", justify="right") column_names = list(zip(*summary_data))[0] for column_name in ["In sizes", "Out sizes"]: if column_name in column_names: - table.add_column(column_name, justify="right", style="green") + table.add_column(column_name, justify="right", style="white") rows = list(zip(*(arr[1] for arr in summary_data))) for row in rows: