Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
2 changes: 0 additions & 2 deletions conf/dataset/your_dataset.yml

This file was deleted.

25 changes: 0 additions & 25 deletions conf/model/HarmonicCNN.yml

This file was deleted.

12 changes: 12 additions & 0 deletions conf/mtat/model/HarmonicCNN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
version: harmoniccnn
type: HarmoincCNN
params:
# CNN Parameters
n_channels: 128
sample_rate: 16000
n_fft : 513
n_mels : 128
n_class : 50
n_harmonic : 6
semitone_scale : 2
learn_bw : only_Q
11 changes: 11 additions & 0 deletions conf/mtat/pipeline/pv00.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: pv00
type: DataPipeline
dataset:
type: MTATDataset
path: ../dataset/mtat
input_length: 80000
dataloader:
type: DataLoader
params:
batch_size: 16
num_workers: 8
19 changes: 19 additions & 0 deletions conf/mtat/runner/rv00.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
version: rv00
type: AutotaggingRunner
optimizer:
type: Adam
params:
learning_rate: 1e-5
scale_factor: 5
scheduler:
type: ExponentialLR
params:
gamma: 0.95
trainer:
type: Trainer
params:
max_epochs: 100
gpus: 1
distributed_backend: dp # train.py: ddp, evaluate.py: dp
benchmark: False
deterministic: True
92 changes: 88 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,88 @@
"""
This script was made by Nick at 19/07/20.
To implement code for evaluating your model.
"""
from argparse import ArgumentParser, Namespace
import json
from pathlib import Path

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer, seed_everything
import torch

from src.model.net import HarmonicCNN
from src.task.pipeline import DataPipeline
from src.task.runner import AutotaggingRunner


def get_config(args: Namespace) -> DictConfig:
parent_config_dir = Path("conf")
child_config_dir = parent_config_dir / args.dataset
model_config_dir = child_config_dir / "model"
pipeline_config_dir = child_config_dir / "pipeline"
runner_config_dir = child_config_dir / "runner"

config = OmegaConf.create()
model_config = OmegaConf.load(model_config_dir / f"{args.model}.yaml")
pipeline_config = OmegaConf.load(pipeline_config_dir / f"{args.pipeline}.yaml")
runner_config = OmegaConf.load(runner_config_dir / f"{args.runner}.yaml")
config.update(model=model_config, pipeline=pipeline_config, runner=runner_config)
return config

def main(args) -> None:
seed_everything(42)
config = get_config(args)

# prepare dataloader
pipeline = DataPipeline(pipline_config=config.pipeline)

dataset = pipeline.get_dataset(
pipeline.dataset_builder,
config.pipeline.dataset.path,
args.type,
config.pipeline.dataset.input_length
)
dataloader = pipeline.get_dataloader(
dataset,
shuffle=False,
drop_last=True,
**pipeline.pipeline_config.dataloader.params,
)
model = HarmonicCNN(**config.model.params)
runner = AutotaggingRunner(model, config.runner)

checkpoint_path = (
f"exp/{args.dataset}/{args.model}/{args.runner}/{args.checkpoint}.ckpt"
)
state_dict = torch.load(checkpoint_path)
runner.load_state_dict(state_dict.get("state_dict"))

trainer = Trainer(
**config.runner.trainer.params, logger=False, checkpoint_callback=False
)
results_path = Path(f"exp/{args.dataset}/{args.model}/{args.runner}/results.json")

if results_path.exists():
with open(results_path, mode="r") as io:
results = json.load(io)

result = trainer.test(runner, test_dataloaders=dataloader)
results.update({"checkpoint": args.checkpoint, f"{args.type}": result})

else:
results = {}
result = trainer.test(runner, test_dataloaders=dataloader)
results.update({"checkpoint": args.checkpoint, f"{args.type}": result})
Comment on lines +61 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이 구문의 의도가 무엇인지 알 수 있을까요?


with open(
f"exp/{args.dataset}/{args.model}/{args.runner}/results.json", mode="w"
) as io:
json.dump(results, io, indent=4)

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--type", default="TEST", type=str, choices=["TRAIN", "VALID", "TEST"])
parser.add_argument("--model", default="HarmonicCNN", type=str)
parser.add_argument("--dataset", default="mtat", type=str, choices=["mtat"])
parser.add_argument("--pipeline", default="pv00", type=str)
parser.add_argument("--runner", default="rv00", type=str)
parser.add_argument("--reproduce", default=False, action="store_true")
parser.add_argument("--checkpoint", default="epoch=37-roc_auc=0.8806-pr_auc=0.3905", type=str)
args = parser.parse_args()
main(args)
Binary file not shown.
8 changes: 8 additions & 0 deletions exp/mtat/HarmonicCNN/rv00/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
benchmark: false
deterministic: true
distributed_backend: dp
gamma: 0.95
gpus: 1
learning_rate: 1.0e-05
max_epochs: 100
scale_factor: 5
10 changes: 10 additions & 0 deletions exp/mtat/HarmonicCNN/rv00/results.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"checkpoint": "epoch=37-roc_auc=0.8806-pr_auc=0.3905",
"TEST": [
{
"val_loss": 0.15560948848724365,
"roc_auc": 0.8677473068237305,
"pr_auc": 0.3685624301433563
}
]
}
51 changes: 0 additions & 51 deletions hparams.py

This file was deleted.

13 changes: 1 addition & 12 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,4 @@ def __getitem__(self, index):
return audio_tensor.to(dtype=torch.float32), tag_binary.astype("float32")

def __len__(self):
return len(self.fl)


def get_audio_loader(root, batch_size, input_length, split="TRAIN", num_workers=0):
data_loader = data.DataLoader(
dataset=MTATDataset(root, split=split, input_length=input_length),
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers,
)
return data_loader
return len(self.fl)
18 changes: 14 additions & 4 deletions src/metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
"""
This script was made by Nick at 19/07/20.
To implement code for metric (e.g. NLL loss).
"""
import torch.nn as nn
from pytorch_lightning.metrics.sklearns import AUROC, AveragePrecision

roc_auc = AUROC(average='macro')
average_precision = AveragePrecision(average='macro')

def get_auc(y_score, y_true):
# for Validation sanity check:
if y_true.shape[0] == 1:
return 0,0
else:
roc_aucs = roc_auc(y_score.flatten(0,1), y_true.flatten(0,1))
pr_aucs = average_precision(y_score.flatten(0,1), y_true.flatten(0,1))
return roc_aucs, pr_aucs
19 changes: 9 additions & 10 deletions src/model/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ class HarmonicCNN(nn.Module):
Won et al. 2020
Data-driven harmonic filters for audio representation learning.
Trainable harmonic band-pass filters.
https://github.com/minzwon/sota-music-tagging-models
"""
def __init__(self,
n_channels=128,
sample_rate=16000,
n_fft=512,
f_min=0.0,
f_max=8000.0,
n_mels=128,
n_class=50,
n_harmonic=6,
semitone_scale=2,
learn_bw='only_Q'):
n_channels: int,
sample_rate: int,
n_fft: int,
n_mels: int,
n_class: int,
n_harmonic: int,
semitone_scale: int,
learn_bw: str):
"""Instantiating HarmonicCNN class
Args:
n_channels(int) : number of channels
Expand Down
69 changes: 69 additions & 0 deletions src/task/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pickle

from omegaconf import DictConfig
from typing import Optional, Callable
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningDataModule
from ..data import MTATDataset

class DataPipeline(LightningDataModule):
def __init__(self, pipline_config: DictConfig) -> None:
super(DataPipeline, self).__init__()
self.pipeline_config = pipline_config
self.dataset_builder = MTATDataset

def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
self.train_dataset = DataPipeline.get_dataset(
self.dataset_builder,
self.pipeline_config.dataset.path,
"TRAIN",
self.pipeline_config.dataset.input_length
)

self.val_dataset = DataPipeline.get_dataset(self.dataset_builder,
self.pipeline_config.dataset.path,
"VALID",
self.pipeline_config.dataset.input_length)

if stage == "test" or stage is None:
self.test_dataset = DataPipeline.get_dataset(self.dataset_builder,
self.pipeline_config.dataset.path,
"TEST",
self.pipeline_config.dataset.input_length)

def train_dataloader(self) -> DataLoader:
return DataPipeline.get_dataloader(self.train_dataset,
batch_size=self.pipeline_config.dataloader.params.batch_size,
num_workers=self.pipeline_config.dataloader.params.num_workers,
drop_last=True,
shuffle=True)

def val_dataloader(self) -> DataLoader:
return DataPipeline.get_dataloader(self.val_dataset,
batch_size=self.pipeline_config.dataloader.params.batch_size,
num_workers=self.pipeline_config.dataloader.params.num_workers,
drop_last=True,
shuffle=False)

def test_dataloader(self) -> DataLoader:
return DataPipeline.get_dataloader(self.test_dataset,
batch_size=self.pipeline_config.dataloader.params.batch_size,
num_workers=self.pipeline_config.dataloader.params.num_workers,
drop_last=True,
shuffle=False)

@classmethod
def get_dataset(cls, dataset_builder:Callable, root, split, length) -> Dataset:
dataset = dataset_builder(root, split, length)
return dataset

@classmethod
def get_dataloader(cls, dataset: Dataset, batch_size: int, num_workers: int, shuffle: bool, drop_last: bool,
**kwargs) -> DataLoader:
return DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
drop_last=drop_last,
**kwargs)
Loading