Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@
AudioSpecialTokens = _mod2.AudioSpecialTokens


class SyntheticAcousticTransformerArgs:
"""Mimics AcousticTransformerArgs interface."""

def __init__(self):
self.n_decoding_steps = 7


class SyntheticModelArgs:
"""Mimics MultimodalAudioModelArgs interface."""

Expand All @@ -96,6 +103,7 @@ class SyntheticAcousticTransformer(nn.Module):
def __init__(self):
super().__init__()
self.model_args = SyntheticModelArgs()
self.acoustic_transformer_args = SyntheticAcousticTransformerArgs()
self.acoustic_embeddings_levels = ACOUSTIC_EMBEDDINGS_LEVELS

# semantic_codebook_output: hidden_dim -> padded_codebook_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def _remap_mistral_audio_args(self, config_dict: dict) -> dict:
audio_tokenizer_args = config_dict["multimodal"].pop("audio_tokenizer_args", None)
audio_config = {}
if encoder_args is not None:
# Default n_decoding_steps if not provided
acoustic_args = encoder_args.get("acoustic_transformer_args", {})
if acoustic_args.get("n_decoding_steps") is None:
logger.warning(
"n_decoding_steps not provided in acoustic_transformer_args, defaulting to 7. "
"Please add 'n_decoding_steps' to params.json under acoustic_transformer_args."
)
acoustic_args["n_decoding_steps"] = 7

audio_config = {
"sampling_rate": encoder_args["audio_encoding_args"]["sampling_rate"],
"codec_args": audio_tokenizer_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.acoustic_embeddings_levels = self.acoustic_transformer.acoustic_embeddings_levels

self.cfg_alpha = 1.2
self.n_steps = 8
self.n_steps = self.acoustic_transformer.acoustic_transformer_args.n_decoding_steps

# Graph storage
self.graphs: dict[int, CUDAGraph] = {}
Expand All @@ -72,7 +72,7 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d
)

# Pre-create persistent buffers
self.timesteps = torch.linspace(0, 1, self.n_steps, device=device, dtype=dtype)
self.timesteps = torch.linspace(0, 1, self.n_steps + 1, device=device, dtype=dtype)
self.fake_eos_one = torch.tensor(1.0, dtype=dtype, device=device)
self.fake_eos_zero = torch.tensor(0.0, dtype=dtype, device=device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class AcousticTransformerArgs:
use_biases: bool = False
norm_eps: float = 1e-5
sigma: float = 1e-5 # was 0.01 in beta version
n_decoding_steps: int | None = None # Number of Euler ODE steps for flow matching


@dataclass
Expand Down Expand Up @@ -436,14 +437,13 @@ def __init__(
self._empty_audio_token_id = AudioSpecialTokens.id(AudioSpecialTokens.empty_audio)

# Flow matching constants
# TODO(chenyo): hardcoded, need to fix
self._acoustic_decode_iters = 8
self._n_steps = args.n_decoding_steps
# TODO(chenyo): hardcoded, need to fix
self._cfg_alpha = 1.2
self._noise_scale = 1.0
self.register_buffer(
"_timesteps",
torch.linspace(0, 1, self._acoustic_decode_iters),
torch.linspace(0, 1, self._n_steps + 1),
persistent=False,
)

Expand Down
Loading