| import torch |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| """ |
| This scheduler has 3 main responsibilities: |
| |
| 1. Setup (init) - Pre-compute noise schedule |
| 2. Training (q_sample) - Add noise to images |
| 3. Generation (p_sample_text + sample_text) - Remove noise |
| step-by-step |
| |
| """ |
|
|
|
|
| class SimpleDDPMScheduler: |
| def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02): |
| self.num_timesteps = num_timesteps |
|
|
| |
| self.betas = torch.linspace(beta_start, beta_end, num_timesteps) |
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod( |
| self.alphas, dim=0 |
| ) |
| self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
| |
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) |
|
|
| |
| |
| |
| self.posterior_variance = ( |
| self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| ) |
|
|
| def q_sample(self, x_start, t, noise=None): |
| """Add noise to the clean images according to the noise schedule. |
| |
| So we can have examples at any timestep in the forward process.""" |
| |
| if noise is None: |
| noise = torch.randn_like(x_start) |
|
|
| sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) |
| sqrt_one_minus_alphas_cumprod_t = extract( |
| self.sqrt_one_minus_alphas_cumprod, t, x_start.shape |
| ) |
|
|
| return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
|
|
| def p_sample_text(self, model, x, t, text_embeddings, guidance_scale=1.0): |
| """Sample x_{t-1} from x_t using the model with text conditioning and CFG. |
| |
| Args: |
| model: The diffusion model |
| x: Current noisy image |
| t: Current timestep |
| text_embeddings: Text embeddings for conditioning |
| guidance_scale: Classifier-free guidance scale (1.0 = no guidance, higher = stronger) |
| """ |
| |
| predicted_noise = model(x, t, text_embeddings) |
|
|
| |
| if guidance_scale > 1.0: |
| |
| uncond_embeddings = torch.zeros_like(text_embeddings) |
| uncond_noise = model(x, t, uncond_embeddings) |
|
|
| |
| predicted_noise = uncond_noise + guidance_scale * (predicted_noise - uncond_noise) |
|
|
| |
| betas_t = extract(self.betas, t, x.shape) |
| sqrt_one_minus_alphas_cumprod_t = extract( |
| self.sqrt_one_minus_alphas_cumprod, t, x.shape |
| ) |
| sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, x.shape) |
|
|
| |
| model_mean = sqrt_recip_alphas_t * ( |
| x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t |
| ) |
|
|
| if t[0] == 0: |
| return model_mean |
| else: |
| posterior_variance_t = extract(self.posterior_variance, t, x.shape) |
| noise = torch.randn_like(x) |
| return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
| def sample_text(self, model, shape, text_embeddings, device="cuda", guidance_scale=1.0): |
| """Generate samples using DDPM sampling with text conditioning and CFG. |
| |
| Args: |
| model: The diffusion model |
| shape: Output shape (B, C, H, W) |
| text_embeddings: Text embeddings for conditioning |
| device: Device to use |
| guidance_scale: Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical) |
| """ |
| b = shape[0] |
| img = torch.randn(shape, device=device) |
|
|
| for i in reversed(range(0, self.num_timesteps)): |
| t = torch.full((b,), i, device=device, dtype=torch.long) |
| img = self.p_sample_text(model, img, t, text_embeddings, guidance_scale) |
|
|
| |
| img = torch.clamp(img, -2.0, 2.0) |
|
|
| return img |
|
|
|
|
| def extract(a, t, x_shape): |
| """Extract coefficients from a based on t and reshape to match x_shape.""" |
| batch_size = t.shape[0] |
| out = a.gather(-1, t.cpu()) |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|