| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from fairseq.modules import Fp32GroupNorm |
| |
|
| |
|
| | class KmeansVectorQuantizer(nn.Module): |
| | def __init__( |
| | self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 |
| | ): |
| | '''Vector quantization using straight pass-through estimator (i.e. kmeans) |
| | |
| | Args: |
| | dim: input dimension (channels) |
| | num_vars: number of quantized vectors per group |
| | groups: number of groups for vector quantization |
| | combine_groups: whether to use the vectors for all groups |
| | vq_dim: dimensionality of the resulting quantized vector |
| | time_first: if true, expect input in BxTxC format, otherwise in BxCxT |
| | gamma: commitment loss coefficient |
| | ''' |
| | super().__init__() |
| |
|
| | self.groups = groups |
| | self.combine_groups = combine_groups |
| | self.input_dim = dim |
| | self.num_vars = num_vars |
| | self.vq_dim = vq_dim |
| | self.time_first = time_first |
| |
|
| | assert ( |
| | vq_dim % groups == 0 |
| | ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" |
| |
|
| | self.var_dim = vq_dim // groups |
| | num_groups = groups if not combine_groups else 1 |
| |
|
| | self.embedding = nn.Parameter( |
| | 0.01 * torch.randn(num_vars, num_groups, self.var_dim) |
| | ) |
| | self.projection = nn.Sequential( |
| | nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False), |
| | Fp32GroupNorm(groups, dim), |
| | ) |
| | self.gamma = gamma |
| | self.mse_mean = nn.MSELoss(reduction="mean") |
| |
|
| | def _pass_grad(self, x, y): |
| | """ Manually set gradient for backward pass. |
| | for y = f(x), ensure that during the backward pass, |
| | dL/dy = dL/dx regardless of f(x). |
| | Returns: |
| | y, with the gradient forced to be dL/dy = dL/dx. |
| | """ |
| |
|
| | return y.detach() + (x - x.detach()) |
| |
|
| | @property |
| | def expand_embedding(self): |
| | if self.combine_groups: |
| | return self.embedding.expand(self.num_vars, self.groups, self.var_dim) |
| | return self.embedding |
| |
|
| | def forward_idx(self, x): |
| | res = self.forward(x, produce_targets=True) |
| | return res["x"], res["targets"] |
| |
|
| | def forward(self, x, produce_targets=False): |
| |
|
| | result = {"num_vars": self.num_vars} |
| |
|
| | if self.time_first: |
| | x = x.transpose(1, 2) |
| |
|
| | bsz, fsz, tsz = x.shape |
| |
|
| | ze = self.projection(x) |
| | ze_ = ze.view(bsz, self.groups, self.var_dim, tsz).permute(0, 3, 1, 2) |
| | d = ( |
| | (ze_.unsqueeze(0) - self.expand_embedding.unsqueeze(1).unsqueeze(1)) |
| | .view(self.num_vars, bsz, tsz, self.groups, -1) |
| | .norm(dim=-1, p=2) |
| | ) |
| | idx = d.argmin(dim=0) |
| | zq = ( |
| | torch.stack( |
| | [ |
| | self.expand_embedding[idx[..., group], group] |
| | for group in range(self.groups) |
| | ], |
| | dim=-2, |
| | ) |
| | .view(bsz, tsz, self.groups * self.var_dim) |
| | .permute(0, 2, 1) |
| | ) |
| | assert ze.shape == zq.shape, (ze.shape, zq.shape) |
| | x = self._pass_grad(ze, zq) |
| |
|
| | hard_x = ( |
| | idx.new_zeros(bsz*tsz*self.groups, self.num_vars) |
| | .scatter_(-1, idx.view(-1, 1), 1.0) |
| | .view(bsz * tsz, self.groups, -1) |
| | ) |
| | hard_probs = torch.mean(hard_x.float(), dim=0) |
| | result["code_perplexity"] = torch.exp( |
| | -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) |
| | ).sum() |
| |
|
| | if produce_targets: |
| | result["targets"] = idx |
| |
|
| | if self.time_first: |
| | x = x.transpose(1, 2) |
| | result["x"] = x |
| |
|
| | ze = ze.float() |
| | zq = zq.float() |
| | latent_loss = self.mse_mean(zq, ze.detach()) |
| | commitment_loss = self.mse_mean(ze, zq.detach()) |
| |
|
| | result["kmeans_loss"] = latent_loss + self.gamma * commitment_loss |
| |
|
| | return result |
| |
|