| | |
| | """ |
| | HuggingFace Space Demo for TextSyncMimi |
| | Speech Editing with Token-Level Embedding Swapping |
| | |
| | This demo loads the model from HuggingFace Hub and allows: |
| | - Generating speech with different voices using OpenAI TTS |
| | - Swapping speech embeddings at specific token positions |
| | - Real-time speech editing |
| | |
| | Prerequisites: |
| | - Set OPENAI_API_KEY in Space secrets |
| | - Model will be loaded from HuggingFace Hub |
| | """ |
| |
|
| | import os |
| | import json |
| | import tempfile |
| | import argparse |
| | from typing import List, Tuple, Optional |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import soundfile as sf |
| | import gradio as gr |
| | from openai import OpenAI |
| | from transformers import ( |
| | AutoModel, |
| | AutoFeatureExtractor, |
| | AutoTokenizer, |
| | MimiModel, |
| | ) |
| |
|
| | |
| | try: |
| | import spaces |
| | GPU_AVAILABLE = True |
| | except ImportError: |
| | GPU_AVAILABLE = False |
| | |
| | class spaces: |
| | @staticmethod |
| | def GPU(func): |
| | return func |
| |
|
| |
|
| | |
| | SAMPLE_RATE = 24000 |
| | FRAME_RATE = 12.5 |
| | TTS_VOICES = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer", "verse"] |
| | MAX_Z_TOKENS = 50 |
| | END_TOKEN_THRESHOLD = 0.5 |
| |
|
| | |
| | model = None |
| | mimi_model = None |
| | tokenizer = None |
| | feature_extractor = None |
| | device = None |
| | openai_client = None |
| |
|
| |
|
| | def load_audio_to_inputs(feature_extractor, audio_path: str, sample_rate: int) -> torch.Tensor: |
| | """Load audio file and convert to model inputs.""" |
| | import librosa |
| | audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) |
| | audio_inputs = feature_extractor(raw_audio=audio, return_tensors="pt", sampling_rate=sample_rate) |
| | return audio_inputs.input_values |
| |
|
| |
|
| | def initialize_models(model_id: str, tokenizer_id: str = "meta-llama/Llama-3.1-8B-Instruct", hf_token: Optional[str] = None): |
| | """Initialize all models from HuggingFace Hub.""" |
| | global model, mimi_model, tokenizer, feature_extractor, device, openai_client |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | print(f"Loading TextSyncMimi model from {model_id}...") |
| | model = AutoModel.from_pretrained( |
| | model_id, |
| | trust_remote_code=True, |
| | token=hf_token |
| | ) |
| | model.to(device) |
| | model.eval() |
| | |
| | |
| | mimi_model_id = model.config.mimi_model_id if hasattr(model.config, 'mimi_model_id') else "kyutai/mimi" |
| | |
| | print("Loading Mimi model...") |
| | mimi_model = MimiModel.from_pretrained(mimi_model_id, token=hf_token) |
| | mimi_model.to(device) |
| | mimi_model.eval() |
| | |
| | print(f"Loading tokenizer from {tokenizer_id}...") |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=hf_token) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | print("Loading feature extractor...") |
| | feature_extractor = AutoFeatureExtractor.from_pretrained(mimi_model_id, token=hf_token) |
| | |
| | print("Initializing OpenAI client...") |
| | openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| | |
| | print("โ
All models loaded successfully!") |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_cross_attention_s( |
| | model, |
| | text_embeddings: torch.Tensor, |
| | input_values: torch.Tensor, |
| | device: str |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Compute projected text embeddings and cross-attended speech embeddings.""" |
| | audio_attention_mask = torch.ones(1, input_values.shape[-1], dtype=torch.bool, device=device) |
| | text_attention_mask = torch.ones(1, text_embeddings.shape[1], dtype=torch.bool, device=device) |
| |
|
| | |
| | speech_embeddings = model.encode_audio_to_representation( |
| | input_values.to(device), |
| | audio_attention_mask=audio_attention_mask, |
| | ).transpose(1, 2) |
| |
|
| | |
| | text_proj = model.text_proj(text_embeddings.to(device)) |
| |
|
| | |
| | batch_size, text_seq_len = text_proj.shape[:2] |
| | causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_proj.dtype)) |
| | causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) |
| | pad_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) |
| | formatted_text_attention_mask = torch.where((causal_mask * pad_mask).bool(), 0.0, float("-inf")) |
| |
|
| | speech_seq_len = speech_embeddings.shape[1] |
| | speech_mask = torch.ones(batch_size, speech_seq_len, dtype=torch.bool, device=device) |
| | formatted_speech_attention_mask = torch.where( |
| | speech_mask.view(batch_size, 1, 1, speech_seq_len), 0.0, float("-inf") |
| | ) |
| |
|
| | |
| | cross_out = model.cross_attention_transformer( |
| | hidden_states=text_proj, |
| | encoder_hidden_states=speech_embeddings, |
| | attention_mask=formatted_text_attention_mask, |
| | encoder_attention_mask=formatted_speech_attention_mask, |
| | alignment_chunk_sizes=None, |
| | ).last_hidden_state |
| |
|
| | return text_proj, cross_out, text_attention_mask |
| |
|
| |
|
| | @torch.no_grad() |
| | def ar_generate_and_decode( |
| | model, |
| | mimi_model, |
| | text_proj: torch.Tensor, |
| | s_tokens: torch.Tensor, |
| | text_attention_mask: torch.Tensor, |
| | max_z_tokens: int, |
| | end_token_threshold: float, |
| | device: str |
| | ) -> np.ndarray: |
| | """Generate audio autoregressively and decode to waveform.""" |
| | batch_size, text_seq_len = text_proj.shape[:2] |
| |
|
| | text_speech_latent_emb = model.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_start_emb = model.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| | time_speech_end_emb = model.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) |
| |
|
| | generated_z_tokens: List[torch.Tensor] = [] |
| |
|
| | for b in range(batch_size): |
| | if text_attention_mask is not None: |
| | valid_text_len = int(text_attention_mask[b].sum().item()) |
| | else: |
| | valid_text_len = text_seq_len |
| |
|
| | sequence: List[torch.Tensor] = [text_speech_latent_emb] |
| |
|
| | for i in range(valid_text_len): |
| | t_i = text_proj[b, i:i+1] |
| | s_i = s_tokens[b, i:i+1] |
| |
|
| | sequence.extend([t_i, s_i]) |
| | sequence.append(time_speech_start_emb) |
| |
|
| | z_count = 0 |
| | while z_count < max_z_tokens: |
| | current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) |
| | ar_attention_mask = torch.ones(1, current_sequence.shape[1], dtype=torch.bool, device=device) |
| |
|
| | ar_outputs = model.ar_transformer( |
| | hidden_states=current_sequence, |
| | attention_mask=ar_attention_mask, |
| | ) |
| | last_prediction = ar_outputs.last_hidden_state[0, -1:, :] |
| |
|
| | end_token_logit = model.end_token_classifier(last_prediction).squeeze(-1) |
| | end_token_prob = torch.sigmoid(end_token_logit).item() |
| |
|
| | if end_token_prob >= end_token_threshold: |
| | break |
| | sequence.append(last_prediction) |
| | generated_z_tokens.append(last_prediction.squeeze(0)) |
| | z_count += 1 |
| |
|
| | sequence.append(time_speech_end_emb) |
| |
|
| | |
| | if len(generated_z_tokens) == 0: |
| | audio_tensor = torch.zeros(1, 1, 1000, device=device) |
| | else: |
| | z_tokens_batch = torch.stack(generated_z_tokens, dim=0).unsqueeze(0) |
| | embeddings_bct = z_tokens_batch.transpose(1, 2) |
| | embeddings_upsampled = mimi_model.upsample(embeddings_bct) |
| | decoder_outputs = mimi_model.decoder_transformer(embeddings_upsampled.transpose(1, 2), return_dict=True) |
| | embeddings_after_dec = decoder_outputs.last_hidden_state.transpose(1, 2) |
| | audio_tensor = mimi_model.decoder(embeddings_after_dec) |
| |
|
| | audio_numpy = audio_tensor.squeeze().detach().cpu().numpy() |
| | if np.isnan(audio_numpy).any() or np.isinf(audio_numpy).any(): |
| | audio_numpy = np.nan_to_num(audio_numpy) |
| | if audio_numpy.ndim > 1: |
| | audio_numpy = audio_numpy.flatten() |
| | return audio_numpy |
| |
|
| |
|
| | def generate_tts_audio(text: str, voice: str, instructions: str = None) -> str: |
| | """Generate TTS audio using OpenAI and return the file path.""" |
| | if not openai_client: |
| | raise RuntimeError("OpenAI client not initialized") |
| | |
| | if instructions and instructions.strip(): |
| | response = openai_client.audio.speech.create( |
| | model="gpt-4o-mini-tts", |
| | voice=voice, |
| | input=text, |
| | instructions=instructions.strip() |
| | ) |
| | else: |
| | response = openai_client.audio.speech.create( |
| | model="tts-1", |
| | voice=voice, |
| | input=text |
| | ) |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file: |
| | response.stream_to_file(temp_file.name) |
| | return temp_file.name |
| |
|
| |
|
| | @spaces.GPU |
| | def process_inputs(transcript_text: str, voice1: str, voice2: str, instructions1: str = "", instructions2: str = ""): |
| | """Process inputs and generate audio.""" |
| | if not all([model, mimi_model, tokenizer, feature_extractor, openai_client]): |
| | return "Please initialize models first!", None, None, None, None, None, None, None |
| | |
| | if not transcript_text.strip(): |
| | return "Please provide a transcript!", None, None, None, None, None, None, None |
| | |
| | if not voice1 or not voice2: |
| | return "Please select voices for both audio samples!", None, None, None, None, None, None, None |
| | |
| | |
| | tokens = tokenizer(transcript_text.strip(), return_tensors="pt", add_special_tokens=False) |
| | text_token_ids_cpu = tokens.input_ids.squeeze(0).tolist() |
| | text_token_strs = tokenizer.convert_ids_to_tokens(text_token_ids_cpu) |
| | text_token_ids = tokens.input_ids.to(device) |
| | |
| | token_display = "" |
| | for i, tok in enumerate(text_token_strs): |
| | token_display += f"Token {i}: {tok}\n" |
| | |
| | |
| | print(f"Generating TTS audio with voice '{voice1}'...") |
| | audio1_path = generate_tts_audio(transcript_text.strip(), voice1, instructions1) |
| | print(f"Generating TTS audio with voice '{voice2}'...") |
| | audio2_path = generate_tts_audio(transcript_text.strip(), voice2, instructions2) |
| | |
| | |
| | input_values_utt1 = load_audio_to_inputs(feature_extractor, audio1_path, SAMPLE_RATE) |
| | input_values_utt2 = load_audio_to_inputs(feature_extractor, audio2_path, SAMPLE_RATE) |
| | |
| | |
| | with torch.no_grad(): |
| | text_embeddings = model.text_token_embedding(text_token_ids) |
| | |
| | |
| | t1_proj, s1_cross, text_attention_mask = compute_cross_attention_s( |
| | model, text_embeddings, input_values_utt1, device |
| | ) |
| | _, s2_cross, _ = compute_cross_attention_s( |
| | model, text_embeddings, input_values_utt2, device |
| | ) |
| | |
| | |
| | baseline_audio = ar_generate_and_decode( |
| | model=model, |
| | mimi_model=mimi_model, |
| | text_proj=t1_proj, |
| | s_tokens=s1_cross, |
| | text_attention_mask=text_attention_mask, |
| | max_z_tokens=MAX_Z_TOKENS, |
| | end_token_threshold=END_TOKEN_THRESHOLD, |
| | device=device, |
| | ) |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| | sf.write(f.name, baseline_audio, SAMPLE_RATE) |
| | baseline_path = f.name |
| | |
| | return ( |
| | "Processing completed successfully!", |
| | token_display, |
| | audio1_path, |
| | audio2_path, |
| | baseline_path, |
| | json.dumps({ |
| | "t1_proj": t1_proj.cpu().numpy().tolist(), |
| | "s1_cross": s1_cross.cpu().numpy().tolist(), |
| | "s2_cross": s2_cross.cpu().numpy().tolist(), |
| | "text_attention_mask": text_attention_mask.cpu().numpy().tolist(), |
| | "num_tokens": len(text_token_strs) |
| | }), |
| | audio1_path, |
| | audio2_path |
| | ) |
| |
|
| |
|
| | @spaces.GPU |
| | def swap_embeddings(embeddings_json: str, swap_indices: str): |
| | """Perform embedding swap at specified token indices.""" |
| | if not embeddings_json: |
| | return "Please process inputs first!", None |
| | |
| | if not swap_indices.strip(): |
| | return "Please specify token indices to swap (e.g., 0,2,5)!", None |
| | |
| | |
| | embeddings_data = json.loads(embeddings_json) |
| | t1_proj = torch.tensor(embeddings_data["t1_proj"]).to(device) |
| | s1_cross = torch.tensor(embeddings_data["s1_cross"]).to(device) |
| | s2_cross = torch.tensor(embeddings_data["s2_cross"]).to(device) |
| | text_attention_mask = torch.tensor(embeddings_data["text_attention_mask"]).to(device) |
| | num_tokens = embeddings_data["num_tokens"] |
| | |
| | |
| | parts = [p.strip() for p in swap_indices.split(",")] |
| | parsed = [int(p) for p in parts if p.isdigit()] |
| | |
| | if len(parsed) == 0: |
| | return "No valid indices provided! Use format: 0,2,5", None |
| | |
| | valid_indices = [i for i in parsed if 0 <= i < num_tokens] |
| | if len(valid_indices) == 0: |
| | return f"All indices out of range! Valid range: 0-{num_tokens-1}", None |
| | |
| | |
| | s_swapped = s1_cross.clone() |
| | for idx in valid_indices: |
| | s_swapped[:, idx:idx+1, :] = s2_cross[:, idx:idx+1, :] |
| | |
| | |
| | swapped_audio = ar_generate_and_decode( |
| | model=model, |
| | mimi_model=mimi_model, |
| | text_proj=t1_proj, |
| | s_tokens=s_swapped, |
| | text_attention_mask=text_attention_mask, |
| | max_z_tokens=MAX_Z_TOKENS, |
| | end_token_threshold=END_TOKEN_THRESHOLD, |
| | device=device, |
| | ) |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| | sf.write(f.name, swapped_audio, SAMPLE_RATE) |
| | swapped_path = f.name |
| | |
| | return f"Successfully swapped embeddings at token indices: {valid_indices}", swapped_path |
| |
|
| |
|
| | def create_gradio_interface(): |
| | """Create the Gradio interface.""" |
| | with gr.Blocks(title="TextSyncMimi Demo") as interface: |
| | gr.Markdown("# TextSyncMimi - Standalone Demo") |
| | gr.Markdown("Generate two voice renditions using OpenAI TTS, then swap speech embeddings at token positions.") |
| | gr.Markdown("**This demo uses only the self-contained TextSyncMimi-v1 model code.**") |
| | |
| | with gr.Accordion("Style Instruction Examples", open=False): |
| | gr.Markdown(""" |
| | **Example Instructions:** |
| | - *Emotional:* "Speak with excitement and joy", "Sound sad and melancholy" |
| | - *Pace:* "Speak slowly and deliberately", "Talk quickly and energetically" |
| | - *Character:* "Sound like a wise professor", "Speak like an excited child" |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("## Text-to-Speech Configuration") |
| | transcript_text = gr.Textbox( |
| | label="Transcript Text", |
| | placeholder="Enter text to synthesize...", |
| | lines=3 |
| | ) |
| | with gr.Row(): |
| | voice1 = gr.Dropdown( |
| | choices=TTS_VOICES, |
| | label="Voice 1", |
| | value="alloy" |
| | ) |
| | voice2 = gr.Dropdown( |
| | choices=TTS_VOICES, |
| | label="Voice 2", |
| | value="echo" |
| | ) |
| | instructions1 = gr.Textbox( |
| | label="Style Instructions for Voice 1", |
| | placeholder="e.g., Speak slowly and calmly", |
| | lines=2 |
| | ) |
| | instructions2 = gr.Textbox( |
| | label="Style Instructions for Voice 2", |
| | placeholder="e.g., Speak quickly with excitement", |
| | lines=2 |
| | ) |
| | process_btn = gr.Button("Generate & Process", variant="primary") |
| | process_status = gr.Textbox(label="Status", interactive=False) |
| | |
| | with gr.Column(): |
| | gr.Markdown("## Tokenization") |
| | tokens_display = gr.Textbox( |
| | label="Tokens", |
| | lines=16, |
| | interactive=False |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("## Generated TTS Audio") |
| | generated_audio1 = gr.Audio(label="Generated Audio 1") |
| | generated_audio2 = gr.Audio(label="Generated Audio 2") |
| | |
| | with gr.Column(): |
| | gr.Markdown("## Model Output") |
| | baseline_audio = gr.Audio(label="Baseline Reconstruction") |
| | |
| | gr.Markdown("### Embedding Swap") |
| | swap_indices_input = gr.Textbox( |
| | label="Token Indices to Swap", |
| | placeholder="e.g., 0,2,5" |
| | ) |
| | swap_btn = gr.Button("Perform Swap") |
| | swap_status = gr.Textbox(label="Swap Status", interactive=False) |
| | swapped_audio = gr.Audio(label="Swapped Result") |
| | |
| | |
| | embeddings_state = gr.State() |
| | audio1_state = gr.State() |
| | audio2_state = gr.State() |
| | |
| | |
| | process_btn.click( |
| | fn=process_inputs, |
| | inputs=[transcript_text, voice1, voice2, instructions1, instructions2], |
| | outputs=[process_status, tokens_display, generated_audio1, generated_audio2, |
| | baseline_audio, embeddings_state, audio1_state, audio2_state] |
| | ) |
| | |
| | swap_btn.click( |
| | fn=swap_embeddings, |
| | inputs=[embeddings_state, swap_indices_input], |
| | outputs=[swap_status, swapped_audio] |
| | ) |
| | |
| | return interface |
| |
|
| |
|
| | def main(): |
| | """Main function.""" |
| | parser = argparse.ArgumentParser(description="HuggingFace Space Demo for TextSyncMimi") |
| | parser.add_argument( |
| | "--model_id", |
| | type=str, |
| | default="potsawee/TextSyncMimi-v1", |
| | help="HuggingFace model ID" |
| | ) |
| | parser.add_argument( |
| | "--tokenizer_id", |
| | type=str, |
| | default="meta-llama/Llama-3.1-8B-Instruct", |
| | help="HuggingFace tokenizer ID" |
| | ) |
| | parser.add_argument( |
| | "--hf_token", |
| | type=str, |
| | default=None, |
| | help="Hugging Face token (or set HF_TOKEN env var)" |
| | ) |
| | parser.add_argument( |
| | "--port", |
| | type=int, |
| | default=7860, |
| | help="Port for Gradio app" |
| | ) |
| | parser.add_argument( |
| | "--share", |
| | action="store_true", |
| | help="Create public share link" |
| | ) |
| | args = parser.parse_args() |
| | |
| | |
| | if not os.getenv("OPENAI_API_KEY"): |
| | print("โ Error: OPENAI_API_KEY environment variable is required!") |
| | print("Set it: export OPENAI_API_KEY=your_key_here") |
| | return |
| | |
| | |
| | hf_token = args.hf_token or os.getenv("HF_TOKEN") |
| | |
| | |
| | print(f"๐ Initializing TextSyncMimi from HuggingFace Hub: {args.model_id}...") |
| | initialize_models(args.model_id, args.tokenizer_id, hf_token) |
| | print("๐ Launching Gradio interface...") |
| | |
| | |
| | interface = create_gradio_interface() |
| | interface.launch(server_port=args.port, share=args.share) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|