File size: 5,805 Bytes
e4c8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Visualization code from https://github.com/Tsingularity/dift/blob/main/src/utils/visualization.py

import io
from pathlib import Path

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from PIL import Image

FONT_SIZE = 40


@torch.no_grad()
def plot_feats(
    image,
    target,
    pred,
    legend=["Image", "HR Features", "Pred Features"],
    save_path=None,
    return_array=False,
    show_legend=True,
    font_size=FONT_SIZE,
):
    """
    Create a plot_feats visualization.
    """
    # Ensure hr_or_seg is a list
    if not isinstance(pred, list):
        pred = [pred]

    # Prepare inputs for PCA
    feats_for_pca = [target.unsqueeze(0)] + [_.unsqueeze(0) for _ in pred]
    reduced_feats, _ = pca(feats_for_pca)  # pca outputs a list of reduced tensors

    target_imgs = reduced_feats[0]
    pred_imgs = reduced_feats[1:]

    # --- Plot ---
    # Determine number of columns based on whether image is provided
    n_cols = (1 if image is not None else 0) + 1 + len(pred_imgs)
    fig, ax = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
    # Reduce space between images
    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    # Handle single subplot case
    if n_cols == 1:
        ax = [ax]

    # Current axis index
    ax_idx = 0

    # Plot original image if provided
    if image is not None:
        if image.dim() == 3:
            ax[ax_idx].imshow(image.permute(1, 2, 0).detach().cpu())
        elif image.dim() == 2:
            ax[ax_idx].imshow(image.detach().cpu(), cmap="inferno")
        if show_legend:
            ax[ax_idx].set_title(legend[0], fontsize=font_size)
        ax_idx += 1

    # Plot the low-resolution features or segmentation mask
    ax[ax_idx].imshow(target_imgs[0].permute(1, 2, 0).detach().cpu())
    if show_legend:
        legend_idx = 1 if image is not None else 0
        ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
    ax_idx += 1

    # Plot HR features or segmentation masks
    for idx, pred_img in enumerate(pred_imgs):
        ax[ax_idx].imshow(pred_img[0].permute(1, 2, 0).detach().cpu())
        if show_legend:
            legend_idx = (2 if image is not None else 1) + idx
            if len(legend) > legend_idx:
                ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
            else:
                ax[ax_idx].set_title(f"HR Features {idx}", fontsize=font_size)
        ax_idx += 1

    remove_axes(ax)

    # Handle return_array case
    if return_array:
        # Turn off interactive mode temporarily
        was_interactive = plt.isinteractive()
        plt.ioff()

        # Convert figure to numpy array
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
        buf.seek(0)

        # Convert to PIL Image then to numpy array
        pil_img = Image.open(buf)
        img_array = np.array(pil_img)

        # Close the figure and buffer
        plt.close(fig)
        buf.close()

        # Restore interactive mode if it was on
        if was_interactive:
            plt.ion()

        return img_array

    # Standard behavior: save and/or show
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
    plt.show()

    return None


def remove_axes(axes):
    def _remove_axes(ax):
        ax.xaxis.set_major_formatter(plt.NullFormatter())
        ax.yaxis.set_major_formatter(plt.NullFormatter())
        ax.set_xticks([])
        ax.set_yticks([])

    if len(axes.shape) == 2:
        for ax1 in axes:
            for ax in ax1:
                _remove_axes(ax)
    else:
        for ax in axes:
            _remove_axes(ax)


def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
    target_size = None
    if len(image_feats_list) > 1 and fit_pca is None:
        target_size = image_feats_list[0].shape[2]

    def flatten(tensor, target_size=None):
        B, C, H, W = tensor.shape
        assert B == 1, "Batch size should be 1 for PCA flattening"
        if target_size is not None:
            tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear", align_corners=False)
        return rearrange(tensor, "b c h w -> (b h w) c").detach().cpu()

    flattened_feats = []
    for feats in image_feats_list:
        flattened_feats.append(flatten(feats, target_size))
    x = torch.cat(flattened_feats, dim=0)

    # Subsample the data if max_samples is set and the number of samples exceeds max_samples
    if max_samples is not None and x.shape[0] > max_samples:
        indices = torch.randperm(x.shape[0])[:max_samples]
        x = x[indices]

    if fit_pca is None:
        fit_pca = TorchPCA(n_components=dim).fit(x)

    reduced_feats = []
    for feats in image_feats_list:
        B, C, H, W = feats.shape
        x_red = fit_pca.transform(flatten(feats))
        if isinstance(x_red, np.ndarray):
            x_red = torch.from_numpy(x_red)
        x_red -= x_red.min(dim=0, keepdim=True).values
        x_red /= x_red.max(dim=0, keepdim=True).values
        reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2))

    return reduced_feats, fit_pca


class TorchPCA(object):

    def __init__(self, n_components, skip=0):
        self.n_components = n_components
        self.skip = skip

    def fit(self, X):
        self.mean_ = X.mean(dim=0)
        unbiased = X - self.mean_
        U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=20)
        self.components_ = V[:, self.skip :]
        self.singular_values_ = S
        return self

    def transform(self, X):
        t0 = X - self.mean_.unsqueeze(0)
        projected = t0 @ self.components_
        return projected