mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-03 21:10:16 -08:00
116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
from enum import IntEnum
|
|
from pathlib import Path
|
|
from typing import List, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from xlib.image import ImageProcessor
|
|
from xlib.net import ThreadFileDownloader
|
|
|
|
|
|
from xlib import tf
|
|
|
|
class DFMModelInfo:
|
|
"""
|
|
Model for face swapping.
|
|
|
|
"""
|
|
def __init__(self, celeb_type, device):
|
|
celeb_name = str(celeb_type).split('.')[-1]
|
|
model_path = self._model_path = Path(__file__).parent / 'CELEB_MODEL' / f'{celeb_name}.onnx'
|
|
|
|
self._dl_error = None
|
|
self._dl = None
|
|
if not model_path.exists():
|
|
self._dl = ThreadFileDownloader(rf'https://github.com/iperov/DeepFaceLive/releases/download/CELEB_MODEL/{celeb_name}.onnx', savepath=model_path)
|
|
|
|
self._device = device
|
|
self._sess = None
|
|
|
|
@staticmethod
|
|
def get_available_devices(celeb_type):
|
|
return tf.get_available_devices_info()
|
|
|
|
def get_download_error(self) -> Union[str,None]:
|
|
"""
|
|
returns download error or None if no error.
|
|
"""
|
|
if self._dl_error is not None:
|
|
return 'DFMModelInfo download error: ' + self._dl_error
|
|
return None
|
|
|
|
def update_download_progress(self) -> float:
|
|
"""
|
|
returns [0.0..100.0] where 100.0 mean the model is ready to convert.
|
|
"""
|
|
if self._dl_error is not None:
|
|
progress = 0.0
|
|
elif self._dl is None:
|
|
progress = 100.0
|
|
else:
|
|
err = self._dl_error = self._dl.get_error()
|
|
if err is not None:
|
|
progress = 0.0
|
|
else:
|
|
progress = self._dl.get_progress()
|
|
|
|
if progress == 100.0 and self._sess is None:
|
|
self._dl = None
|
|
|
|
model_path = r'D:\DevelopPPP\projects\DeepFaceLive\github_project\xlib\model_hub\tf\TOM_CRUISE.pb'
|
|
self._sess = tf.TFInferenceSession(model_path,
|
|
in_tensor_names=['in_face:0', 'morph_value:0'],
|
|
out_tensor_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
|
|
device_info=self._device)
|
|
|
|
|
|
return progress
|
|
|
|
def convert(self, img, morph_factor=0.75):
|
|
"""
|
|
img np.ndarray HW,HWC,NHWC uint8,float32
|
|
|
|
returns
|
|
|
|
img NHW3 same dtype as img
|
|
celeb_mask NHW1 same dtype as img
|
|
face_mask NHW1 same dtype as img
|
|
|
|
convert the face
|
|
if the model is not ready, it will return black images
|
|
"""
|
|
|
|
ip = ImageProcessor(img)
|
|
|
|
N,H,W,C = ip.get_dims()
|
|
dtype = ip.get_dtype()
|
|
|
|
progress = self.update_download_progress()
|
|
if progress == 100.0:
|
|
|
|
img = ip.resize( (224,224) ).ch(3).to_ufloat32().get_image('NHWC')
|
|
|
|
#out_face_mask, out_celeb, out_celeb_mask = self._sess.run(None, {self._input_name: img})
|
|
|
|
out_face_mask, out_celeb, out_celeb_mask = self._sess.run([img, [morph_factor] ])
|
|
|
|
out_celeb = ImageProcessor(out_celeb).resize((W,H)).ch(3).to_dtype(dtype).get_image('NHWC')
|
|
out_celeb_mask = ImageProcessor(out_celeb_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
|
|
out_face_mask = ImageProcessor(out_face_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
|
|
else:
|
|
out_celeb = np.zeros( (N,H,W,3), dtype=dtype )
|
|
out_celeb_mask = ImageProcessor(np.full( (N,H,W,1), 255, dtype=np.uint8 )).to_dtype(dtype).get_image('NHWC')
|
|
out_face_mask = ImageProcessor(np.full( (N,H,W,1), 255, dtype=np.uint8 )).to_dtype(dtype).get_image('NHWC')
|
|
|
|
return out_celeb, out_celeb_mask, out_face_mask
|
|
|
|
class CelebType(IntEnum):
|
|
TOM_CRUISE = 0
|
|
VLADIMIR_PUTIN = 1
|
|
|
|
CelebTypeNames = ['Tom Cruise',
|
|
'Vladimir Putin']
|
|
|
|
|
|
|