From cdbba08274395afe2690f51de5c6628153934193 Mon Sep 17 00:00:00 2001 From: Niels Date: Sun, 8 Sep 2024 14:44:16 +0200 Subject: [PATCH 1/2] First draft --- model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/model.py b/model.py index 4c0d2e8..79e01b1 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,8 @@ import torch from torch import Tensor, nn +from huggingface_hub import PyTorchModelHubMixin + from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding) @@ -24,7 +26,7 @@ class FluxParams: guidance_embed: bool -class Flux(nn.Module): +class Flux(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/feizc/FluxMusic", pipeline_tag="text-to-audio", tags=["text-to-music"], license="apache-2.0"): """ Transformer model for flow matching on sequences. """ From 4f4d146d3ea5e78033111a782897ed357582b0a4 Mon Sep 17 00:00:00 2001 From: Niels Date: Sun, 8 Sep 2024 14:48:36 +0200 Subject: [PATCH 2/2] Update sample --- sample.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sample.py b/sample.py index ab6ad40..61de365 100644 --- a/sample.py +++ b/sample.py @@ -53,11 +53,9 @@ def main(args): latent_size = (256, 16) - model = build_model(args.version).to(device) - local_path = args.ckpt_path - state_dict = torch.load(local_path, map_location=lambda storage, loc: storage) - model.load_state_dict(state_dict['ema']) - model.eval() # important! + repo_id = f"feizhengcong/FluxMusic-{args.version}" + model.from_pretrained(repo_id) + model.to(device) diffusion = RF() # Setup VAE @@ -112,7 +110,6 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--version", type=str, default="small") parser.add_argument("--prompt_file", type=str, default='config/example.txt') - parser.add_argument("--ckpt_path", type=str, default='musicflow_s.pt') parser.add_argument("--audioldm2_model_path", type=str, default='/maindata/data/shared/multimodal/public/dataset_music/audioldm2' ) parser.add_argument("--seed", type=int, default=2024) args = parser.parse_args()