ACE-Step-Custom / app.py
ACE-Step Custom
Fix: Implement lazy model loading for ZeroGPU compatibility
052ca84
"""
ACE-Step 1.5 Custom Edition - Main Application
A comprehensive music generation system with three main interfaces:
1. Standard ACE-Step GUI
2. Custom Timeline-based Workflow
3. LoRA Training Studio
"""
import gradio as gr
import torch
import numpy as np
from pathlib import Path
import json
from typing import Optional, List, Tuple
import spaces
from src.ace_step_engine import ACEStepEngine
from src.timeline_manager import TimelineManager
from src.lora_trainer import LoRATrainer
from src.audio_processor import AudioProcessor
from src.utils import setup_logging, load_config
# Setup
logger = setup_logging()
config = load_config()
# Lazy initialize components (will be initialized on first use)
ace_engine = None
timeline_manager = None
lora_trainer = None
audio_processor = None
def get_ace_engine():
"""Lazy-load ACE-Step engine."""
global ace_engine
if ace_engine is None:
ace_engine = ACEStepEngine(config)
return ace_engine
def get_timeline_manager():
"""Lazy-load timeline manager."""
global timeline_manager
if timeline_manager is None:
timeline_manager = TimelineManager(config)
return timeline_manager
def get_lora_trainer():
"""Lazy-load LoRA trainer."""
global lora_trainer
if lora_trainer is None:
lora_trainer = LoRATrainer(config)
return lora_trainer
def get_audio_processor():
"""Lazy-load audio processor."""
global audio_processor
if audio_processor is None:
audio_processor = AudioProcessor(config)
return audio_processor
# ==================== TAB 1: STANDARD ACE-STEP GUI ====================
@spaces.GPU(duration=300)
def standard_generate(
prompt: str,
lyrics: str,
duration: int,
temperature: float,
top_p: float,
seed: int,
style: str,
use_lora: bool,
lora_path: Optional[str] = None
) -> Tuple[str, str]:
"""Standard ACE-Step generation with all original features."""
try:
logger.info(f"Standard generation: {prompt[:50]}...")
# Get engine instance
engine = get_ace_engine()
# Generate audio
audio_path = engine.generate(
prompt=prompt,
lyrics=lyrics,
duration=duration,
temperature=temperature,
top_p=top_p,
seed=seed,
style=style,
lora_path=lora_path if use_lora else None
)
info = f"✅ Generated {duration}s audio successfully"
return audio_path, info
except Exception as e:
logger.error(f"Standard generation failed: {e}")
return None, f"❌ Error: {str(e)}"
@spaces.GPU(duration=180)
def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]:
"""Generate variation of existing audio."""
try:
result = get_ace_engine().generate_variation(audio_path, variation_strength)
return result, "✅ Variation generated"
except Exception as e:
return None, f"❌ Error: {str(e)}"
@spaces.GPU(duration=180)
def standard_repaint(
audio_path: str,
start_time: float,
end_time: float,
new_prompt: str
) -> Tuple[str, str]:
"""Repaint specific section of audio."""
try:
result = get_ace_engine().repaint(audio_path, start_time, end_time, new_prompt)
return result, f"✅ Repainted {start_time}s-{end_time}s"
except Exception as e:
return None, f"❌ Error: {str(e)}"
@spaces.GPU(duration=180)
def standard_lyric_edit(
audio_path: str,
new_lyrics: str
) -> Tuple[str, str]:
"""Edit lyrics while maintaining music."""
try:
result = get_ace_engine().edit_lyrics(audio_path, new_lyrics)
return result, "✅ Lyrics edited"
except Exception as e:
return None, f"❌ Error: {str(e)}"
# ==================== TAB 2: CUSTOM TIMELINE WORKFLOW ====================
@spaces.GPU(duration=300)
def timeline_generate(
prompt: str,
lyrics: str,
context_length: int,
style: str,
temperature: float,
seed: int,
session_state: dict
) -> Tuple[str, str, str, dict]:
"""
Generate 32-second clip with 2s lead-in, 28s main, 2s lead-out.
Blends with previous clips based on context_length.
"""
try:
# Initialize session state if None
if session_state is None:
session_state = {"timeline_id": None, "total_clips": 0}
logger.info(f"Timeline generation with {context_length}s context")
# Get managers
tm = get_timeline_manager()
engine = get_ace_engine()
ap = get_audio_processor()
# Get context from timeline
context_audio = tm.get_context(
session_state.get("timeline_id"),
context_length
)
# Generate 32s clip
clip = engine.generate_clip(
prompt=prompt,
lyrics=lyrics,
duration=32,
context_audio=context_audio,
style=style,
temperature=temperature,
seed=seed
)
# Blend with timeline (2s lead-in and lead-out)
blended_clip = ap.blend_clip(
clip,
tm.get_last_clip(session_state.get("timeline_id")),
lead_in=2.0,
lead_out=2.0
)
# Add to timeline
timeline_id = tm.add_clip(
session_state.get("timeline_id"),
blended_clip,
metadata={
"prompt": prompt,
"lyrics": lyrics,
"context_length": context_length
}
)
# Update session
session_state["timeline_id"] = timeline_id
session_state["total_clips"] = session_state.get("total_clips", 0) + 1
# Get full timeline audio
full_audio = tm.export_timeline(timeline_id)
# Get timeline visualization
timeline_viz = tm.visualize_timeline(timeline_id)
info = f"✅ Clip {session_state['total_clips']} added • Total: {tm.get_duration(timeline_id):.1f}s"
return blended_clip, full_audio, timeline_viz, session_state, info
except Exception as e:
logger.error(f"Timeline generation failed: {e}")
return None, None, None, session_state, f"❌ Error: {str(e)}"
def timeline_extend(
prompt: str,
lyrics: str,
context_length: int,
session_state: dict
) -> Tuple[str, str, str, dict]:
"""Extend current timeline with new generation."""
return timeline_generate(
prompt, lyrics, context_length, "auto", 0.7, -1, session_state
)
@spaces.GPU(duration=240)
def timeline_inpaint(
start_time: float,
end_time: float,
new_prompt: str,
session_state: dict
) -> Tuple[str, str, dict]:
"""Inpaint specific region in timeline."""
try:
# Initialize session state if None
if session_state is None:
session_state = {"timeline_id": None, "total_clips": 0}
tm = get_timeline_manager()
timeline_id = session_state.get("timeline_id")
result = tm.inpaint_region(
timeline_id,
start_time,
end_time,
new_prompt
)
full_audio = tm.export_timeline(timeline_id)
timeline_viz = tm.visualize_timeline(timeline_id)
info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s"
return full_audio, timeline_viz, session_state, info
except Exception as e:
return None, None, session_state, f"❌ Error: {str(e)}"
def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]:
"""Reset timeline to start fresh."""
# Initialize session state if None
if session_state is None:
session_state = {"timeline_id": None, "total_clips": 0}
elif session_state.get("timeline_id"):
get_timeline_manager().delete_timeline(session_state["timeline_id"])
session_state = {"timeline_id": None, "total_clips": 0}
return None, None, "Timeline cleared", session_state
# ==================== TAB 3: LORA TRAINING ====================
def lora_upload_files(files: List[str]) -> str:
"""Upload and prepare audio files for LoRA training."""
try:
prepared_files = get_lora_trainer().prepare_dataset(files)
return f"✅ Prepared {len(prepared_files)} files for training"
except Exception as e:
return f"❌ Error: {str(e)}"
@spaces.GPU(duration=300)
def lora_train(
dataset_path: str,
model_name: str,
learning_rate: float,
batch_size: int,
num_epochs: int,
rank: int,
alpha: int,
use_existing_lora: bool,
existing_lora_path: Optional[str] = None,
progress=gr.Progress()
) -> Tuple[str, str]:
"""Train LoRA model on uploaded dataset."""
try:
logger.info(f"Starting LoRA training: {model_name}")
# Initialize or load LoRA
if use_existing_lora and existing_lora_path:
lora_trainer.load_lora(existing_lora_path)
else:
lora_trainer.initialize_lora(rank=rank, alpha=alpha)
# Train
def progress_callback(step, total_steps, loss):
progress((step, total_steps), desc=f"Training (loss: {loss:.4f})")
result_path = lora_trainer.train(
dataset_path=dataset_path,
model_name=model_name,
learning_rate=learning_rate,
batch_size=batch_size,
num_epochs=num_epochs,
progress_callback=progress_callback
)
info = f"✅ Training complete! Model saved to {result_path}"
return result_path, info
except Exception as e:
logger.error(f"LoRA training failed: {e}")
return None, f"❌ Error: {str(e)}"
def lora_download(lora_path: str) -> str:
"""Provide LoRA model for download."""
return lora_path if Path(lora_path).exists() else None
# ==================== GRADIO UI ====================
def create_ui():
"""Create the three-tab Gradio interface."""
with gr.Blocks(title="ACE-Step 1.5 Custom Edition", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🎵 ACE-Step 1.5 Custom Edition
**Three powerful interfaces for music generation and training**
Models will download automatically on first use (~7GB from HuggingFace)
""")
with gr.Tabs():
# ============ TAB 1: STANDARD ACE-STEP ============
with gr.Tab("🎼 Standard ACE-Step"):
gr.Markdown("### Full-featured standard ACE-Step 1.5 interface")
with gr.Row():
with gr.Column():
std_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the music style, mood, instruments...",
lines=3
)
std_lyrics = gr.Textbox(
label="Lyrics (optional)",
placeholder="Enter lyrics here...",
lines=5
)
with gr.Row():
std_duration = gr.Slider(
minimum=10, maximum=240, value=30, step=10,
label="Duration (seconds)"
)
std_style = gr.Dropdown(
choices=["auto", "pop", "rock", "jazz", "classical", "electronic", "hip-hop"],
value="auto",
label="Style"
)
with gr.Row():
std_temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.7, step=0.1,
label="Temperature"
)
std_top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.05,
label="Top P"
)
std_seed = gr.Number(label="Seed (-1 for random)", value=-1)
with gr.Row():
std_use_lora = gr.Checkbox(label="Use LoRA", value=False)
std_lora_path = gr.Textbox(
label="LoRA Path",
placeholder="Path to LoRA model (if using)"
)
std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg")
with gr.Column():
gr.Markdown("### Audio Input (Optional)")
gr.Markdown("*Upload audio file or record to use as style guidance*")
std_audio_input = gr.Audio(
label="Style Reference Audio",
type="filepath"
)
gr.Markdown("### Generated Output")
std_audio_out = gr.Audio(label="Generated Audio")
std_info = gr.Textbox(label="Status", lines=2)
gr.Markdown("### Advanced Controls")
with gr.Accordion("🔄 Generate Variation", open=False):
std_var_strength = gr.Slider(0.1, 1.0, 0.5, label="Variation Strength")
std_var_btn = gr.Button("Generate Variation")
with gr.Accordion("🎨 Repaint Section", open=False):
std_repaint_start = gr.Number(label="Start Time (s)", value=0)
std_repaint_end = gr.Number(label="End Time (s)", value=10)
std_repaint_prompt = gr.Textbox(label="New Prompt", lines=2)
std_repaint_btn = gr.Button("Repaint")
with gr.Accordion("✏️ Edit Lyrics", open=False):
std_edit_lyrics = gr.Textbox(label="New Lyrics", lines=4)
std_edit_btn = gr.Button("Edit Lyrics")
# Event handlers
std_generate_btn.click(
fn=standard_generate,
inputs=[std_prompt, std_lyrics, std_duration, std_temperature,
std_top_p, std_seed, std_style, std_use_lora, std_lora_path],
outputs=[std_audio_out, std_info]
)
std_var_btn.click(
fn=standard_variation,
inputs=[std_audio_out, std_var_strength],
outputs=[std_audio_out, std_info]
)
std_repaint_btn.click(
fn=standard_repaint,
inputs=[std_audio_out, std_repaint_start, std_repaint_end, std_repaint_prompt],
outputs=[std_audio_out, std_info]
)
std_edit_btn.click(
fn=standard_lyric_edit,
inputs=[std_audio_out, std_edit_lyrics],
outputs=[std_audio_out, std_info]
)
# ============ TAB 2: CUSTOM TIMELINE ============
with gr.Tab("⏱️ Timeline Workflow"):
gr.Markdown("""
### Custom Timeline-based Generation
Generate 32-second clips that seamlessly blend together on a master timeline.
""")
# Session state for timeline
timeline_state = gr.State(value=None)
with gr.Row():
with gr.Column():
tl_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe this section...",
lines=3
)
tl_lyrics = gr.Textbox(
label="Lyrics for this clip",
placeholder="Enter lyrics for this 32s section...",
lines=4
)
gr.Markdown("*How far back to reference for style guidance*")
tl_context_length = gr.Slider(
minimum=0, maximum=120, value=30, step=10,
label="Context Length (seconds)"
)
with gr.Row():
tl_style = gr.Dropdown(
choices=["auto", "pop", "rock", "jazz", "electronic"],
value="auto",
label="Style"
)
tl_temperature = gr.Slider(
minimum=0.5, maximum=1.0, value=0.7, step=0.05,
label="Temperature"
)
tl_seed = gr.Number(label="Seed (-1 for random)", value=-1)
with gr.Row():
tl_generate_btn = gr.Button("🎵 Generate Clip", variant="primary", size="lg")
tl_extend_btn = gr.Button("➕ Extend", size="lg")
tl_reset_btn = gr.Button("🔄 Reset Timeline", variant="secondary")
tl_info = gr.Textbox(label="Status", lines=2)
with gr.Column():
tl_clip_audio = gr.Audio(label="Latest Clip")
tl_full_audio = gr.Audio(label="Full Timeline")
tl_timeline_viz = gr.Image(label="Timeline Visualization")
with gr.Accordion("🎨 Inpaint Timeline Region", open=False):
tl_inpaint_start = gr.Number(label="Start Time (s)", value=0)
tl_inpaint_end = gr.Number(label="End Time (s)", value=10)
tl_inpaint_prompt = gr.Textbox(label="New Prompt", lines=2)
tl_inpaint_btn = gr.Button("Inpaint Region")
# Event handlers
tl_generate_btn.click(
fn=timeline_generate,
inputs=[tl_prompt, tl_lyrics, tl_context_length, tl_style,
tl_temperature, tl_seed, timeline_state],
outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info]
)
tl_extend_btn.click(
fn=timeline_extend,
inputs=[tl_prompt, tl_lyrics, tl_context_length, timeline_state],
outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info]
)
tl_reset_btn.click(
fn=timeline_reset,
inputs=[timeline_state],
outputs=[tl_clip_audio, tl_full_audio, tl_info, timeline_state]
)
tl_inpaint_btn.click(
fn=timeline_inpaint,
inputs=[tl_inpaint_start, tl_inpaint_end, tl_inpaint_prompt, timeline_state],
outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info]
)
# ============ TAB 3: LORA TRAINING ============
with gr.Tab("🎓 LoRA Training Studio"):
gr.Markdown("""
### Train Custom LoRA Models
Upload audio files to train specialized models for voice cloning, style adaptation, etc.
""")
with gr.Row():
with gr.Column():
gr.Markdown("#### 1. Upload Training Data")
lora_files = gr.File(
label="Audio Files",
file_count="multiple",
file_types=["audio"]
)
lora_upload_btn = gr.Button("📤 Upload & Prepare Dataset")
lora_upload_status = gr.Textbox(label="Upload Status", lines=2)
gr.Markdown("#### 2. Training Configuration")
lora_dataset_path = gr.Textbox(
label="Dataset Path",
placeholder="Path to prepared dataset"
)
lora_model_name = gr.Textbox(
label="Model Name",
placeholder="my_custom_lora"
)
with gr.Row():
lora_learning_rate = gr.Number(
label="Learning Rate",
value=1e-4
)
lora_batch_size = gr.Slider(
minimum=1, maximum=16, value=4, step=1,
label="Batch Size"
)
with gr.Row():
lora_num_epochs = gr.Slider(
minimum=1, maximum=100, value=10, step=1,
label="Epochs"
)
lora_rank = gr.Slider(
minimum=4, maximum=128, value=16, step=4,
label="LoRA Rank"
)
lora_alpha = gr.Slider(
minimum=4, maximum=128, value=32, step=4,
label="LoRA Alpha"
)
lora_use_existing = gr.Checkbox(
label="Continue training from existing LoRA",
value=False
)
lora_existing_path = gr.Textbox(
label="Existing LoRA Path",
placeholder="Path to existing LoRA model"
)
lora_train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
with gr.Column():
lora_train_status = gr.Textbox(label="Training Status", lines=3)
lora_model_path = gr.Textbox(label="Trained Model Path", lines=1)
lora_download_btn = gr.Button("💾 Download Model")
lora_download_file = gr.File(label="Download")
gr.Markdown("""
#### Training Tips
- Upload 10+ audio samples for best results
- Keep samples consistent in style/quality
- Higher rank = more capacity but slower training
- Start with 10-20 epochs and adjust
- Use existing LoRA to continue training
""")
# Event handlers
lora_upload_btn.click(
fn=lora_upload_files,
inputs=[lora_files],
outputs=[lora_upload_status]
)
lora_train_btn.click(
fn=lora_train,
inputs=[lora_dataset_path, lora_model_name, lora_learning_rate,
lora_batch_size, lora_num_epochs, lora_rank, lora_alpha,
lora_use_existing, lora_existing_path],
outputs=[lora_model_path, lora_train_status]
)
lora_download_btn.click(
fn=lora_download,
inputs=[lora_model_path],
outputs=[lora_download_file]
)
gr.Markdown("""
---
### About
ACE-Step 1.5 Custom Edition by Gamahea | Based on [ACE-Step](https://ace-step.github.io/)
""")
return app
# ==================== MAIN ====================
if __name__ == "__main__":
logger.info("Starting ACE-Step 1.5 Custom Edition...")
try:
# Create and launch app
app = create_ui()
# Monkey patch the get_api_info method to prevent JSON schema errors
original_get_api_info = app.get_api_info
def safe_get_api_info(*args, **kwargs):
"""Patched get_api_info that returns minimal info to avoid schema errors"""
try:
return original_get_api_info(*args, **kwargs)
except (TypeError, AttributeError, KeyError) as e:
logger.warning(f"API info generation failed, returning minimal info: {e}")
return {
"named_endpoints": {},
"unnamed_endpoints": {}
}
app.get_api_info = safe_get_api_info
logger.info("✓ Patched get_api_info method")
# Launch the app
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)
except Exception as e:
logger.error(f"Failed to launch app: {e}")
import traceback
traceback.print_exc()
raise