DeepFaceLab/nnlib/DeepPortraitRelighting.py
2019-11-13 00:19:26 +04:00

242 lines
10 KiB
Python

import math
from pathlib import Path
import cv2
import numpy as np
import numpy.linalg as npla
class DeepPortraitRelighting(object):
def __init__(self):
from nnlib import nnlib
nnlib.import_torch()
self.torch = nnlib.torch
self.torch_device = nnlib.torch_device
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)
def SH_basis(self, alt, azi):
alt = alt * math.pi / 180.0
azi = azi * math.pi / 180.0
x = math.cos(alt)*math.sin(azi)
y = -math.cos(alt)*math.cos(azi)
z = math.sin(alt)
normal = np.array([x,y,z])
norm_X = normal[0]
norm_Y = normal[1]
norm_Z = normal[2]
sh_basis = np.zeros((9))
att= np.pi*np.array([1, 2.0/3.0, 1/4.0])
sh_basis[0] = 0.5/np.sqrt(np.pi)*att[0]
sh_basis[1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1]
sh_basis[2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1]
sh_basis[3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1]
sh_basis[4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2]
sh_basis[5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2]
sh_basis[6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2]
sh_basis[7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2]
sh_basis[8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2]
return sh_basis
#n = [0..8]
def relight(self, img, alt, azi, intensity=1.0, lighten=False):
torch = self.torch
sh = self.SH_basis (alt, azi)
sh = (sh.reshape( (1,9,1,1) ) ).astype(np.float32)
#sh *= 0.1
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))
row, col, _ = img.shape
img = cv2.resize(img, (512, 512))
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
inputL = Lab[:,:,0]
outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)),
sh, 0)
outputImg = outputImg[0].cpu().data.numpy()
outputImg = outputImg.transpose((1,2,0))
outputImg = np.squeeze(outputImg)
outputImg = np.clip (outputImg, 0.0, 1.0)
outputImg = cv2.blur(outputImg, (3,3) )
if not lighten:
outputImg = inputL*(1.0-intensity) + (inputL*outputImg)*intensity
else:
outputImg = inputL*(1.0-intensity) + (outputImg*255.0)*intensity
outputImg = np.clip(outputImg, 0,255).astype(np.uint8)
Lab[:,:,0] = outputImg
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
result = cv2.resize(result, (col, row))
return result
@staticmethod
def build_model(torch, torch_device):
nn = torch.nn
F = torch.nn.functional
def conv3X3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
# define the network
class BasicBlock(nn.Module):
def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# batchNorm_type 0 means batchnormalization
# 1 means instance normalization
self.inplanes = inplanes
self.outplanes = outplanes
self.conv1 = conv3X3(inplanes, outplanes, 1)
self.conv2 = conv3X3(outplanes, outplanes, 1)
if batchNorm_type == 0:
self.bn1 = nn.BatchNorm2d(outplanes)
self.bn2 = nn.BatchNorm2d(outplanes)
else:
self.bn1 = nn.InstanceNorm2d(outplanes)
self.bn2 = nn.InstanceNorm2d(outplanes)
self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.inplanes != self.outplanes:
out += self.shortcuts(x)
else:
out += x
out = F.relu(out)
return out
class HourglassBlock(nn.Module):
def __init__(self, inplane, mid_plane, middleNet, skipLayer=True):
super(HourglassBlock, self).__init__()
# upper branch
self.skipLayer = True
self.upper = BasicBlock(inplane, inplane, batchNorm_type=1)
# lower branch
self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
self.upSample = nn.Upsample(scale_factor=2, mode='nearest')
self.low1 = BasicBlock(inplane, mid_plane)
self.middle = middleNet
self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1)
def forward(self, x, light, count, skip_count):
# we use count to indicate wich layer we are in
# max_count indicates the from which layer, we would use skip connections
out_upper = self.upper(x)
out_lower = self.downSample(x)
out_lower = self.low1(out_lower)
out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count)
out_lower = self.low2(out_lower)
out_lower = self.upSample(out_lower)
if count >= skip_count and self.skipLayer:
out = out_lower + out_upper
else:
out = out_lower
return out, out_middle
class lightingNet(nn.Module):
def __init__(self, ncInput, ncOutput, ncMiddle):
super(lightingNet, self).__init__()
self.ncInput = ncInput
self.ncOutput = ncOutput
self.ncMiddle = ncMiddle
self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
self.predict_relu1 = nn.PReLU()
self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False)
self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
self.post_relu1 = nn.PReLU()
self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False)
self.post_relu2 = nn.ReLU() # to be consistance with the original feature
def forward(self, innerFeat, target_light, count, skip_count):
x = innerFeat[:,0:self.ncInput,:,:] # lighting feature
_, _, row, col = x.shape
# predict lighting
feat = x.mean(dim=(2,3), keepdim=True)
light = self.predict_relu1(self.predict_FC1(feat))
light = self.predict_FC2(light)
upFeat = self.post_relu1(self.post_FC1(target_light))
upFeat = self.post_relu2(self.post_FC2(upFeat))
upFeat = upFeat.repeat((1,1,row, col))
innerFeat[:,0:self.ncInput,:,:] = upFeat
return innerFeat, light#light
class HourglassNet(nn.Module):
def __init__(self, baseFilter = 16, gray=True):
super(HourglassNet, self).__init__()
self.ncLight = 27 # number of channels for input to lighting network
self.baseFilter = baseFilter
# number of channles for output of lighting network
if gray:
self.ncOutLight = 9 # gray: channel is 1
else:
self.ncOutLight = 27 # color: channel is 3
self.ncPre = self.baseFilter # number of channels for pre-convolution
# number of channels
self.ncHG3 = self.baseFilter
self.ncHG2 = 2*self.baseFilter
self.ncHG1 = 4*self.baseFilter
self.ncHG0 = 8*self.baseFilter + self.ncLight
self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2)
self.pre_bn = nn.BatchNorm2d(self.ncPre)
self.light = lightingNet(self.ncLight, self.ncOutLight, 128)
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light)
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0)
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1)
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2)
self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1)
self.bn_1 = nn.BatchNorm2d(self.ncPre)
self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
self.bn_2 = nn.BatchNorm2d(self.ncPre)
self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
self.bn_3 = nn.BatchNorm2d(self.ncPre)
self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0)
def forward(self, x, target_light, skip_count):
feat = self.pre_conv(x)
feat = F.relu(self.pre_bn(feat))
# get the inner most features
feat, out_light = self.HG3(feat, target_light, 0, skip_count)
#return feat, out_light
feat = F.relu(self.bn_1(self.conv_1(feat)))
feat = F.relu(self.bn_2(self.conv_2(feat)))
feat = F.relu(self.bn_3(self.conv_3(feat)))
out_img = self.output(feat)
out_img = torch.sigmoid(out_img)
return out_img, out_light
model = HourglassNet()
t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' )
model.load_state_dict(t_dict)
model.to( torch_device )
model.train(False)
return model