DeepFaceLab/samplelib/SampleGeneratorFaceCelebAMaskHQ.py
Colombo 61472cdaf7 global refactoring and fixes,
removed support of extracted(aligned) PNG faces. Use old builds to convert from PNG to JPG.

fanseg model file in facelib/ is renamed
2020-03-13 08:09:00 +04:00

270 lines
10 KiB
Python

import multiprocessing
import pickle
import time
import traceback
from enum import IntEnum
import cv2
import numpy as np
from core import imagelib, mplib, pathex
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import SampleGeneratorBase
class MaskType(IntEnum):
none = 0,
cloth = 1,
ear_r = 2,
eye_g = 3,
hair = 4,
hat = 5,
l_brow = 6,
l_ear = 7,
l_eye = 8,
l_lip = 9,
mouth = 10,
neck = 11,
neck_l = 12,
nose = 13,
r_brow = 14,
r_ear = 15,
r_eye = 16,
skin = 17,
u_lip = 18
MaskType_to_name = {
int(MaskType.none ) : 'none',
int(MaskType.cloth ) : 'cloth',
int(MaskType.ear_r ) : 'ear_r',
int(MaskType.eye_g ) : 'eye_g',
int(MaskType.hair ) : 'hair',
int(MaskType.hat ) : 'hat',
int(MaskType.l_brow) : 'l_brow',
int(MaskType.l_ear ) : 'l_ear',
int(MaskType.l_eye ) : 'l_eye',
int(MaskType.l_lip ) : 'l_lip',
int(MaskType.mouth ) : 'mouth',
int(MaskType.neck ) : 'neck',
int(MaskType.neck_l) : 'neck_l',
int(MaskType.nose ) : 'nose',
int(MaskType.r_brow) : 'r_brow',
int(MaskType.r_ear ) : 'r_ear',
int(MaskType.r_eye ) : 'r_eye',
int(MaskType.skin ) : 'skin',
int(MaskType.u_lip ) : 'u_lip',
}
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() }
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase):
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256,
generators_count=4, data_format="NHWC",
**kwargs):
super().__init__(debug, batch_size)
self.initialized = False
dataset_path = root_path / 'CelebAMask-HQ'
if not dataset_path.exists():
raise ValueError(f'Unable to find {dataset_path}')
images_path = dataset_path /'CelebA-HQ-img'
if not images_path.exists():
raise ValueError(f'Unable to find {images_path}')
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno'
if not masks_path.exists():
raise ValueError(f'Unable to find {masks_path}')
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True)
source_images_paths_len = len(source_images_paths)
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True)
if source_images_paths_len == 0 or len(mask_images_paths) == 0:
raise ValueError('No training data provided.')
mask_file_id_hash = {}
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"):
stem = filepath.stem
file_id, mask_type = stem.split('_', 1)
file_id = int(file_id)
if file_id not in mask_file_id_hash:
mask_file_id_hash[file_id] = {}
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path))
source_file_id_set = set()
for filepath in source_images_paths:
stem = filepath.stem
file_id = int(stem)
source_file_id_set.update ( {file_id} )
for k in mask_file_id_hash.keys():
if k not in source_file_id_set:
io.log_err (f"Corrupted dataset: {k} not in {images_path}")
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
self.generator_counter = -1
self.initialized = True
#overridable
def is_initialized(self):
return self.initialized
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, param ):
images_path, masks_path, mask_file_id_hash, data_format = param
file_ids = list(mask_file_id_hash.keys())
shuffle_file_ids = []
resolution = 256
random_flip = True
rotation_range=[-15,15]
scale_range=[-0.10, 0.95]
tx_range=[-0.3, 0.3]
ty_range=[-0.3, 0.3]
random_bilinear_resize = (25,75)
motion_blur = (25, 5)
gaussian_blur = (25, 5)
bs = self.batch_size
while True:
batches = None
n_batch = 0
while n_batch < bs:
try:
if len(shuffle_file_ids) == 0:
shuffle_file_ids = file_ids.copy()
np.random.shuffle(shuffle_file_ids)
file_id = shuffle_file_ids.pop()
masks = mask_file_id_hash[file_id]
image_path = images_path / f'{file_id}.jpg'
skin_path = masks.get(MaskType.skin, None)
hair_path = masks.get(MaskType.hair, None)
hat_path = masks.get(MaskType.hat, None)
#neck_path = masks.get(MaskType.neck, None)
img = cv2_imread(image_path).astype(np.float32) / 255.0
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0
if hair_path is not None:
hair_path = masks_path / hair_path
if hair_path.exists():
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hair)
if hat_path is not None:
hat_path = masks_path / hat_path
if hat_path.exists():
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0
mask *= (1-hat)
#if neck_path is not None:
# neck_path = masks_path / neck_path
# if neck_path.exists():
# neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0
# mask = np.clip(mask+neck, 0, 1)
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range )
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 )
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
h = ( h + np.random.randint(360) ) % 360
s = np.clip ( s + np.random.random()-0.5, 0, 1 )
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 )
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 )
if motion_blur is not None:
chance, mb_max_size = motion_blur
chance = np.clip(chance, 0, 100)
mblur_rnd_chance = np.random.randint(100)
mblur_rnd_kernel = np.random.randint(mb_max_size)+1
mblur_rnd_deg = np.random.randint(360)
if mblur_rnd_chance < chance:
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg )
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
if gaussian_blur is not None:
chance, kernel_max_size = gaussian_blur
chance = np.clip(chance, 0, 100)
gblur_rnd_chance = np.random.randint(100)
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1
if gblur_rnd_chance < chance:
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0)
if random_bilinear_resize is not None:
chance, max_size_per = random_bilinear_resize
chance = np.clip(chance, 0, 100)
pick_chance = np.random.randint(100)
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) )
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR )
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR )
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None]
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4)
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
mask = np.clip(mask, 0, 1)
if data_format == "NCHW":
img = np.transpose(img, (2,0,1) )
mask = np.transpose(mask, (2,0,1) )
if batches is None:
batches = [ [], [] ]
batches[0].append ( img )
batches[1].append ( mask )
n_batch += 1
except:
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]