| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def quant_noise(module, p, block_size): |
| | """ |
| | Wraps modules and applies quantization noise to the weights for |
| | subsequent quantization with Iterative Product Quantization as |
| | described in "Training with Quantization Noise for Extreme Model Compression" |
| | |
| | Args: |
| | - module: nn.Module |
| | - p: amount of Quantization Noise |
| | - block_size: size of the blocks for subsequent quantization with iPQ |
| | |
| | Remarks: |
| | - Module weights must have the right sizes wrt the block size |
| | - Only Linear, Embedding and Conv2d modules are supported for the moment |
| | - For more detail on how to quantize by blocks with convolutional weights, |
| | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" |
| | - We implement the simplest form of noise here as stated in the paper |
| | which consists in randomly dropping blocks |
| | """ |
| |
|
| | |
| | if p <= 0: |
| | return module |
| |
|
| | |
| | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) |
| |
|
| | |
| | is_conv = module.weight.ndim == 4 |
| |
|
| | |
| | if not is_conv: |
| | assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" |
| |
|
| | |
| | else: |
| | |
| | if module.kernel_size == (1, 1): |
| | assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" |
| | |
| | else: |
| | k = module.kernel_size[0] * module.kernel_size[1] |
| | assert k % block_size == 0, "Kernel size must be a multiple of block size" |
| |
|
| | def _forward_pre_hook(mod, input): |
| | |
| | if mod.training: |
| | if not is_conv: |
| | |
| | weight = mod.weight |
| | in_features = weight.size(1) |
| | out_features = weight.size(0) |
| |
|
| | |
| | mask = torch.zeros(in_features // block_size * out_features, device=weight.device) |
| | mask.bernoulli_(p) |
| | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
| |
|
| | else: |
| | |
| | weight = mod.weight |
| | in_channels = mod.in_channels |
| | out_channels = mod.out_channels |
| |
|
| | |
| | if mod.kernel_size == (1, 1): |
| | mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device) |
| | mask.bernoulli_(p) |
| | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) |
| | else: |
| | mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) |
| | mask.bernoulli_(p) |
| | mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
| |
|
| | |
| | mask = mask.to(torch.bool) |
| | s = 1 / (1 - p) |
| | mod.weight.data = s * weight.masked_fill(mask, 0) |
| |
|
| | module.register_forward_pre_hook(_forward_pre_hook) |
| | return module |
| |
|