| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from fairseq import utils |
| | from .unfold import unfold1d |
| | from fairseq.incremental_decoding_utils import with_incremental_state |
| | from fairseq.modules.fairseq_dropout import FairseqDropout |
| |
|
| |
|
| | def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, |
| | weight_dropout=0., weight_softmax=False, |
| | renorm_padding=False, bias=False, conv_bias=False, |
| | query_size=None, in_proj=False): |
| | if torch.cuda.is_available(): |
| | try: |
| | from fairseq.modules.dynamicconv_layer import DynamicconvLayer |
| | return DynamicconvLayer(input_size, kernel_size=kernel_size, |
| | padding_l=padding_l, num_heads=num_heads, |
| | weight_dropout=weight_dropout, |
| | weight_softmax=weight_softmax, bias=bias) |
| | except ImportError as e: |
| | print(e) |
| | return DynamicConv1dTBC(input_size, kernel_size=kernel_size, |
| | padding_l=padding_l, num_heads=num_heads, |
| | weight_dropout=weight_dropout, |
| | weight_softmax=weight_softmax, bias=bias) |
| |
|
| |
|
| | def Linear(in_features, out_features, bias=True): |
| | m = nn.Linear(in_features, out_features, bias) |
| | nn.init.xavier_uniform_(m.weight) |
| | if bias: |
| | nn.init.constant_(m.bias, 0.) |
| | return m |
| |
|
| |
|
| | @with_incremental_state |
| | class DynamicConv1dTBC(nn.Module): |
| | '''Dynamic lightweight convolution taking T x B x C inputs |
| | Args: |
| | input_size: # of channels of the input |
| | kernel_size: convolution channels |
| | padding_l: padding to the left when using "same" padding |
| | num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) |
| | weight_dropout: the drop rate of the DropConnect to drop the weight |
| | weight_softmax: normalize the weight with softmax before the convolution |
| | renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) |
| | bias: use bias |
| | conv_bias: bias of the convolution |
| | query_size: specified when feeding a different input as the query |
| | in_proj: project the input and generate the filter together |
| | |
| | Shape: |
| | Input: TxBxC, i.e. (timesteps, batch_size, input_size) |
| | Output: TxBxC, i.e. (timesteps, batch_size, input_size) |
| | |
| | Attributes: |
| | weight: the learnable weights of the module of shape |
| | `(num_heads, 1, kernel_size)` |
| | bias: the learnable bias of the module of shape `(input_size)` |
| | ''' |
| | def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1, |
| | weight_dropout=0., weight_softmax=False, |
| | renorm_padding=False, bias=False, conv_bias=False, |
| | query_size=None, in_proj=False): |
| | super().__init__() |
| | self.input_size = input_size |
| | self.query_size = input_size if query_size is None else query_size |
| | self.kernel_size = kernel_size |
| | self.padding_l = padding_l |
| | self.num_heads = num_heads |
| | self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) |
| | self.weight_softmax = weight_softmax |
| | self.renorm_padding = renorm_padding |
| |
|
| | if in_proj: |
| | self.weight_linear = Linear(self.input_size, self.input_size + num_heads * kernel_size * 1) |
| | else: |
| | self.weight_linear = Linear(self.query_size, num_heads * kernel_size * 1, bias=bias) |
| | if conv_bias: |
| | self.conv_bias = nn.Parameter(torch.Tensor(input_size)) |
| | else: |
| | self.conv_bias = None |
| | self.reset_parameters() |
| |
|
| | @property |
| | def in_proj(self): |
| | return self.weight_linear.out_features == self.input_size + self.num_heads * self.kernel_size |
| |
|
| | def reset_parameters(self): |
| | self.weight_linear.reset_parameters() |
| | if self.conv_bias is not None: |
| | nn.init.constant_(self.conv_bias, 0.) |
| |
|
| | def forward(self, x, incremental_state=None, query=None, unfold=None): |
| | '''Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C |
| | args: |
| | x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) |
| | incremental_state: A dict to keep the state |
| | unfold: unfold the input or not. If not, we use the matrix trick instead |
| | query: use the specified query to predict the conv filters |
| | ''' |
| | unfold = x.size(0) > 512 if unfold is None else unfold |
| | unfold = unfold or (incremental_state is not None) |
| | assert query is None or not self.in_proj |
| |
|
| | if query is None: |
| | query = x |
| | if unfold: |
| | output = self._forward_unfolded(x, incremental_state, query) |
| | else: |
| | output = self._forward_expanded(x, incremental_state, query) |
| |
|
| | if self.conv_bias is not None: |
| | output = output + self.conv_bias.view(1, 1, -1) |
| | return output |
| |
|
| | def _forward_unfolded(self, x, incremental_state, query): |
| | '''The conventional implementation of convolutions. |
| | Unfolding the input by having a window shifting to the right.''' |
| | T, B, C = x.size() |
| | K, H = self.kernel_size, self.num_heads |
| | R = C // H |
| | assert R * H == C == self.input_size |
| |
|
| | if self.in_proj: |
| | proj = self.weight_linear(x) |
| | x = proj.narrow(2, 0, self.input_size).contiguous() |
| | weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) |
| | else: |
| | weight = self.weight_linear(query).view(T*B*H, -1) |
| |
|
| | |
| | assert not self.renorm_padding or incremental_state is not None |
| |
|
| | if incremental_state is not None: |
| | input_buffer = self._get_input_buffer(incremental_state) |
| | if input_buffer is None: |
| | input_buffer = x.new() |
| | x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) |
| | if self.kernel_size > 1: |
| | self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) |
| | x_unfold = x_unfold.view(T*B*H, R, -1) |
| | else: |
| | padding_l = self.padding_l |
| | if K > T and padding_l == K-1: |
| | weight = weight.narrow(1, K-T, T) |
| | K, padding_l = T, T-1 |
| | |
| | x_unfold = unfold1d(x, K, padding_l, 0) |
| | x_unfold = x_unfold.view(T*B*H, R, K) |
| |
|
| | if self.weight_softmax and not self.renorm_padding: |
| | weight = F.softmax(weight, dim=1) |
| | weight = weight.narrow(1, 0, K) |
| |
|
| | if incremental_state is not None: |
| | weight = weight[:, -x_unfold.size(2):] |
| | K = weight.size(1) |
| |
|
| | if self.weight_softmax and self.renorm_padding: |
| | weight = F.softmax(weight, dim=1) |
| |
|
| | weight = self.weight_dropout_module(weight, inplace=False) |
| |
|
| | output = torch.bmm(x_unfold, weight.unsqueeze(2)) |
| | output = output.view(T, B, C) |
| | return output |
| |
|
| | def _forward_expanded(self, x, incremental_stat, query): |
| | '''Turn the convolution filters into band matrices and do matrix multiplication. |
| | This is faster when the sequence is short, but less memory efficient. |
| | This is not used in the decoder during inference. |
| | ''' |
| | T, B, C = x.size() |
| | K, H = self.kernel_size, self.num_heads |
| | R = C // H |
| | assert R * H == C == self.input_size |
| | if self.in_proj: |
| | proj = self.weight_linear(x) |
| | x = proj.narrow(2, 0, self.input_size).contiguous() |
| | weight = proj.narrow(2, self.input_size, H*K).contiguous().view(T*B*H, -1) |
| | else: |
| | weight = self.weight_linear(query).view(T*B*H, -1) |
| |
|
| | if not self.renorm_padding: |
| | if self.weight_softmax: |
| | weight = F.softmax(weight, dim=1) |
| | weight = self.weight_dropout_module(weight, inplace=False) |
| | weight = weight.narrow(1, 0, K).contiguous() |
| | weight = weight.view(T, B*H, K).transpose(0, 1) |
| |
|
| | x = x.view(T, B*H, R).transpose(0, 1) |
| | if self.weight_softmax and self.renorm_padding: |
| | |
| | weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) |
| | weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) |
| | weight_expanded = weight_expanded.narrow(2, self.padding_l, T) |
| | |
| | weight_expanded = F.softmax(weight_expanded, dim=2) |
| | weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) |
| | else: |
| | P = self.padding_l |
| | |
| | if K > T and P == K-1: |
| | weight = weight.narrow(2, K-T, T) |
| | K, P = T, T-1 |
| | |
| | weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) |
| | weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) |
| | weight_expanded = weight_expanded.narrow(2, P, T) |
| | output = torch.bmm(weight_expanded, x) |
| | output = output.transpose(0, 1).contiguous().view(T, B, C) |
| | return output |
| |
|
| | def reorder_incremental_state(self, incremental_state, new_order): |
| | input_buffer = self._get_input_buffer(incremental_state) |
| | if input_buffer is not None: |
| | input_buffer = input_buffer.index_select(1, new_order) |
| | self._set_input_buffer(incremental_state, input_buffer) |
| |
|
| | def _get_input_buffer(self, incremental_state): |
| | return utils.get_incremental_state(self, incremental_state, 'input_buffer') |
| |
|
| | def _set_input_buffer(self, incremental_state, new_buffer): |
| | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) |
| |
|
| | def extra_repr(self): |
| | s = '{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}'.format( |
| | self.input_size, self.kernel_size, self.padding_l, |
| | self.num_heads, self.weight_softmax, self.conv_bias is not None, self.renorm_padding, |
| | self.in_proj, |
| | ) |
| |
|
| | if self.query_size != self.input_size: |
| | s += ', query_size={}'.format(self.query_size) |
| | if self.weight_dropout_module.p > 0.: |
| | s += ', weight_dropout={}'.format(self.weight_dropout_module.p) |
| | return s |
| |
|