Shortcuts

Source code for torch.optim.radam

import math
import torch
from torch import Tensor

from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
                        _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["RAdam", "radam"]


[docs]class RAdam(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, *, foreach: Optional[bool] = None, differentiable: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, differentiable=differentiable, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) group.setdefault("differentiable", False) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( state_values[0]["step"] ) if not step_is_tensor: for s in state_values: s["step"] = torch.tensor(float(s["step"])) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps): for p in group["params"]: if p.grad is not None: params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError("RAdam does not support sparse gradients") grads.append(p.grad) state = self.state[p] # Lazy state initialization if len(state) == 0: state["step"] = torch.tensor(0.0) # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like( p, memory_format=torch.preserve_format ) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) exp_avgs.append(state["exp_avg"]) exp_avg_sqs.append(state["exp_avg_sq"]) state_steps.append(state["step"]) @_use_grad_for_differentiable def step(self, closure=None): """Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] state_steps = [] beta1, beta2 = group["betas"] self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps) radam( params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps, beta1=beta1, beta2=beta2, lr=group["lr"], weight_decay=group["weight_decay"], eps=group["eps"], foreach=group["foreach"], differentiable=group["differentiable"], ) return loss
RAdam.__doc__ = r"""Implements RAdam algorithm. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \: \lambda \text{ (weightdecay)}, \\ &\hspace{13mm} \epsilon \text{ (epsilon)} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \\ &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} - 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex] &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\ &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\ &\hspace{12mm} r_t \leftarrow \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\ &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\ &\hspace{6mm}\textbf{else} \\ &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_. This implementation uses the same weight_decay implementation as Adam (were the weight_decay is applied to the gradient) and not the one from AdamW (were weight_decay is applied to the update). This is different from the `author's implementation`_. """ + r""" Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) {foreach} {differentiable} .. _On the variance of the adaptive learning rate and beyond: https://arxiv.org/abs/1908.03265 .. _author's implementation: https://github.com/LiyuanLucasLiu/RAdam """.format(foreach=_foreach_doc, differentiable=_differentiable_doc) def radam( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, differentiable: bool = False, *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, ): r"""Functional API that performs RAdam algorithm computation. See :class:`~torch.optim.RAdam` for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" ) if foreach is None: _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False) if foreach and torch.jit.is_scripting(): raise RuntimeError("torch.jit.script not supported with foreach optimizers") if foreach and not torch.jit.is_scripting(): func = _multi_tensor_radam else: func = _single_tensor_radam func( params, grads, exp_avgs, exp_avg_sqs, state_steps, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay, eps=eps, differentiable=differentiable, ) def _single_tensor_radam( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[Tensor], *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, differentiable: bool, ): for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] # update step step_t += 1 step = _get_value(step_t) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # correcting bias for the first moving moment bias_corrected_exp_avg = exp_avg / bias_correction1 # maximum length of the approximated SMA rho_inf = 2 / (1 - beta2) - 1 # compute the length of the approximated SMA rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2 if rho_t > 5.0: # Compute the variance rectification term and update parameters accordingly rect = math.sqrt( (rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t) ) exp_avg_sq_sqrt = exp_avg_sq.sqrt() if differentiable: exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps) else: exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq_sqrt param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0) else: param.add_(bias_corrected_exp_avg * lr, alpha=-1.0) def _multi_tensor_radam( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[Tensor], *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, differentiable: bool, ): if len(params) == 0: return assert not differentiable, "_foreach ops don't support autograd" grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps]) for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs, grouped_state_steps in grouped_tensors.values(): # Update steps torch._foreach_add_(grouped_state_steps, 1) # maximum length of the approximated SMA rho_inf = 2 / (1 - beta2) - 1 # compute the length of the approximated SMA rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) / (1 - beta2 ** _get_value(step)) for step in grouped_state_steps] bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps] bias_correction2 = [1 - beta2 ** _get_value(step) for step in grouped_state_steps] if weight_decay != 0: grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay) # Decay the first and second moment running average coefficient torch._foreach_mul_(grouped_exp_avgs, beta1) torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1) torch._foreach_mul_(grouped_exp_avg_sqs, beta2) torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2) rect = [ _dispatch_sqrt( (rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t) ) if rho_t > 5 else 0 for rho_t in rho_t_list ] unrectified = [0 if rect > 0 else 1.0 for rect in rect] exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) torch._foreach_add_(exp_avg_sq_sqrt, eps) bias_correction_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] denom = torch._foreach_div(exp_avg_sq_sqrt, bias_correction_sqrt) step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(rect, bias_correction1)]) torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size) denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in grouped_exp_avgs] step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]) torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources