mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
226 lines
7.1 KiB
Python
226 lines
7.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class FRNorm2D(nn.Module):
|
|
def __init__(self, in_ch):
|
|
super().__init__()
|
|
self.in_ch = in_ch
|
|
self.weight = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
|
|
self.bias = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
|
|
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
|
|
nn.init.ones_(self.weight)
|
|
nn.init.zeros_(self.bias)
|
|
nn.init.constant_(self.eps, 1e-6)
|
|
|
|
def forward(self, x):
|
|
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
|
|
x = x * torch.rsqrt(nu2 + self.eps.abs())
|
|
return self.weight * x + self.bias
|
|
|
|
class TLU(nn.Module):
|
|
def __init__(self, in_ch):
|
|
super(TLU, self).__init__()
|
|
self.in_ch = in_ch
|
|
self.tau = nn.parameter.Parameter(torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
|
|
nn.init.zeros_(self.tau)
|
|
|
|
def forward(self, x):
|
|
return torch.max(x, self.tau)
|
|
|
|
class BlurPool(nn.Module):
|
|
def __init__(self, in_ch, filt_size=3, stride=2, pad_off=0):
|
|
super().__init__()
|
|
self.in_ch = in_ch
|
|
self.filt_size = filt_size
|
|
self.pad_off = pad_off
|
|
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
|
|
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
|
|
self.stride = stride
|
|
self.off = int((self.stride-1)/2.)
|
|
|
|
if(self.filt_size==2):
|
|
a = np.array([1., 1.])
|
|
elif(self.filt_size==3):
|
|
a = np.array([1., 2., 1.])
|
|
elif(self.filt_size==4):
|
|
a = np.array([1., 3., 3., 1.])
|
|
elif(self.filt_size==5):
|
|
a = np.array([1., 4., 6., 4., 1.])
|
|
elif(self.filt_size==6):
|
|
a = np.array([1., 5., 10., 10., 5., 1.])
|
|
elif(self.filt_size==7):
|
|
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
|
|
|
filt = torch.Tensor(a[:,None]*a[None,:])
|
|
filt = filt/torch.sum(filt)
|
|
self.register_buffer('filt', filt[None,None,:,:].repeat(in_ch,1,1,1) )
|
|
|
|
self.pad = nn.ZeroPad2d(self.pad_sizes)
|
|
|
|
def forward(self, inp):
|
|
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=self.in_ch)
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d (in_ch, out_ch, kernel_size=3, padding=1)
|
|
self.frn = FRNorm2D(out_ch)
|
|
self.tlu = TLU(out_ch)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.frn(x)
|
|
x = self.tlu(x)
|
|
return x
|
|
|
|
class UpConvBlock(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.ConvTranspose2d (in_ch, out_ch, kernel_size=3, stride=2, padding=1,output_padding=1)
|
|
self.frn = FRNorm2D(out_ch)
|
|
self.tlu = TLU(out_ch)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.frn(x)
|
|
x = self.tlu(x)
|
|
return x
|
|
|
|
class XSegNet(nn.Module):
|
|
def __init__(self, in_ch, out_ch, base_ch=32):
|
|
"""
|
|
|
|
"""
|
|
super().__init__()
|
|
self.base_ch = base_ch
|
|
|
|
self.conv01 = ConvBlock(in_ch, base_ch)
|
|
self.conv02 = ConvBlock(base_ch, base_ch)
|
|
self.bp0 = BlurPool (base_ch, filt_size=4)
|
|
|
|
self.conv11 = ConvBlock(base_ch, base_ch*2)
|
|
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
|
|
self.bp1 = BlurPool (base_ch*2, filt_size=3)
|
|
|
|
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
|
|
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
|
|
self.bp2 = BlurPool (base_ch*4, filt_size=2)
|
|
|
|
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
|
|
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.bp3 = BlurPool (base_ch*8, filt_size=2)
|
|
|
|
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.bp4 = BlurPool (base_ch*8, filt_size=2)
|
|
|
|
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.bp5 = BlurPool (base_ch*8, filt_size=2)
|
|
|
|
self.dense1 = nn.Linear ( 4*4* base_ch*8, 512)
|
|
self.dense2 = nn.Linear ( 512, 4*4* base_ch*8)
|
|
|
|
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
|
|
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
|
|
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
|
|
|
|
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
|
|
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
|
|
self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.uconv41 = ConvBlock(base_ch*8, base_ch*8)
|
|
|
|
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
|
|
self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
|
|
self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
|
|
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
|
|
|
|
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
|
|
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
|
|
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
|
|
|
|
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
|
|
self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
|
|
self.uconv11 = ConvBlock(base_ch*2, base_ch*2)
|
|
|
|
self.up0 = UpConvBlock (base_ch*2, base_ch)
|
|
self.uconv02 = ConvBlock(base_ch*2, base_ch)
|
|
self.uconv01 = ConvBlock(base_ch, base_ch)
|
|
|
|
self.out_conv = nn.Conv2d (base_ch, out_ch, kernel_size=7, padding=3)
|
|
|
|
def forward(self, inp):
|
|
x = inp
|
|
|
|
x = self.conv01(x)
|
|
x = x0 = self.conv02(x)
|
|
x = self.bp0(x)
|
|
|
|
x = self.conv11(x)
|
|
x = x1 = self.conv12(x)
|
|
x = self.bp1(x)
|
|
|
|
x = self.conv21(x)
|
|
x = x2 = self.conv22(x)
|
|
x = self.bp2(x)
|
|
|
|
x = self.conv31(x)
|
|
x = self.conv32(x)
|
|
x = x3 = self.conv33(x)
|
|
x = self.bp3(x)
|
|
|
|
x = self.conv41(x)
|
|
x = self.conv42(x)
|
|
x = x4 = self.conv43(x)
|
|
x = self.bp4(x)
|
|
|
|
x = self.conv51(x)
|
|
x = self.conv52(x)
|
|
x = x5 = self.conv53(x)
|
|
x = self.bp5(x)
|
|
|
|
x = x.view(x.shape[0], -1)
|
|
x = self.dense1(x)
|
|
x = self.dense2(x)
|
|
x = x.view (-1, self.base_ch*8, 4, 4)
|
|
|
|
x = self.up5(x)
|
|
|
|
x = self.uconv53(torch.cat([x,x5],axis=1))
|
|
x = self.uconv52(x)
|
|
x = self.uconv51(x)
|
|
|
|
x = self.up4(x)
|
|
x = self.uconv43(torch.cat([x,x4],axis=1))
|
|
x = self.uconv42(x)
|
|
x = self.uconv41(x)
|
|
|
|
x = self.up3(x)
|
|
x = self.uconv33(torch.cat([x,x3],axis=1))
|
|
x = self.uconv32(x)
|
|
x = self.uconv31(x)
|
|
|
|
x = self.up2(x)
|
|
x = self.uconv22(torch.cat([x,x2],axis=1))
|
|
x = self.uconv21(x)
|
|
|
|
x = self.up1(x)
|
|
x = self.uconv12(torch.cat([x,x1],axis=1))
|
|
x = self.uconv11(x)
|
|
|
|
x = self.up0(x)
|
|
x = self.uconv02(torch.cat([x,x0],axis=1))
|
|
x = self.uconv01(x)
|
|
|
|
x = self.out_conv(x)
|
|
|
|
return x
|
|
|