Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ACE-Step 1.5 Custom Edition - Main Application | |
| A comprehensive music generation system with three main interfaces: | |
| 1. Standard ACE-Step GUI | |
| 2. Custom Timeline-based Workflow | |
| 3. LoRA Training Studio | |
| """ | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| from typing import Optional, List, Tuple | |
| import spaces | |
| from src.ace_step_engine import ACEStepEngine | |
| from src.timeline_manager import TimelineManager | |
| from src.lora_trainer import LoRATrainer | |
| from src.audio_processor import AudioProcessor | |
| from src.utils import setup_logging, load_config | |
| # Setup | |
| logger = setup_logging() | |
| config = load_config() | |
| # Lazy initialize components (will be initialized on first use) | |
| ace_engine = None | |
| timeline_manager = None | |
| lora_trainer = None | |
| audio_processor = None | |
| def get_ace_engine(): | |
| """Lazy-load ACE-Step engine.""" | |
| global ace_engine | |
| if ace_engine is None: | |
| ace_engine = ACEStepEngine(config) | |
| return ace_engine | |
| def get_timeline_manager(): | |
| """Lazy-load timeline manager.""" | |
| global timeline_manager | |
| if timeline_manager is None: | |
| timeline_manager = TimelineManager(config) | |
| return timeline_manager | |
| def get_lora_trainer(): | |
| """Lazy-load LoRA trainer.""" | |
| global lora_trainer | |
| if lora_trainer is None: | |
| lora_trainer = LoRATrainer(config) | |
| return lora_trainer | |
| def get_audio_processor(): | |
| """Lazy-load audio processor.""" | |
| global audio_processor | |
| if audio_processor is None: | |
| audio_processor = AudioProcessor(config) | |
| return audio_processor | |
| # ==================== TAB 1: STANDARD ACE-STEP GUI ==================== | |
| def standard_generate( | |
| prompt: str, | |
| lyrics: str, | |
| duration: int, | |
| temperature: float, | |
| top_p: float, | |
| seed: int, | |
| style: str, | |
| use_lora: bool, | |
| lora_path: Optional[str] = None | |
| ) -> Tuple[str, str]: | |
| """Standard ACE-Step generation with all original features.""" | |
| try: | |
| logger.info(f"Standard generation: {prompt[:50]}...") | |
| # Get engine instance | |
| engine = get_ace_engine() | |
| # Generate audio | |
| audio_path = engine.generate( | |
| prompt=prompt, | |
| lyrics=lyrics, | |
| duration=duration, | |
| temperature=temperature, | |
| top_p=top_p, | |
| seed=seed, | |
| style=style, | |
| lora_path=lora_path if use_lora else None | |
| ) | |
| info = f"✅ Generated {duration}s audio successfully" | |
| return audio_path, info | |
| except Exception as e: | |
| logger.error(f"Standard generation failed: {e}") | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]: | |
| """Generate variation of existing audio.""" | |
| try: | |
| result = get_ace_engine().generate_variation(audio_path, variation_strength) | |
| return result, "✅ Variation generated" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_repaint( | |
| audio_path: str, | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str | |
| ) -> Tuple[str, str]: | |
| """Repaint specific section of audio.""" | |
| try: | |
| result = get_ace_engine().repaint(audio_path, start_time, end_time, new_prompt) | |
| return result, f"✅ Repainted {start_time}s-{end_time}s" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_lyric_edit( | |
| audio_path: str, | |
| new_lyrics: str | |
| ) -> Tuple[str, str]: | |
| """Edit lyrics while maintaining music.""" | |
| try: | |
| result = get_ace_engine().edit_lyrics(audio_path, new_lyrics) | |
| return result, "✅ Lyrics edited" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # ==================== TAB 2: CUSTOM TIMELINE WORKFLOW ==================== | |
| def timeline_generate( | |
| prompt: str, | |
| lyrics: str, | |
| context_length: int, | |
| style: str, | |
| temperature: float, | |
| seed: int, | |
| session_state: dict | |
| ) -> Tuple[str, str, str, dict]: | |
| """ | |
| Generate 32-second clip with 2s lead-in, 28s main, 2s lead-out. | |
| Blends with previous clips based on context_length. | |
| """ | |
| try: | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| logger.info(f"Timeline generation with {context_length}s context") | |
| # Get managers | |
| tm = get_timeline_manager() | |
| engine = get_ace_engine() | |
| ap = get_audio_processor() | |
| # Get context from timeline | |
| context_audio = tm.get_context( | |
| session_state.get("timeline_id"), | |
| context_length | |
| ) | |
| # Generate 32s clip | |
| clip = engine.generate_clip( | |
| prompt=prompt, | |
| lyrics=lyrics, | |
| duration=32, | |
| context_audio=context_audio, | |
| style=style, | |
| temperature=temperature, | |
| seed=seed | |
| ) | |
| # Blend with timeline (2s lead-in and lead-out) | |
| blended_clip = ap.blend_clip( | |
| clip, | |
| tm.get_last_clip(session_state.get("timeline_id")), | |
| lead_in=2.0, | |
| lead_out=2.0 | |
| ) | |
| # Add to timeline | |
| timeline_id = tm.add_clip( | |
| session_state.get("timeline_id"), | |
| blended_clip, | |
| metadata={ | |
| "prompt": prompt, | |
| "lyrics": lyrics, | |
| "context_length": context_length | |
| } | |
| ) | |
| # Update session | |
| session_state["timeline_id"] = timeline_id | |
| session_state["total_clips"] = session_state.get("total_clips", 0) + 1 | |
| # Get full timeline audio | |
| full_audio = tm.export_timeline(timeline_id) | |
| # Get timeline visualization | |
| timeline_viz = tm.visualize_timeline(timeline_id) | |
| info = f"✅ Clip {session_state['total_clips']} added • Total: {tm.get_duration(timeline_id):.1f}s" | |
| return blended_clip, full_audio, timeline_viz, session_state, info | |
| except Exception as e: | |
| logger.error(f"Timeline generation failed: {e}") | |
| return None, None, None, session_state, f"❌ Error: {str(e)}" | |
| def timeline_extend( | |
| prompt: str, | |
| lyrics: str, | |
| context_length: int, | |
| session_state: dict | |
| ) -> Tuple[str, str, str, dict]: | |
| """Extend current timeline with new generation.""" | |
| return timeline_generate( | |
| prompt, lyrics, context_length, "auto", 0.7, -1, session_state | |
| ) | |
| def timeline_inpaint( | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str, | |
| session_state: dict | |
| ) -> Tuple[str, str, dict]: | |
| """Inpaint specific region in timeline.""" | |
| try: | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| tm = get_timeline_manager() | |
| timeline_id = session_state.get("timeline_id") | |
| result = tm.inpaint_region( | |
| timeline_id, | |
| start_time, | |
| end_time, | |
| new_prompt | |
| ) | |
| full_audio = tm.export_timeline(timeline_id) | |
| timeline_viz = tm.visualize_timeline(timeline_id) | |
| info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s" | |
| return full_audio, timeline_viz, session_state, info | |
| except Exception as e: | |
| return None, None, session_state, f"❌ Error: {str(e)}" | |
| def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]: | |
| """Reset timeline to start fresh.""" | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| elif session_state.get("timeline_id"): | |
| get_timeline_manager().delete_timeline(session_state["timeline_id"]) | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| return None, None, "Timeline cleared", session_state | |
| # ==================== TAB 3: LORA TRAINING ==================== | |
| def lora_upload_files(files: List[str]) -> str: | |
| """Upload and prepare audio files for LoRA training.""" | |
| try: | |
| prepared_files = get_lora_trainer().prepare_dataset(files) | |
| return f"✅ Prepared {len(prepared_files)} files for training" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def lora_train( | |
| dataset_path: str, | |
| model_name: str, | |
| learning_rate: float, | |
| batch_size: int, | |
| num_epochs: int, | |
| rank: int, | |
| alpha: int, | |
| use_existing_lora: bool, | |
| existing_lora_path: Optional[str] = None, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str]: | |
| """Train LoRA model on uploaded dataset.""" | |
| try: | |
| logger.info(f"Starting LoRA training: {model_name}") | |
| # Initialize or load LoRA | |
| if use_existing_lora and existing_lora_path: | |
| lora_trainer.load_lora(existing_lora_path) | |
| else: | |
| lora_trainer.initialize_lora(rank=rank, alpha=alpha) | |
| # Train | |
| def progress_callback(step, total_steps, loss): | |
| progress((step, total_steps), desc=f"Training (loss: {loss:.4f})") | |
| result_path = lora_trainer.train( | |
| dataset_path=dataset_path, | |
| model_name=model_name, | |
| learning_rate=learning_rate, | |
| batch_size=batch_size, | |
| num_epochs=num_epochs, | |
| progress_callback=progress_callback | |
| ) | |
| info = f"✅ Training complete! Model saved to {result_path}" | |
| return result_path, info | |
| except Exception as e: | |
| logger.error(f"LoRA training failed: {e}") | |
| return None, f"❌ Error: {str(e)}" | |
| def lora_download(lora_path: str) -> str: | |
| """Provide LoRA model for download.""" | |
| return lora_path if Path(lora_path).exists() else None | |
| # ==================== GRADIO UI ==================== | |
| def create_ui(): | |
| """Create the three-tab Gradio interface.""" | |
| with gr.Blocks(title="ACE-Step 1.5 Custom Edition", theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # 🎵 ACE-Step 1.5 Custom Edition | |
| **Three powerful interfaces for music generation and training** | |
| Models will download automatically on first use (~7GB from HuggingFace) | |
| """) | |
| with gr.Tabs(): | |
| # ============ TAB 1: STANDARD ACE-STEP ============ | |
| with gr.Tab("🎼 Standard ACE-Step"): | |
| gr.Markdown("### Full-featured standard ACE-Step 1.5 interface") | |
| with gr.Row(): | |
| with gr.Column(): | |
| std_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the music style, mood, instruments...", | |
| lines=3 | |
| ) | |
| std_lyrics = gr.Textbox( | |
| label="Lyrics (optional)", | |
| placeholder="Enter lyrics here...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| std_duration = gr.Slider( | |
| minimum=10, maximum=240, value=30, step=10, | |
| label="Duration (seconds)" | |
| ) | |
| std_style = gr.Dropdown( | |
| choices=["auto", "pop", "rock", "jazz", "classical", "electronic", "hip-hop"], | |
| value="auto", | |
| label="Style" | |
| ) | |
| with gr.Row(): | |
| std_temperature = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| std_top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.9, step=0.05, | |
| label="Top P" | |
| ) | |
| std_seed = gr.Number(label="Seed (-1 for random)", value=-1) | |
| with gr.Row(): | |
| std_use_lora = gr.Checkbox(label="Use LoRA", value=False) | |
| std_lora_path = gr.Textbox( | |
| label="LoRA Path", | |
| placeholder="Path to LoRA model (if using)" | |
| ) | |
| std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.Markdown("### Audio Input (Optional)") | |
| gr.Markdown("*Upload audio file or record to use as style guidance*") | |
| std_audio_input = gr.Audio( | |
| label="Style Reference Audio", | |
| type="filepath" | |
| ) | |
| gr.Markdown("### Generated Output") | |
| std_audio_out = gr.Audio(label="Generated Audio") | |
| std_info = gr.Textbox(label="Status", lines=2) | |
| gr.Markdown("### Advanced Controls") | |
| with gr.Accordion("🔄 Generate Variation", open=False): | |
| std_var_strength = gr.Slider(0.1, 1.0, 0.5, label="Variation Strength") | |
| std_var_btn = gr.Button("Generate Variation") | |
| with gr.Accordion("🎨 Repaint Section", open=False): | |
| std_repaint_start = gr.Number(label="Start Time (s)", value=0) | |
| std_repaint_end = gr.Number(label="End Time (s)", value=10) | |
| std_repaint_prompt = gr.Textbox(label="New Prompt", lines=2) | |
| std_repaint_btn = gr.Button("Repaint") | |
| with gr.Accordion("✏️ Edit Lyrics", open=False): | |
| std_edit_lyrics = gr.Textbox(label="New Lyrics", lines=4) | |
| std_edit_btn = gr.Button("Edit Lyrics") | |
| # Event handlers | |
| std_generate_btn.click( | |
| fn=standard_generate, | |
| inputs=[std_prompt, std_lyrics, std_duration, std_temperature, | |
| std_top_p, std_seed, std_style, std_use_lora, std_lora_path], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_var_btn.click( | |
| fn=standard_variation, | |
| inputs=[std_audio_out, std_var_strength], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_repaint_btn.click( | |
| fn=standard_repaint, | |
| inputs=[std_audio_out, std_repaint_start, std_repaint_end, std_repaint_prompt], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_edit_btn.click( | |
| fn=standard_lyric_edit, | |
| inputs=[std_audio_out, std_edit_lyrics], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| # ============ TAB 2: CUSTOM TIMELINE ============ | |
| with gr.Tab("⏱️ Timeline Workflow"): | |
| gr.Markdown(""" | |
| ### Custom Timeline-based Generation | |
| Generate 32-second clips that seamlessly blend together on a master timeline. | |
| """) | |
| # Session state for timeline | |
| timeline_state = gr.State(value=None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| tl_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe this section...", | |
| lines=3 | |
| ) | |
| tl_lyrics = gr.Textbox( | |
| label="Lyrics for this clip", | |
| placeholder="Enter lyrics for this 32s section...", | |
| lines=4 | |
| ) | |
| gr.Markdown("*How far back to reference for style guidance*") | |
| tl_context_length = gr.Slider( | |
| minimum=0, maximum=120, value=30, step=10, | |
| label="Context Length (seconds)" | |
| ) | |
| with gr.Row(): | |
| tl_style = gr.Dropdown( | |
| choices=["auto", "pop", "rock", "jazz", "electronic"], | |
| value="auto", | |
| label="Style" | |
| ) | |
| tl_temperature = gr.Slider( | |
| minimum=0.5, maximum=1.0, value=0.7, step=0.05, | |
| label="Temperature" | |
| ) | |
| tl_seed = gr.Number(label="Seed (-1 for random)", value=-1) | |
| with gr.Row(): | |
| tl_generate_btn = gr.Button("🎵 Generate Clip", variant="primary", size="lg") | |
| tl_extend_btn = gr.Button("➕ Extend", size="lg") | |
| tl_reset_btn = gr.Button("🔄 Reset Timeline", variant="secondary") | |
| tl_info = gr.Textbox(label="Status", lines=2) | |
| with gr.Column(): | |
| tl_clip_audio = gr.Audio(label="Latest Clip") | |
| tl_full_audio = gr.Audio(label="Full Timeline") | |
| tl_timeline_viz = gr.Image(label="Timeline Visualization") | |
| with gr.Accordion("🎨 Inpaint Timeline Region", open=False): | |
| tl_inpaint_start = gr.Number(label="Start Time (s)", value=0) | |
| tl_inpaint_end = gr.Number(label="End Time (s)", value=10) | |
| tl_inpaint_prompt = gr.Textbox(label="New Prompt", lines=2) | |
| tl_inpaint_btn = gr.Button("Inpaint Region") | |
| # Event handlers | |
| tl_generate_btn.click( | |
| fn=timeline_generate, | |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, tl_style, | |
| tl_temperature, tl_seed, timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| tl_extend_btn.click( | |
| fn=timeline_extend, | |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| tl_reset_btn.click( | |
| fn=timeline_reset, | |
| inputs=[timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_info, timeline_state] | |
| ) | |
| tl_inpaint_btn.click( | |
| fn=timeline_inpaint, | |
| inputs=[tl_inpaint_start, tl_inpaint_end, tl_inpaint_prompt, timeline_state], | |
| outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| # ============ TAB 3: LORA TRAINING ============ | |
| with gr.Tab("🎓 LoRA Training Studio"): | |
| gr.Markdown(""" | |
| ### Train Custom LoRA Models | |
| Upload audio files to train specialized models for voice cloning, style adaptation, etc. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### 1. Upload Training Data") | |
| lora_files = gr.File( | |
| label="Audio Files", | |
| file_count="multiple", | |
| file_types=["audio"] | |
| ) | |
| lora_upload_btn = gr.Button("📤 Upload & Prepare Dataset") | |
| lora_upload_status = gr.Textbox(label="Upload Status", lines=2) | |
| gr.Markdown("#### 2. Training Configuration") | |
| lora_dataset_path = gr.Textbox( | |
| label="Dataset Path", | |
| placeholder="Path to prepared dataset" | |
| ) | |
| lora_model_name = gr.Textbox( | |
| label="Model Name", | |
| placeholder="my_custom_lora" | |
| ) | |
| with gr.Row(): | |
| lora_learning_rate = gr.Number( | |
| label="Learning Rate", | |
| value=1e-4 | |
| ) | |
| lora_batch_size = gr.Slider( | |
| minimum=1, maximum=16, value=4, step=1, | |
| label="Batch Size" | |
| ) | |
| with gr.Row(): | |
| lora_num_epochs = gr.Slider( | |
| minimum=1, maximum=100, value=10, step=1, | |
| label="Epochs" | |
| ) | |
| lora_rank = gr.Slider( | |
| minimum=4, maximum=128, value=16, step=4, | |
| label="LoRA Rank" | |
| ) | |
| lora_alpha = gr.Slider( | |
| minimum=4, maximum=128, value=32, step=4, | |
| label="LoRA Alpha" | |
| ) | |
| lora_use_existing = gr.Checkbox( | |
| label="Continue training from existing LoRA", | |
| value=False | |
| ) | |
| lora_existing_path = gr.Textbox( | |
| label="Existing LoRA Path", | |
| placeholder="Path to existing LoRA model" | |
| ) | |
| lora_train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") | |
| with gr.Column(): | |
| lora_train_status = gr.Textbox(label="Training Status", lines=3) | |
| lora_model_path = gr.Textbox(label="Trained Model Path", lines=1) | |
| lora_download_btn = gr.Button("💾 Download Model") | |
| lora_download_file = gr.File(label="Download") | |
| gr.Markdown(""" | |
| #### Training Tips | |
| - Upload 10+ audio samples for best results | |
| - Keep samples consistent in style/quality | |
| - Higher rank = more capacity but slower training | |
| - Start with 10-20 epochs and adjust | |
| - Use existing LoRA to continue training | |
| """) | |
| # Event handlers | |
| lora_upload_btn.click( | |
| fn=lora_upload_files, | |
| inputs=[lora_files], | |
| outputs=[lora_upload_status] | |
| ) | |
| lora_train_btn.click( | |
| fn=lora_train, | |
| inputs=[lora_dataset_path, lora_model_name, lora_learning_rate, | |
| lora_batch_size, lora_num_epochs, lora_rank, lora_alpha, | |
| lora_use_existing, lora_existing_path], | |
| outputs=[lora_model_path, lora_train_status] | |
| ) | |
| lora_download_btn.click( | |
| fn=lora_download, | |
| inputs=[lora_model_path], | |
| outputs=[lora_download_file] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About | |
| ACE-Step 1.5 Custom Edition by Gamahea | Based on [ACE-Step](https://ace-step.github.io/) | |
| """) | |
| return app | |
| # ==================== MAIN ==================== | |
| if __name__ == "__main__": | |
| logger.info("Starting ACE-Step 1.5 Custom Edition...") | |
| try: | |
| # Create and launch app | |
| app = create_ui() | |
| # Monkey patch the get_api_info method to prevent JSON schema errors | |
| original_get_api_info = app.get_api_info | |
| def safe_get_api_info(*args, **kwargs): | |
| """Patched get_api_info that returns minimal info to avoid schema errors""" | |
| try: | |
| return original_get_api_info(*args, **kwargs) | |
| except (TypeError, AttributeError, KeyError) as e: | |
| logger.warning(f"API info generation failed, returning minimal info: {e}") | |
| return { | |
| "named_endpoints": {}, | |
| "unnamed_endpoints": {} | |
| } | |
| app.get_api_info = safe_get_api_info | |
| logger.info("✓ Patched get_api_info method") | |
| # Launch the app | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to launch app: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |