Spaces:
Runtime error
Runtime error
File size: 5,805 Bytes
e4c8837 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Visualization code from https://github.com/Tsingularity/dift/blob/main/src/utils/visualization.py
import io
from pathlib import Path
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from PIL import Image
FONT_SIZE = 40
@torch.no_grad()
def plot_feats(
image,
target,
pred,
legend=["Image", "HR Features", "Pred Features"],
save_path=None,
return_array=False,
show_legend=True,
font_size=FONT_SIZE,
):
"""
Create a plot_feats visualization.
"""
# Ensure hr_or_seg is a list
if not isinstance(pred, list):
pred = [pred]
# Prepare inputs for PCA
feats_for_pca = [target.unsqueeze(0)] + [_.unsqueeze(0) for _ in pred]
reduced_feats, _ = pca(feats_for_pca) # pca outputs a list of reduced tensors
target_imgs = reduced_feats[0]
pred_imgs = reduced_feats[1:]
# --- Plot ---
# Determine number of columns based on whether image is provided
n_cols = (1 if image is not None else 0) + 1 + len(pred_imgs)
fig, ax = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
# Reduce space between images
plt.subplots_adjust(wspace=0.05, hspace=0.05)
# Handle single subplot case
if n_cols == 1:
ax = [ax]
# Current axis index
ax_idx = 0
# Plot original image if provided
if image is not None:
if image.dim() == 3:
ax[ax_idx].imshow(image.permute(1, 2, 0).detach().cpu())
elif image.dim() == 2:
ax[ax_idx].imshow(image.detach().cpu(), cmap="inferno")
if show_legend:
ax[ax_idx].set_title(legend[0], fontsize=font_size)
ax_idx += 1
# Plot the low-resolution features or segmentation mask
ax[ax_idx].imshow(target_imgs[0].permute(1, 2, 0).detach().cpu())
if show_legend:
legend_idx = 1 if image is not None else 0
ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
ax_idx += 1
# Plot HR features or segmentation masks
for idx, pred_img in enumerate(pred_imgs):
ax[ax_idx].imshow(pred_img[0].permute(1, 2, 0).detach().cpu())
if show_legend:
legend_idx = (2 if image is not None else 1) + idx
if len(legend) > legend_idx:
ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
else:
ax[ax_idx].set_title(f"HR Features {idx}", fontsize=font_size)
ax_idx += 1
remove_axes(ax)
# Handle return_array case
if return_array:
# Turn off interactive mode temporarily
was_interactive = plt.isinteractive()
plt.ioff()
# Convert figure to numpy array
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
# Convert to PIL Image then to numpy array
pil_img = Image.open(buf)
img_array = np.array(pil_img)
# Close the figure and buffer
plt.close(fig)
buf.close()
# Restore interactive mode if it was on
if was_interactive:
plt.ion()
return img_array
# Standard behavior: save and/or show
if save_path is not None:
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
plt.show()
return None
def remove_axes(axes):
def _remove_axes(ax):
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_xticks([])
ax.set_yticks([])
if len(axes.shape) == 2:
for ax1 in axes:
for ax in ax1:
_remove_axes(ax)
else:
for ax in axes:
_remove_axes(ax)
def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
target_size = None
if len(image_feats_list) > 1 and fit_pca is None:
target_size = image_feats_list[0].shape[2]
def flatten(tensor, target_size=None):
B, C, H, W = tensor.shape
assert B == 1, "Batch size should be 1 for PCA flattening"
if target_size is not None:
tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear", align_corners=False)
return rearrange(tensor, "b c h w -> (b h w) c").detach().cpu()
flattened_feats = []
for feats in image_feats_list:
flattened_feats.append(flatten(feats, target_size))
x = torch.cat(flattened_feats, dim=0)
# Subsample the data if max_samples is set and the number of samples exceeds max_samples
if max_samples is not None and x.shape[0] > max_samples:
indices = torch.randperm(x.shape[0])[:max_samples]
x = x[indices]
if fit_pca is None:
fit_pca = TorchPCA(n_components=dim).fit(x)
reduced_feats = []
for feats in image_feats_list:
B, C, H, W = feats.shape
x_red = fit_pca.transform(flatten(feats))
if isinstance(x_red, np.ndarray):
x_red = torch.from_numpy(x_red)
x_red -= x_red.min(dim=0, keepdim=True).values
x_red /= x_red.max(dim=0, keepdim=True).values
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2))
return reduced_feats, fit_pca
class TorchPCA(object):
def __init__(self, n_components, skip=0):
self.n_components = n_components
self.skip = skip
def fit(self, X):
self.mean_ = X.mean(dim=0)
unbiased = X - self.mean_
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=20)
self.components_ = V[:, self.skip :]
self.singular_values_ = S
return self
def transform(self, X):
t0 = X - self.mean_.unsqueeze(0)
projected = t0 @ self.components_
return projected
|