| | |
| | """ |
| | Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow. |
| | |
| | This script converts LoRA weights saved by SimpleTuner into a format that can be |
| | directly loaded by diffusers' load_lora_weights() method. |
| | |
| | Usage: |
| | python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors> |
| | |
| | Example: |
| | python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | from pathlib import Path |
| | from typing import Dict |
| |
|
| | import safetensors.torch |
| | import torch |
| |
|
| |
|
| | def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str: |
| | """ |
| | Detect the format of the LoRA state dict. |
| | |
| | Returns: |
| | "peft" if already in PEFT/diffusers format |
| | "mixed" if mixed format (some lora_A/B, some lora.down/up) |
| | "simpletuner_transformer" if in SimpleTuner format with transformer prefix |
| | "simpletuner_auraflow" if in SimpleTuner AuraFlow format |
| | "kohya" if in Kohya format |
| | "unknown" otherwise |
| | """ |
| | keys = list(state_dict.keys()) |
| | |
| | |
| | has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys) |
| | has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys) |
| | has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys) |
| | |
| | |
| | has_transformer_prefix = any(k.startswith("transformer.") for k in keys) |
| | has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys) |
| | has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys) |
| | |
| | |
| | if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up): |
| | return "mixed" |
| | |
| | |
| | if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up: |
| | return "peft" |
| | |
| | |
| | if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up): |
| | return "simpletuner_transformer" |
| | |
| | |
| | if has_lora_transformer_prefix and has_lora_down_up: |
| | return "simpletuner_auraflow" |
| | |
| | |
| | if has_lora_unet_prefix and has_lora_down_up: |
| | return "kohya" |
| | |
| | return "unknown" |
| |
|
| |
|
| | def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Convert mixed LoRA format to pure PEFT format. |
| | |
| | SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B |
| | and others use .lora.down./.lora.up. This converts all to lora_A/lora_B. |
| | """ |
| | new_state_dict = {} |
| | converted_count = 0 |
| | kept_count = 0 |
| | skipped_count = 0 |
| | renames = [] |
| | |
| | |
| | all_keys = sorted(state_dict.keys()) |
| | |
| | print("\nProcessing keys:") |
| | print("-" * 80) |
| | |
| | for key in all_keys: |
| | |
| | if ".lora_A." in key or ".lora_B." in key: |
| | new_state_dict[key] = state_dict[key] |
| | kept_count += 1 |
| | |
| | |
| | elif ".lora.down.weight" in key: |
| | new_key = key.replace(".lora.down.weight", ".lora_A.weight") |
| | new_state_dict[new_key] = state_dict[key] |
| | renames.append((key, new_key)) |
| | converted_count += 1 |
| | |
| | |
| | elif ".lora.up.weight" in key: |
| | new_key = key.replace(".lora.up.weight", ".lora_B.weight") |
| | new_state_dict[new_key] = state_dict[key] |
| | renames.append((key, new_key)) |
| | converted_count += 1 |
| | |
| | |
| | elif ".alpha" in key: |
| | skipped_count += 1 |
| | continue |
| | |
| | |
| | else: |
| | new_state_dict[key] = state_dict[key] |
| | print(f"⚠ Warning: Unexpected key format: {key}") |
| | |
| | print(f"\nSummary:") |
| | print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)") |
| | print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B") |
| | print(f" ✓ Skipped {skipped_count} alpha keys") |
| | |
| | if renames: |
| | print(f"\nRenames applied ({len(renames)} conversions):") |
| | print("-" * 80) |
| | for old_key, new_key in renames: |
| | |
| | if ".lora.down.weight" in old_key: |
| | layer = old_key.replace(".lora.down.weight", "") |
| | print(f" {layer}") |
| | print(f" .lora.down.weight → .lora_A.weight") |
| | elif ".lora.up.weight" in old_key: |
| | layer = old_key.replace(".lora.up.weight", "") |
| | print(f" {layer}") |
| | print(f" .lora.up.weight → .lora_B.weight") |
| | |
| | return new_state_dict |
| |
|
| |
|
| | def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up) |
| | to diffusers PEFT format (transformer. prefix with lora_A/lora_B). |
| | |
| | This is a simpler conversion since the key structure is already correct. |
| | """ |
| | new_state_dict = {} |
| | renames = [] |
| | |
| | |
| | all_keys = list(state_dict.keys()) |
| | base_keys = set() |
| | |
| | for key in all_keys: |
| | if ".lora_down.weight" in key: |
| | base_key = key.replace(".lora_down.weight", "") |
| | base_keys.add(base_key) |
| | |
| | print(f"\nFound {len(base_keys)} LoRA layers to convert") |
| | print("-" * 80) |
| | |
| | |
| | for base_key in sorted(base_keys): |
| | down_key = f"{base_key}.lora_down.weight" |
| | up_key = f"{base_key}.lora_up.weight" |
| | alpha_key = f"{base_key}.alpha" |
| | |
| | if down_key not in state_dict or up_key not in state_dict: |
| | print(f"⚠ Warning: Missing weights for {base_key}") |
| | continue |
| | |
| | down_weight = state_dict.pop(down_key) |
| | up_weight = state_dict.pop(up_key) |
| | |
| | |
| | has_alpha = False |
| | if alpha_key in state_dict: |
| | alpha = state_dict.pop(alpha_key) |
| | lora_rank = down_weight.shape[0] |
| | scale = alpha / lora_rank |
| | |
| | |
| | scale_down = scale |
| | scale_up = 1.0 |
| | while scale_down * 2 < scale_up: |
| | scale_down *= 2 |
| | scale_up /= 2 |
| | |
| | down_weight = down_weight * scale_down |
| | up_weight = up_weight * scale_up |
| | has_alpha = True |
| | |
| | |
| | new_down_key = f"{base_key}.lora_A.weight" |
| | new_up_key = f"{base_key}.lora_B.weight" |
| | |
| | new_state_dict[new_down_key] = down_weight |
| | new_state_dict[new_up_key] = up_weight |
| | |
| | renames.append((down_key, new_down_key, has_alpha)) |
| | renames.append((up_key, new_up_key, has_alpha)) |
| | |
| | |
| | remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")] |
| | if remaining: |
| | print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}") |
| | |
| | print(f"\nRenames applied ({len(renames)} conversions):") |
| | print("-" * 80) |
| | |
| | |
| | current_layer = None |
| | for old_key, new_key, has_alpha in renames: |
| | layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "") |
| | |
| | if layer != current_layer: |
| | alpha_str = " (alpha scaled)" if has_alpha else "" |
| | print(f"\n {layer}{alpha_str}") |
| | current_layer = layer |
| | |
| | if ".lora_down.weight" in old_key: |
| | print(f" .lora_down.weight → .lora_A.weight") |
| | elif ".lora_up.weight" in old_key: |
| | print(f" .lora_up.weight → .lora_B.weight") |
| | |
| | return new_state_dict |
| |
|
| |
|
| | def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format. |
| | |
| | SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts, |
| | but for transformer-based models like AuraFlow, the keys may differ. |
| | """ |
| | new_state_dict = {} |
| | |
| | def _convert(original_key, diffusers_key, state_dict, new_state_dict): |
| | """Helper to convert a single LoRA layer.""" |
| | down_key = f"{original_key}.lora_down.weight" |
| | if down_key not in state_dict: |
| | return False |
| | |
| | down_weight = state_dict.pop(down_key) |
| | lora_rank = down_weight.shape[0] |
| | |
| | up_weight_key = f"{original_key}.lora_up.weight" |
| | up_weight = state_dict.pop(up_weight_key) |
| | |
| | |
| | alpha_key = f"{original_key}.alpha" |
| | if alpha_key in state_dict: |
| | alpha = state_dict.pop(alpha_key) |
| | scale = alpha / lora_rank |
| | |
| | |
| | scale_down = scale |
| | scale_up = 1.0 |
| | while scale_down * 2 < scale_up: |
| | scale_down *= 2 |
| | scale_up /= 2 |
| | |
| | down_weight = down_weight * scale_down |
| | up_weight = up_weight * scale_up |
| | |
| | |
| | diffusers_down_key = f"{diffusers_key}.lora_A.weight" |
| | new_state_dict[diffusers_down_key] = down_weight |
| | new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight |
| | |
| | return True |
| | |
| | |
| | all_unique_keys = { |
| | k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") |
| | for k in state_dict |
| | if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k |
| | } |
| | |
| | |
| | for original_key in sorted(all_unique_keys): |
| | if original_key.startswith("lora_transformer_single_transformer_blocks_"): |
| | |
| | parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_") |
| | block_idx = int(parts[0]) |
| | diffusers_key = f"single_transformer_blocks.{block_idx}" |
| | |
| | |
| | remaining = "_".join(parts[1:]) |
| | if "attn_to_q" in remaining: |
| | diffusers_key += ".attn.to_q" |
| | elif "attn_to_k" in remaining: |
| | diffusers_key += ".attn.to_k" |
| | elif "attn_to_v" in remaining: |
| | diffusers_key += ".attn.to_v" |
| | elif "proj_out" in remaining: |
| | diffusers_key += ".proj_out" |
| | elif "proj_mlp" in remaining: |
| | diffusers_key += ".proj_mlp" |
| | elif "norm_linear" in remaining: |
| | diffusers_key += ".norm.linear" |
| | else: |
| | print(f"Warning: Unhandled single block key pattern: {original_key}") |
| | continue |
| | |
| | elif original_key.startswith("lora_transformer_transformer_blocks_"): |
| | |
| | parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_") |
| | block_idx = int(parts[0]) |
| | diffusers_key = f"transformer_blocks.{block_idx}" |
| | |
| | |
| | remaining = "_".join(parts[1:]) |
| | if "attn_to_out_0" in remaining: |
| | diffusers_key += ".attn.to_out.0" |
| | elif "attn_to_add_out" in remaining: |
| | diffusers_key += ".attn.to_add_out" |
| | elif "attn_to_q" in remaining: |
| | diffusers_key += ".attn.to_q" |
| | elif "attn_to_k" in remaining: |
| | diffusers_key += ".attn.to_k" |
| | elif "attn_to_v" in remaining: |
| | diffusers_key += ".attn.to_v" |
| | elif "attn_add_q_proj" in remaining: |
| | diffusers_key += ".attn.add_q_proj" |
| | elif "attn_add_k_proj" in remaining: |
| | diffusers_key += ".attn.add_k_proj" |
| | elif "attn_add_v_proj" in remaining: |
| | diffusers_key += ".attn.add_v_proj" |
| | elif "ff_net_0_proj" in remaining: |
| | diffusers_key += ".ff.net.0.proj" |
| | elif "ff_net_2" in remaining: |
| | diffusers_key += ".ff.net.2" |
| | elif "ff_context_net_0_proj" in remaining: |
| | diffusers_key += ".ff_context.net.0.proj" |
| | elif "ff_context_net_2" in remaining: |
| | diffusers_key += ".ff_context.net.2" |
| | elif "norm1_linear" in remaining: |
| | diffusers_key += ".norm1.linear" |
| | elif "norm1_context_linear" in remaining: |
| | diffusers_key += ".norm1_context.linear" |
| | else: |
| | print(f"Warning: Unhandled double block key pattern: {original_key}") |
| | continue |
| | |
| | elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"): |
| | |
| | print(f"Found text encoder key: {original_key}") |
| | continue |
| | |
| | else: |
| | print(f"Warning: Unknown key pattern: {original_key}") |
| | continue |
| | |
| | |
| | _convert(original_key, diffusers_key, state_dict, new_state_dict) |
| | |
| | |
| | transformer_state_dict = { |
| | f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.") |
| | } |
| | |
| | |
| | if len(state_dict) > 0: |
| | remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")] |
| | if remaining_keys: |
| | print(f"Warning: Some keys were not converted: {remaining_keys[:10]}") |
| | |
| | return transformer_state_dict |
| |
|
| |
|
| | def convert_lora(input_path: str, output_path: str) -> None: |
| | """ |
| | Main conversion function. |
| | |
| | Args: |
| | input_path: Path to input LoRA safetensors file |
| | output_path: Path to output diffusers-compatible safetensors file |
| | """ |
| | print(f"Loading LoRA from: {input_path}") |
| | state_dict = safetensors.torch.load_file(input_path) |
| | |
| | print(f"Detecting LoRA format...") |
| | format_type = detect_lora_format(state_dict) |
| | print(f"Detected format: {format_type}") |
| | |
| | if format_type == "peft": |
| | print("LoRA is already in diffusers-compatible PEFT format!") |
| | print("No conversion needed. Copying file...") |
| | import shutil |
| | shutil.copy(input_path, output_path) |
| | return |
| | |
| | elif format_type == "mixed": |
| | print("Converting MIXED format LoRA to pure PEFT format...") |
| | print("(Some layers use lora_A/B, others use .lora.down/.lora.up)") |
| | converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy()) |
| | |
| | elif format_type == "simpletuner_transformer": |
| | print("Converting SimpleTuner transformer format to diffusers...") |
| | print("(has transformer. prefix but uses lora_down/lora_up naming)") |
| | converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy()) |
| | |
| | elif format_type == "simpletuner_auraflow": |
| | print("Converting SimpleTuner AuraFlow format to diffusers...") |
| | converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
| | |
| | elif format_type == "kohya": |
| | print("Note: Detected Kohya format. This converter is optimized for AuraFlow.") |
| | print("For other models, diffusers has built-in conversion.") |
| | converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
| | |
| | else: |
| | print("Error: Unknown LoRA format!") |
| | print("Sample keys from the state dict:") |
| | for i, key in enumerate(list(state_dict.keys())[:20]): |
| | print(f" {key}") |
| | sys.exit(1) |
| | |
| | print(f"Saving converted LoRA to: {output_path}") |
| | safetensors.torch.save_file(converted_state_dict, output_path) |
| | |
| | print("\nConversion complete!") |
| | print(f"Original keys: {len(state_dict)}") |
| | print(f"Converted keys: {len(converted_state_dict)}") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Convert SimpleTuner LoRA to diffusers-compatible format", |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | # Convert a SimpleTuner LoRA for AuraFlow |
| | python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors |
| | |
| | # Check format without converting |
| | python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors |
| | """ |
| | ) |
| | |
| | parser.add_argument( |
| | "input", |
| | type=str, |
| | help="Input LoRA file (SimpleTuner format)" |
| | ) |
| | |
| | parser.add_argument( |
| | "output", |
| | type=str, |
| | help="Output LoRA file (diffusers-compatible format)" |
| | ) |
| | |
| | parser.add_argument( |
| | "--dry-run", |
| | action="store_true", |
| | help="Only detect format, don't convert" |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if not Path(args.input).exists(): |
| | print(f"Error: Input file not found: {args.input}") |
| | sys.exit(1) |
| | |
| | if args.dry_run: |
| | print(f"Loading LoRA from: {args.input}") |
| | state_dict = safetensors.torch.load_file(args.input) |
| | format_type = detect_lora_format(state_dict) |
| | print(f"Detected format: {format_type}") |
| | print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):") |
| | for key in list(state_dict.keys())[:10]: |
| | print(f" {key}") |
| | return |
| | |
| | convert_lora(args.input, args.output) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|