DeepFaceLive/xlib/torch/model/XsegNet.py
2021-12-19 18:14:05 +04:00

226 lines
7.1 KiB
Python

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class FRNorm2D(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.in_ch = in_ch
self.weight = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
self.bias = nn.parameter.Parameter( torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
nn.init.constant_(self.eps, 1e-6)
def forward(self, x):
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
x = x * torch.rsqrt(nu2 + self.eps.abs())
return self.weight * x + self.bias
class TLU(nn.Module):
def __init__(self, in_ch):
super(TLU, self).__init__()
self.in_ch = in_ch
self.tau = nn.parameter.Parameter(torch.Tensor(1, in_ch, 1, 1), requires_grad=True)
nn.init.zeros_(self.tau)
def forward(self, x):
return torch.max(x, self.tau)
class BlurPool(nn.Module):
def __init__(self, in_ch, filt_size=3, stride=2, pad_off=0):
super().__init__()
self.in_ch = in_ch
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride-1)/2.)
if(self.filt_size==2):
a = np.array([1., 1.])
elif(self.filt_size==3):
a = np.array([1., 2., 1.])
elif(self.filt_size==4):
a = np.array([1., 3., 3., 1.])
elif(self.filt_size==5):
a = np.array([1., 4., 6., 4., 1.])
elif(self.filt_size==6):
a = np.array([1., 5., 10., 10., 5., 1.])
elif(self.filt_size==7):
a = np.array([1., 6., 15., 20., 15., 6., 1.])
filt = torch.Tensor(a[:,None]*a[None,:])
filt = filt/torch.sum(filt)
self.register_buffer('filt', filt[None,None,:,:].repeat(in_ch,1,1,1) )
self.pad = nn.ZeroPad2d(self.pad_sizes)
def forward(self, inp):
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=self.in_ch)
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d (in_ch, out_ch, kernel_size=3, padding=1)
self.frn = FRNorm2D(out_ch)
self.tlu = TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
class UpConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.ConvTranspose2d (in_ch, out_ch, kernel_size=3, stride=2, padding=1,output_padding=1)
self.frn = FRNorm2D(out_ch)
self.tlu = TLU(out_ch)
def forward(self, x):
x = self.conv(x)
x = self.frn(x)
x = self.tlu(x)
return x
class XSegNet(nn.Module):
def __init__(self, in_ch, out_ch, base_ch=32):
"""
"""
super().__init__()
self.base_ch = base_ch
self.conv01 = ConvBlock(in_ch, base_ch)
self.conv02 = ConvBlock(base_ch, base_ch)
self.bp0 = BlurPool (base_ch, filt_size=4)
self.conv11 = ConvBlock(base_ch, base_ch*2)
self.conv12 = ConvBlock(base_ch*2, base_ch*2)
self.bp1 = BlurPool (base_ch*2, filt_size=3)
self.conv21 = ConvBlock(base_ch*2, base_ch*4)
self.conv22 = ConvBlock(base_ch*4, base_ch*4)
self.bp2 = BlurPool (base_ch*4, filt_size=2)
self.conv31 = ConvBlock(base_ch*4, base_ch*8)
self.conv32 = ConvBlock(base_ch*8, base_ch*8)
self.conv33 = ConvBlock(base_ch*8, base_ch*8)
self.bp3 = BlurPool (base_ch*8, filt_size=2)
self.conv41 = ConvBlock(base_ch*8, base_ch*8)
self.conv42 = ConvBlock(base_ch*8, base_ch*8)
self.conv43 = ConvBlock(base_ch*8, base_ch*8)
self.bp4 = BlurPool (base_ch*8, filt_size=2)
self.conv51 = ConvBlock(base_ch*8, base_ch*8)
self.conv52 = ConvBlock(base_ch*8, base_ch*8)
self.conv53 = ConvBlock(base_ch*8, base_ch*8)
self.bp5 = BlurPool (base_ch*8, filt_size=2)
self.dense1 = nn.Linear ( 4*4* base_ch*8, 512)
self.dense2 = nn.Linear ( 512, 4*4* base_ch*8)
self.up5 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
self.up4 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
self.uconv41 = ConvBlock(base_ch*8, base_ch*8)
self.up3 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
self.up2 = UpConvBlock (base_ch*8, base_ch*4)
self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
self.up1 = UpConvBlock (base_ch*4, base_ch*2)
self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
self.uconv11 = ConvBlock(base_ch*2, base_ch*2)
self.up0 = UpConvBlock (base_ch*2, base_ch)
self.uconv02 = ConvBlock(base_ch*2, base_ch)
self.uconv01 = ConvBlock(base_ch, base_ch)
self.out_conv = nn.Conv2d (base_ch, out_ch, kernel_size=7, padding=3)
def forward(self, inp):
x = inp
x = self.conv01(x)
x = x0 = self.conv02(x)
x = self.bp0(x)
x = self.conv11(x)
x = x1 = self.conv12(x)
x = self.bp1(x)
x = self.conv21(x)
x = x2 = self.conv22(x)
x = self.bp2(x)
x = self.conv31(x)
x = self.conv32(x)
x = x3 = self.conv33(x)
x = self.bp3(x)
x = self.conv41(x)
x = self.conv42(x)
x = x4 = self.conv43(x)
x = self.bp4(x)
x = self.conv51(x)
x = self.conv52(x)
x = x5 = self.conv53(x)
x = self.bp5(x)
x = x.view(x.shape[0], -1)
x = self.dense1(x)
x = self.dense2(x)
x = x.view (-1, self.base_ch*8, 4, 4)
x = self.up5(x)
x = self.uconv53(torch.cat([x,x5],axis=1))
x = self.uconv52(x)
x = self.uconv51(x)
x = self.up4(x)
x = self.uconv43(torch.cat([x,x4],axis=1))
x = self.uconv42(x)
x = self.uconv41(x)
x = self.up3(x)
x = self.uconv33(torch.cat([x,x3],axis=1))
x = self.uconv32(x)
x = self.uconv31(x)
x = self.up2(x)
x = self.uconv22(torch.cat([x,x2],axis=1))
x = self.uconv21(x)
x = self.up1(x)
x = self.uconv12(torch.cat([x,x1],axis=1))
x = self.uconv11(x)
x = self.up0(x)
x = self.uconv02(torch.cat([x,x0],axis=1))
x = self.uconv01(x)
x = self.out_conv(x)
return x