mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-11-23 00:20:06 -08:00
90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
from xlib.file import SplittedFile
|
|
from xlib.image import ImageProcessor
|
|
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
|
|
get_available_devices_info)
|
|
|
|
|
|
class LIA:
|
|
"""
|
|
Latent Image Animator: Learning to Animate Images via Latent Space Navigation
|
|
https://github.com/wyhsirius/LIA
|
|
|
|
arguments
|
|
|
|
device_info ORTDeviceInfo
|
|
|
|
use LIA.get_available_devices()
|
|
to determine a list of avaliable devices accepted by model
|
|
|
|
raises
|
|
Exception
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_available_devices() -> List[ORTDeviceInfo]:
|
|
return get_available_devices_info()
|
|
|
|
def __init__(self, device_info : ORTDeviceInfo):
|
|
if device_info not in LIA.get_available_devices():
|
|
raise Exception(f'device_info {device_info} is not in available devices for LIA')
|
|
|
|
generator_path = Path(__file__).parent / 'generator.onnx'
|
|
SplittedFile.merge(generator_path, delete_parts=False)
|
|
if not generator_path.exists():
|
|
raise FileNotFoundError(f'{generator_path} not found')
|
|
|
|
self._generator = InferenceSession_with_device(str(generator_path), device_info)
|
|
|
|
|
|
def get_input_size(self):
|
|
"""
|
|
returns optimal (Width,Height) for input images, thus you can resize source image to avoid extra load
|
|
"""
|
|
return (256,256)
|
|
|
|
def extract_motion(self, img : np.ndarray):
|
|
"""
|
|
Extract motion from image
|
|
|
|
arguments
|
|
|
|
img np.ndarray HW HWC 1HWC uint8/float32
|
|
"""
|
|
feed_img = ImageProcessor(img).resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW')
|
|
return self._generator.run(['out_drv_motion'], {'in_src': np.zeros((1,3,256,256), np.float32),
|
|
'in_drv': feed_img,
|
|
'in_drv_start_motion': np.zeros((1,20), np.float32),
|
|
'in_power' : np.zeros((1,), np.float32)
|
|
})[0]
|
|
|
|
|
|
|
|
def generate(self, img_source : np.ndarray, img_driver : np.ndarray, driver_start_motion : np.ndarray, power):
|
|
"""
|
|
|
|
arguments
|
|
|
|
img_source np.ndarray HW HWC 1HWC uint8/float32
|
|
|
|
img_driver np.ndarray HW HWC 1HWC uint8/float32
|
|
|
|
driver_start_motion reference motion for driver
|
|
"""
|
|
ip = ImageProcessor(img_source)
|
|
dtype = ip.get_dtype()
|
|
_,H,W,_ = ip.get_dims()
|
|
|
|
out = self._generator.run(['out'], {'in_src': ip.resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW'),
|
|
'in_drv' : ImageProcessor(img_driver).resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW'),
|
|
'in_drv_start_motion' : driver_start_motion,
|
|
'in_power' : np.array([power], np.float32)
|
|
})[0].transpose(0,2,3,1)[0]
|
|
|
|
out = ImageProcessor(out).to_dtype(dtype, from_tanh=True).resize((W,H)).swap_ch().get_image('HWC')
|
|
return out
|
|
|