| | |
| | |
| | |
| | |
| |
|
| | from itertools import chain |
| |
|
| | import torch |
| |
|
| | from fairseq import optim, utils |
| |
|
| | from .dynamic_loss_scaler import DynamicLossScaler |
| |
|
| |
|
| | class _FP16OptimizerMixin(object): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | |
| | super().__init__(*args, **kwargs) |
| |
|
| | @property |
| | def has_flat_params(self): |
| | return torch.is_tensor(self.fp32_params) |
| |
|
| | @classmethod |
| | def build_fp32_params(cls, params, flatten=True): |
| | |
| | if flatten: |
| | total_param_size = sum(p.data.numel() for p in params) |
| | fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=params[0].device) |
| | offset = 0 |
| | for p in params: |
| | numel = p.data.numel() |
| | fp32_params[offset:offset+numel].copy_(p.data.view(-1)) |
| | offset += numel |
| | fp32_params = torch.nn.Parameter(fp32_params) |
| | fp32_params.grad = fp32_params.data.new(total_param_size) |
| | return fp32_params |
| | else: |
| | fp32_params = [] |
| | for p in params: |
| | p32 = torch.nn.Parameter(p.data.float()) |
| | p32.grad = torch.zeros_like(p32.data) |
| | fp32_params.append(p32) |
| | return fp32_params |
| |
|
| | def state_dict(self): |
| | """Return the optimizer's state dict.""" |
| | state_dict = self.fp32_optimizer.state_dict() |
| | if self.scaler is not None: |
| | state_dict['loss_scale'] = self.scaler.loss_scale |
| | return state_dict |
| |
|
| | def load_state_dict(self, state_dict, optimizer_overrides=None): |
| | """Load an optimizer state dict. |
| | |
| | In general we should prefer the configuration of the existing optimizer |
| | instance (e.g., learning rate) over that found in the state_dict. This |
| | allows us to resume training from a checkpoint using a new set of |
| | optimizer args. |
| | """ |
| | if 'loss_scale' in state_dict and self.scaler is not None: |
| | self.scaler.loss_scale = state_dict['loss_scale'] |
| | self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) |
| |
|
| | def backward(self, loss): |
| | """Computes the sum of gradients of the given tensor w.r.t. graph leaves. |
| | |
| | Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this |
| | function additionally dynamically scales the loss to avoid gradient |
| | underflow. |
| | """ |
| | if self.scaler is not None: |
| | loss = self.scaler.scale(loss) |
| | loss.backward() |
| | self._needs_sync = True |
| |
|
| | def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): |
| | if self._needs_sync: |
| | if self.scaler is not None: |
| | |
| | multiply_grads /= self.scaler.loss_scale |
| |
|
| | |
| | if self.has_flat_params: |
| | offset = 0 |
| | for p in self.fp16_params: |
| | if not p.requires_grad: |
| | continue |
| | grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) |
| | numel = grad_data.numel() |
| | self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1)) |
| | offset += numel |
| | self.fp32_params.grad.data.mul_(multiply_grads) |
| | else: |
| | for p, p32 in zip(self.fp16_params, self.fp32_params): |
| | if not p.requires_grad: |
| | continue |
| | if p.grad is not None: |
| | p32.grad.data.copy_(p.grad.data) |
| | p32.grad.data.mul_(multiply_grads) |
| | else: |
| | p32.grad = torch.zeros_like(p.data, dtype=torch.float) |
| |
|
| | self._needs_sync = False |
| |
|
| | def _sync_fp32_grads_to_fp16(self): |
| | |
| | if self.has_flat_params: |
| | offset = 0 |
| | for p in self.fp16_params: |
| | if not p.requires_grad: |
| | continue |
| | numel = p.data.numel() |
| | p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) |
| | offset += numel |
| | else: |
| | for p, p32 in zip(self.fp16_params, self.fp32_params): |
| | if not p.requires_grad: |
| | continue |
| | p.data.copy_(p32.data) |
| |
|
| | def multiply_grads(self, c): |
| | """Multiplies grads by a constant ``c``.""" |
| | if self._needs_sync: |
| | self._sync_fp16_grads_to_fp32(c) |
| | elif self.has_flat_params: |
| | self.fp32_params.grad.data.mul_(c) |
| | else: |
| | for p32 in self.fp32_params: |
| | p32.grad.data.mul_(c) |
| |
|
| | def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): |
| | """Clips gradient norm and updates dynamic loss scaler.""" |
| | self._sync_fp16_grads_to_fp32() |
| | grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn) |
| |
|
| | |
| | if self.scaler is not None: |
| | self.scaler.check_overflow(grad_norm) |
| |
|
| | return grad_norm |
| |
|
| | def step(self, closure=None): |
| | """Performs a single optimization step.""" |
| | self._sync_fp16_grads_to_fp32() |
| | self.fp32_optimizer.step(closure) |
| |
|
| | if self.scaler is not None: |
| | self.scaler.update() |
| |
|
| | self._sync_fp32_grads_to_fp16() |
| |
|
| | def zero_grad(self): |
| | """Clears the gradients of all optimized parameters.""" |
| | for p in self.fp16_params: |
| | p.grad = None |
| | if self.has_flat_params: |
| | self.fp32_params.grad.zero_() |
| | else: |
| | for p32 in self.fp32_params: |
| | p32.grad.zero_() |
| | self._needs_sync = False |
| |
|
| |
|
| | class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): |
| | """ |
| | Wrap an *optimizer* to support FP16 (mixed precision) training. |
| | """ |
| |
|
| | def __init__(self, args, params, fp32_optimizer, fp32_params): |
| | super().__init__(args) |
| | self.fp16_params = params |
| | self.fp32_optimizer = fp32_optimizer |
| | self.fp32_params = fp32_params |
| |
|
| | if getattr(args, 'fp16_scale_window', None) is None: |
| | if len(args.update_freq) > 1: |
| | raise ValueError( |
| | '--fp16-scale-window must be given explicitly when using a ' |
| | 'custom --update-freq schedule' |
| | ) |
| | data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) |
| | scale_window = int(2**14 / data_parallel_size / args.update_freq[0]) |
| | else: |
| | scale_window = args.fp16_scale_window |
| |
|
| | if not getattr(args, 'bf16', False): |
| | self.scaler = DynamicLossScaler( |
| | init_scale=args.fp16_init_scale, |
| | scale_window=scale_window, |
| | tolerance=args.fp16_scale_tolerance, |
| | threshold=args.threshold_loss_scale, |
| | min_loss_scale=args.min_loss_scale |
| | ) |
| | else: |
| | |
| | self.scaler = None |
| |
|
| | @classmethod |
| | def build_optimizer(cls, args, params): |
| | """ |
| | Args: |
| | args (argparse.Namespace): fairseq args |
| | params (iterable): iterable of parameters to optimize |
| | """ |
| | flatten = not getattr(args, 'fp16_no_flatten_grads', False) |
| | if getattr(args, 'bf16', False): |
| | flatten = False |
| | fp32_params = cls.build_fp32_params(params, flatten=flatten) |
| | if flatten: |
| | fp32_optimizer = optim.build_optimizer(args, [fp32_params]) |
| | else: |
| | fp32_optimizer = optim.build_optimizer(args, fp32_params) |
| | if flatten and not fp32_optimizer.supports_flat_params: |
| | raise RuntimeError( |
| | 'chosen optimizer does not support flat params, ' |
| | 'please set --fp16-no-flatten-grads' |
| | ) |
| | return cls(args, params, fp32_optimizer, fp32_params) |
| |
|
| | @property |
| | def optimizer(self): |
| | return self.fp32_optimizer.optimizer |
| |
|
| | @property |
| | def optimizer_config(self): |
| | return self.fp32_optimizer.optimizer_config |
| |
|
| | def get_lr(self): |
| | return self.fp32_optimizer.get_lr() |
| |
|
| | def set_lr(self, lr): |
| | self.fp32_optimizer.set_lr(lr) |
| |
|
| |
|
| | class _MemoryEfficientFP16OptimizerMixin(object): |
| |
|
| | def __init__(self, *args, **kwargs): |
| | |
| | super().__init__(*args, **kwargs) |
| |
|
| | @property |
| | def has_flat_params(self): |
| | return False |
| |
|
| | def state_dict(self): |
| | """Return the optimizer's state dict.""" |
| | state_dict = self.wrapped_optimizer.state_dict() |
| | if self.scaler is not None: |
| | state_dict['loss_scale'] = self.scaler.loss_scale |
| | return state_dict |
| |
|
| | def load_state_dict(self, state_dict, optimizer_overrides=None): |
| | """Load an optimizer state dict. |
| | |
| | In general we should prefer the configuration of the existing optimizer |
| | instance (e.g., learning rate) over that found in the state_dict. This |
| | allows us to resume training from a checkpoint using a new set of |
| | optimizer args. |
| | """ |
| | if 'loss_scale' in state_dict and self.scaler is not None: |
| | self.scaler.loss_scale = state_dict['loss_scale'] |
| |
|
| | self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | groups = self.optimizer.param_groups |
| | saved_groups = state_dict['param_groups'] |
| | id_map = { |
| | old_id: p |
| | for old_id, p in zip( |
| | chain(*(g['params'] for g in saved_groups)), |
| | chain(*(g['params'] for g in groups)) |
| | ) |
| | } |
| | for k, v in state_dict['state'].items(): |
| | if k in id_map: |
| | param = id_map[k] |
| | self.optimizer.state[param] = v |
| |
|
| | def backward(self, loss): |
| | """Computes the sum of gradients of the given tensor w.r.t. graph leaves. |
| | |
| | Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this |
| | function additionally dynamically scales the loss to avoid gradient |
| | underflow. |
| | """ |
| | if self.scaler is not None: |
| | loss = self.scaler.scale(loss) |
| | loss.backward() |
| |
|
| | def _unscale_grads(self): |
| | if self._multiply_factor != 1.: |
| | self.wrapped_optimizer.multiply_grads(self._multiply_factor) |
| | self._multiply_factor = 1. |
| |
|
| | def multiply_grads(self, c): |
| | """Multiplies grads by a constant *c*.""" |
| | self._multiply_factor *= c |
| |
|
| | def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): |
| | """Clips gradient norm and updates dynamic loss scaler.""" |
| | max_norm = float(max_norm) |
| | grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(0, aggregate_norm_fn) |
| |
|
| | if self.scaler is not None: |
| | grad_norm_cpu = float(grad_norm) |
| | if grad_norm_cpu > max_norm > 0.: |
| | self._multiply_factor *= max_norm / grad_norm_cpu |
| |
|
| | |
| | self.scaler.check_overflow(grad_norm_cpu) |
| | else: |
| | clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) |
| | self._multiply_factor *= clip_coef |
| |
|
| | return grad_norm |
| |
|
| | def step(self, closure=None): |
| | """Performs a single optimization step.""" |
| | if self.supports_step_with_scale: |
| | |
| | self.wrapped_optimizer.step(closure, scale=(1. / self._multiply_factor)) |
| | else: |
| | self._unscale_grads() |
| | self.wrapped_optimizer.step(closure) |
| |
|
| | if self.scaler is not None: |
| | self.scaler.update() |
| |
|
| | def zero_grad(self): |
| | """Clears the gradients of all optimized parameters.""" |
| | self.wrapped_optimizer.zero_grad() |
| | if self.scaler is not None: |
| | self._multiply_factor = 1. / float(self.scaler.loss_scale) |
| |
|
| |
|
| | class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer): |
| | """ |
| | Wrap an *optimizer* to support FP16 (mixed precision) training. |
| | |
| | Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not |
| | maintain an FP32 copy of the model. We instead expect the optimizer to |
| | convert the gradients to FP32 internally and sync the results back to the |
| | FP16 model params. This significantly reduces memory usage but slightly |
| | increases the time spent in the optimizer. |
| | |
| | Since this wrapper depends on specific functionality in the wrapped |
| | optimizer (i.e., on-the-fly conversion of grads to FP32), only certain |
| | optimizers can be wrapped. This is determined by the |
| | *supports_memory_efficient_fp16* property. |
| | """ |
| |
|
| | def __init__(self, args, params, optimizer): |
| | if not optimizer.supports_memory_efficient_fp16: |
| | raise ValueError( |
| | 'Unsupported optimizer: {}'.format(optimizer.__class__.__name__) |
| | ) |
| |
|
| | super().__init__(args) |
| | self.wrapped_optimizer = optimizer |
| |
|
| | if getattr(args, 'fp16_scale_window', None) is None: |
| | if len(args.update_freq) > 1: |
| | raise ValueError( |
| | '--fp16-scale-window must be given explicitly when using a ' |
| | 'custom --update-freq schedule' |
| | ) |
| | data_parallel_size = int(args.distributed_world_size / args.model_parallel_size) |
| | scale_window = 2**14 / data_parallel_size / args.update_freq[0] |
| | else: |
| | scale_window = args.fp16_scale_window |
| |
|
| | if not getattr(args, 'bf16', False): |
| | self.scaler = DynamicLossScaler( |
| | init_scale=args.fp16_init_scale, |
| | scale_window=scale_window, |
| | tolerance=args.fp16_scale_tolerance, |
| | threshold=args.threshold_loss_scale, |
| | min_loss_scale=args.min_loss_scale |
| | ) |
| | else: |
| | |
| | self.scaler = None |
| |
|
| | @classmethod |
| | def build_optimizer(cls, args, params): |
| | """ |
| | Args: |
| | args (argparse.Namespace): fairseq args |
| | params (iterable): iterable of parameters to optimize |
| | """ |
| | fp16_optimizer = optim.build_optimizer(args, params) |
| | return cls(args, params, fp16_optimizer) |
| |
|
| | @property |
| | def optimizer(self): |
| | return self.wrapped_optimizer.optimizer |
| |
|
| | @property |
| | def optimizer_config(self): |
| | return self.wrapped_optimizer.optimizer_config |
| |
|
| | def get_lr(self): |
| | return self.wrapped_optimizer.get_lr() |
| |
|
| | def set_lr(self, lr): |
| | self.wrapped_optimizer.set_lr(lr) |
| |
|