| from daam import trace, set_seed
|
| from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| from matplotlib import pyplot as plt
|
| import torch
|
| import os
|
|
|
|
|
| if not torch.cuda.is_available():
|
| raise RuntimeError("CUDA is not available. Please ensure a GPU is available and PyTorch is installed with CUDA support.")
|
|
|
|
|
| output_dir = 'sdxl-creaprompt'
|
| os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
| model_url = 'https://huggingface.co/ApacheOne/local-checkpoints/blob/main/SDXL(PONY)/creapromptLightning_creapromtHypersdxlV1.safetensors'
|
| device = 'cuda'
|
|
|
|
|
| pipe = StableDiffusionXLPipeline.from_single_file(
|
| model_url,
|
| torch_dtype=torch.float16,
|
| use_safetensors=True,
|
| variant='fp16'
|
| )
|
|
|
|
|
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
| pipe.scheduler.config,
|
| use_karras=False
|
| )
|
|
|
|
|
| pipe.enable_model_cpu_offload()
|
| pipe.enable_vae_slicing()
|
| pipe = pipe.to(device)
|
|
|
|
|
| prompt = 'realism woman, wearing dark black low waist jeans, white shoes and red crop top, hands by side, full body shot, Lake Tahoe, (masterpiece best quality ultra-detailed best shadow amazing realistic picture)'
|
| gen = set_seed(42)
|
|
|
|
|
| with torch.no_grad():
|
| with trace(pipe) as tc:
|
| out = pipe(
|
| prompt,
|
| num_inference_steps=13,
|
| generator=gen,
|
| callback=tc.time_callback,
|
| callback_steps=1,
|
| guidance_scale=1.9,
|
| height=1024,
|
| width=1024
|
| )
|
|
|
| generated_image_path = os.path.join(output_dir, 'generated_image.png')
|
| out.images[0].save(generated_image_path)
|
|
|
|
|
| heat_map = tc.compute_global_heat_map()
|
| for word in prompt.split():
|
| word_heat_map = heat_map.compute_word_heat_map(word)
|
|
|
|
|
| fig = plt.figure()
|
| word_heat_map.plot_overlay(out.images[0])
|
| plt.title(f"Heatmap for '{word}'")
|
|
|
|
|
| heatmap_path = os.path.join(output_dir, f'heatmap_{word}.png')
|
| plt.savefig(heatmap_path, bbox_inches='tight')
|
| plt.close(fig)
|
|
|
|
|
| exp = tc.to_experiment('sdxl-creaprompt-experiment-gpu')
|
| exp.save()
|
|
|
| print(f"Generation complete! Images saved in '{output_dir}' folder:")
|
| print(f"- Generated image: {generated_image_path}")
|
| print(f"- Heatmaps: {output_dir}/heatmap_<word>.png")
|
| print("Experiment saved in 'sdxl-creaprompt-experiment-gpu'.") |