DeepFaceLive/xlib/torch/optim/AdaBelief.py
2021-12-20 17:52:52 +04:00

54 lines
2.0 KiB
Python

import torch
class AdaBelief(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
weight_decay=0, lr_dropout = 1.0):
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
super(AdaBelief, self).__init__(params, defaults)
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
def set_lr(self, lr):
for group in self.param_groups:
group['lr'] = lr
def step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
beta1, beta2 = group['betas']
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
state['step'] += 1
m_t, v_t = state['m_t'], state['v_t']
m_t.mul_(beta1).add_( grad , alpha=1 - beta1)
v_t.mul_(beta2).add_( (grad - m_t)**2 , alpha=1 - beta2)
v_diff = (-group['lr'] * m_t).div_( v_t.sqrt().add_(group['eps']) )
if group['lr_dropout'] < 1.0:
lrd_rand = torch.ones_like(p.data)
v_diff *= torch.bernoulli(lrd_rand * group['lr_dropout'] )
p.data.add_(v_diff)