| import cv2 |
| import torch |
| from torch import nn |
| from einops.layers.torch import Rearrange |
| from .DCT import Learnable_DCT2D |
| |
|
|
| class Block(nn.Module): |
| """ ConvNeXtV2 Block. |
| |
| Args: |
| dim (int): Number of input channels. |
| drop_path (float): Stochastic depth rate. Default: 0.0 |
| """ |
|
|
| def __init__(self, dim, drop_path=0.): |
| super().__init__() |
| self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) |
| self.norm = LayerNorm(dim, eps=1e-6) |
| self.pwconv1 = nn.Linear(dim, 4 * dim) |
| self.act = nn.GELU() |
| self.grn = GRN(4 * dim) |
| self.pwconv2 = nn.Linear(4 * dim, dim) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.attention = Spatial_Attention() |
| def forward(self, x): |
| input = x |
| x = self.dwconv(x) |
| x = x.permute(0, 2, 3, 1) |
| x = self.norm(x) |
| x = self.pwconv1(x) |
| x = self.act(x) |
| x = self.grn(x) |
| x = self.pwconv2(x) |
|
|
| x = x.permute(0, 3, 1, 2) |
| attention = self.attention(x) |
| x = x * nn.UpsamplingBilinear2d(x.shape[2:])(attention) |
| x = input + self.drop_path(x) |
| return x |
|
|
| class Spatial_Attention(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.avgpool = nn.AdaptiveAvgPool2d((7,7)) |
| self.conv = nn.Conv2d(2,1, kernel_size=7, padding=3) |
| self.attention = TransformerBlock(1, 1, heads=1, dim_head=1, img_size=[7,7]) |
|
|
| def forward(self, x): |
| x_avg = x.mean([1]).unsqueeze(1) |
| x_max = x.max(dim=1).values.unsqueeze(1) |
| |
| x = torch.cat([x_avg, x_max], dim=1) |
| x = self.avgpool(x) |
| x = self.conv(x) |
| x = self.attention(x) |
| return x |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, inp, oup, heads=8, dim_head=32, img_size=None, downsample=False, dropout=0.): |
| super().__init__() |
| hidden_dim = int(inp * 4) |
|
|
| self.downsample = downsample |
| self.ih, self.iw = img_size |
|
|
| if self.downsample: |
| self.pool1 = nn.MaxPool2d(3, 2, 1) |
| self.pool2 = nn.MaxPool2d(3, 2, 1) |
| self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False) |
|
|
| self.attn = Attention(inp, oup, heads, dim_head, dropout) |
| self.ff = FeedForward(oup, hidden_dim, dropout) |
|
|
| self.attn = nn.Sequential( |
| Rearrange('b c ih iw -> b (ih iw) c'), |
| PreNorm(inp, self.attn, nn.LayerNorm), |
| Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
| ) |
|
|
| self.ff = nn.Sequential( |
| Rearrange('b c ih iw -> b (ih iw) c'), |
| PreNorm(oup, self.ff, nn.LayerNorm), |
| Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw) |
| ) |
|
|
| def forward(self, x): |
| if self.downsample: |
| x = self.proj(self.pool1(x)) + self.attn(self.pool2(x)) |
| else: |
| x = x + self.attn(x) |
| x = x + self.ff(x) |
| return x |
|
|
|
|
| class CSATv2(nn.Module): |
| def __init__(self, img_size=None, num_classes=1000, drop_path_rate=0, head_init_scale=1): |
| super().__init__() |
| dims = [32, 72, 168, 386] |
| channel_order = "channels_first" |
| depths = [2, 2, 6, 4] |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
|
|
| |
| |
|
|
| self.stages1 = nn.Sequential( |
| Block(dim=dims[0], drop_path=dp_rates[0]), |
| Block(dim=dims[0], drop_path=dp_rates[1]), |
| LayerNorm(dims[0], eps=1e-6, data_format=channel_order), |
| nn.Conv2d(dims[0], dims[0 + 1], kernel_size=2, stride=2), |
| ) |
|
|
| self.stages2 = nn.Sequential( |
| Block(dim=dims[1], drop_path=dp_rates[0]), |
| Block(dim=dims[1], drop_path=dp_rates[1]), |
| LayerNorm(dims[1], eps=1e-6, data_format=channel_order), |
| nn.Conv2d(dims[1], dims[1 + 1], kernel_size=2, stride=2), |
| ) |
|
|
| self.stages3 = nn.Sequential( |
| Block(dim=dims[2], drop_path=dp_rates[0]), |
| Block(dim=dims[2], drop_path=dp_rates[1]), |
| Block(dim=dims[2], drop_path=dp_rates[2]), |
| Block(dim=dims[2], drop_path=dp_rates[3]), |
| Block(dim=dims[2], drop_path=dp_rates[4]), |
| Block(dim=dims[2], drop_path=dp_rates[5]), |
| TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
| TransformerBlock(inp=dims[2], oup=dims[2], img_size=[int(img_size / 32), int(img_size / 32)]), |
| LayerNorm(dims[2], eps=1e-6, data_format=channel_order), |
| nn.Conv2d(dims[2], dims[2 + 1], kernel_size=2, stride=2), |
| ) |
|
|
| self.stages4 = nn.Sequential( |
| Block(dim=dims[3], drop_path=dp_rates[0]), |
| Block(dim=dims[3], drop_path=dp_rates[1]), |
| Block(dim=dims[3], drop_path=dp_rates[2]), |
| Block(dim=dims[3], drop_path=dp_rates[3]), |
| TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
| TransformerBlock(inp=dims[3], oup=dims[3], img_size=[int(img_size / 64), int(img_size / 64)]), |
| ) |
|
|
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) |
| self.head = nn.Linear(dims[-1], num_classes) |
|
|
| self.apply(self._init_weights) |
| self.head.weight.data.mul_(head_init_scale) |
| self.head.bias.data.mul_(head_init_scale) |
| self.dct = Learnable_DCT2D(8) |
| |
|
|
| def load_checkpoint(self, checkpoint): |
| state = torch.load(checkpoint, map_location='cpu') |
| try: |
| state_dict = state['state_dict'] |
| except: |
| state_dict = state['model'] |
| for key in list(state_dict.keys()): |
| state_dict[key.replace('module.backbone.', '').replace('resnet.', '')] = state_dict.pop(key) |
|
|
| model_dict = self.state_dict() |
| weights = {k: v for k, v in state_dict.items() if k in model_dict} |
|
|
| model_dict.update(weights) |
| del model_dict['head.bias'] |
| del model_dict['head.weight'] |
| self.load_state_dict(model_dict, strict=False) |
|
|
| def preprocess(self, x): |
| x = cv2.cvtColor(x, cv2.COLOR_BGR2YCR_CB) |
| return x |
|
|
| def _init_weights(self, m): |
| if isinstance(m, (nn.Conv2d, nn.Linear)): |
| trunc_normal_(m.weight, std=.02) |
| try: |
| nn.init.constant_(m.bias, 0) |
| except: |
| pass |
| |
|
|
|
|
| def forward(self, x): |
| |
| x = self.dct(x) |
| x = self.stages1(x) |
| x = self.stages2(x) |
| x = self.stages3(x) |
| x = self.stages4(x) |
| x = self.norm(x.mean([-2, -1])) |
| x = self.head(x) |
| return x |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| import math |
| import warnings |
|
|
| class LayerNorm(nn.Module): |
| """ LayerNorm that supports two data formats: channels_last (default) or channels_first. |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
| with shape (batch_size, channels, height, width). |
| """ |
|
|
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| self.eps = eps |
| self.data_format = data_format |
| if self.data_format not in ["channels_last", "channels_first"]: |
| raise NotImplementedError |
| self.normalized_shape = (normalized_shape,) |
|
|
| def forward(self, x): |
| if self.data_format == "channels_last": |
| return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| elif self.data_format == "channels_first": |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.eps) |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| return x |
|
|
|
|
| class GRN(nn.Module): |
| """ GRN (Global Response Normalization) layer |
| """ |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
| def forward(self, x): |
| Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
| Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
| return self.gamma * (x * Nx) + self.beta + x |
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| |
| """ |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, hidden_dim, dropout=0.): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, dim), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class PreNorm(nn.Module): |
| def __init__(self, dim, fn, norm): |
| super().__init__() |
| self.norm = norm(dim) |
| self.fn = fn |
|
|
| def forward(self, x, **kwargs): |
| return self.fn(self.norm(x), **kwargs) |
|
|
| class Attention(nn.Module): |
| def __init__(self, inp, oup, heads=8, dim_head=32, dropout=0.): |
| super().__init__() |
| inner_dim = dim_head * heads |
| project_out = not (heads == 1 and dim_head == inp) |
|
|
| |
| self.heads = heads |
| self.scale = dim_head ** -0.5 |
|
|
| self.attend = nn.Softmax(dim=-1) |
| self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False) |
|
|
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, oup), |
| nn.Dropout(dropout) |
| ) if project_out else nn.Identity() |
| self.pos_embed = PosCNN(in_chans=inp) |
|
|
| def forward(self, x): |
| x = self.pos_embed(x) |
| qkv = self.to_qkv(x).chunk(3, dim=-1) |
| q, k, v = map(lambda t: rearrange( |
| t, 'b n (h d) -> b h n d', h=self.heads), qkv) |
|
|
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| attn = self.attend(dots) |
| out = torch.matmul(attn, v) |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| out = self.to_out(out) |
| return out |
|
|
| |
| class PosCNN(nn.Module): |
| def __init__(self, in_chans): |
| super(PosCNN, self).__init__() |
| self.proj = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride = 1, padding=1, bias=True, groups=in_chans) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| feat_token = x |
| H, W = int(N**0.5), int(N**0.5) |
| cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) |
| x = self.proj(cnn_feat) + cnn_feat |
| x = x.flatten(2).transpose(1, 2) |
| return x |
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| |
| r"""Fills the input Tensor with values drawn from a truncated |
| normal distribution. The values are effectively drawn from the |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| with values outside :math:`[a, b]` redrawn until they are within |
| the bounds. The method used for generating the random values works |
| best when :math:`a \leq \text{mean} \leq b`. |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| a: the minimum cutoff value |
| b: the maximum cutoff value |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.trunc_normal_(w) |
| """ |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |