Skip to content

Commit

Permalink
feat(CLI): Creating a new CLI to view checkpoint's info.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Dec 5, 2023
1 parent e00cf32 commit 4b756da
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 0 deletions.
167 changes: 167 additions & 0 deletions everyvoice/base_cli/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
CLI command to inspect EveryVoice's checkpoints.
"""
import json
import sys
import warnings
from enum import Enum
from json import JSONEncoder
from pathlib import Path
from typing import Any, Dict

import typer
import yaml
from pydantic import BaseModel
from typing_extensions import Annotated

from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.model import (
FastSpeech2,
)
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.model import HiFiGAN

app = typer.Typer(
pretty_exceptions_show_locals=False,
help="Extract checkpoint's hyperparameters.",
)


class ExportType(str, Enum):
"""
Available export format for the configuration.
"""

JSON = "json"
YAML = "yaml"


class CheckpointEncoder(JSONEncoder):
"""
Helper JSON Encoder for missing `torch.Tensor` & `pydantic.BaseModel`.
"""

def default(self, obj: Any):
"""
Extends json to handle `torch.Tensor` and `pydantic.BaseModel`.
"""
import torch

if isinstance(obj, torch.Tensor):
return list(obj.shape)
elif isinstance(obj, BaseModel):
return json.loads(obj.json())
return super().default(obj)


def load_checkpoint(model_path: Path) -> Dict[str, Any]:
"""
Loads a checkpoint and performs minor clean up of the checkpoint.
Removes the `optimizer_states`'s `state` and `param_groups`'s `params`.
Removes `state_dict` from the checkpoint.
"""
import torch

checkpoint = torch.load(str(model_path), map_location=torch.device("cpu"))

# Some clean up of useless stuff.
if "optimizer_states" in checkpoint:
for optimizer in checkpoint["optimizer_states"]:
# Delete the optimizer history values.
if "state" in optimizer:
del optimizer["state"]
# These are simply values [0, len(checkpoint["optimizer_states"][0]["state"])].
for param_group in optimizer["param_groups"]:
if "params" in param_group:
del param_group["params"]

if "state_dict" in checkpoint:
del checkpoint["state_dict"]

if "loops" in checkpoint:
del checkpoint["loops"]

return checkpoint


@app.command()
def inspect(
model_path: Path = typer.Argument(
...,
exists=True,
dir_okay=False,
file_okay=True,
help="The path to your model checkpoint file.",
),
export_type: ExportType = ExportType.YAML,
show_config: Annotated[
bool,
typer.Option(
"--show-config/--no-show-config", # noqa
"-c/-C", # noqa
help="Show the configuration used during training in either json or yaml format", # noqa
),
] = True,
show_architecture: Annotated[
bool,
typer.Option(
"--show-architecture/--no-show-architecture", # noqa
"-a/-A", # noqa
help="Show the model's architecture", # noqa
),
] = True,
show_weights: Annotated[
bool,
typer.Option(
"--show-weights/--no-show-weights", # noqa
"-w/-W", # noqa
help="Show the number of weights per layer", # noqa
),
] = True,
):
"""
Given an EveryVoice checkpoint, show information about the configuration
used during training, the model's architecture and the number of weights
per layer and total weight count.
"""
checkpoint = load_checkpoint(model_path)

if show_config:
print("Configs:")
if export_type is ExportType.JSON:
json.dump(
checkpoint,
sys.stdout,
ensure_ascii=False,
indent=2,
cls=CheckpointEncoder,
)
elif export_type is ExportType.YAML:
output = json.loads(json.dumps(checkpoint, cls=CheckpointEncoder))
yaml.dump(output, stream=sys.stdout)
else:
raise NotImplementedError(f"Unsupported export type {export_type}!")

if show_architecture:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
model = HiFiGAN.load_from_checkpoint(model_path)
# NOTE if ANY exception is raise, that means the model couldn't be
# loaded and we want to try another config type. This is to "ask
# forgiveness, not permission".
except Exception:
try:
model = FastSpeech2.load_from_checkpoint(model_path)
except Exception:
raise NotImplementedError(
"Your checkpoint contains a model type that is not yet supported!"
)
print("\n\nModel Architecture:\n", model, sep="")

if show_weights:
from torchinfo import summary

statistics = summary(model, None, verbose=0)
print("\nModel's Weights:\n", statistics)
# According to Aidan (1, 80, 50) should be a valid input size but it looks
# like the model is expecting a Dict which isn't supported by torchsummary.
# print(summary(model, (1, 80, 50)))
7 changes: 7 additions & 0 deletions everyvoice/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import (
align_single as ctc_segment,
)
from everyvoice.model.e2e.config import EveryVoiceConfig
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
from everyvoice.base_cli.checkpoint import inspect as inspect_checkpoint
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.cli import (
preprocess as preprocess_fs2,
)
Expand Down Expand Up @@ -201,6 +204,10 @@ def new_project():
short_help="Synthesize using your pre-trained EveryVoice models",
)

app.command(
name="inspect-checkpoint",
short_help="Extract structural information from a checkpoint",
)(inspect_checkpoint)

class TestSuites(str, Enum):
all = "all"
Expand Down
6 changes: 6 additions & 0 deletions everyvoice/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def setUp(self) -> None:
"train",
"synthesize",
"preprocess",
"inspect-checkpoint",
]

def test_commands_present(self):
Expand All @@ -46,6 +47,11 @@ def test_update_schema(self):
)
)

def test_inspect_checkpoint(self):
result = self.runner.invoke(app, ["inspect-checkpoint", "--help"])
self.assertIn("inspect-checkpoint [OPTIONS] MODEL_PATH",
result.stdout)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ simple-term-menu==1.5.2
setuptools==59.5.0 # https://github.com/pytorch/pytorch/issues/69894
tabulate==0.8.10
tensorboard>=2.14.1
torchinfo==1.8.0
typer[all]>=0.9.0

0 comments on commit 4b756da

Please sign in to comment.