mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
import torch
|
|
|
|
class AdaBelief(torch.optim.Optimizer):
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
|
weight_decay=0, lr_dropout = 1.0):
|
|
|
|
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
super(AdaBelief, self).__init__(params, defaults)
|
|
|
|
def reset(self):
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
state = self.state[p]
|
|
state['step'] = 0
|
|
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
|
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
|
|
|
def set_lr(self, lr):
|
|
for group in self.param_groups:
|
|
group['lr'] = lr
|
|
|
|
def step(self):
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
grad = p.grad.data
|
|
|
|
beta1, beta2 = group['betas']
|
|
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
|
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
|
|
|
if group['weight_decay'] != 0:
|
|
grad.add_(p.data, alpha=group['weight_decay'])
|
|
|
|
state['step'] += 1
|
|
|
|
m_t, v_t = state['m_t'], state['v_t']
|
|
m_t.mul_(beta1).add_( grad , alpha=1 - beta1)
|
|
v_t.mul_(beta2).add_( (grad - m_t)**2 , alpha=1 - beta2)
|
|
|
|
v_diff = (-group['lr'] * m_t).div_( v_t.sqrt().add_(group['eps']) )
|
|
|
|
if group['lr_dropout'] < 1.0:
|
|
lrd_rand = torch.ones_like(p.data)
|
|
v_diff *= torch.bernoulli(lrd_rand * group['lr_dropout'] )
|
|
|
|
p.data.add_(v_diff)
|