mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-11-23 00:20:06 -08:00
62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
from xlib.image import ImageProcessor
|
|
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
|
|
get_available_devices_info)
|
|
|
|
|
|
class FaceMesh:
|
|
"""
|
|
Google FaceMesh detection model.
|
|
|
|
arguments
|
|
|
|
device_info ORTDeviceInfo
|
|
|
|
use FaceMesh.get_available_devices()
|
|
to determine a list of avaliable devices accepted by model
|
|
|
|
raises
|
|
Exception
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_available_devices() -> List[ORTDeviceInfo]:
|
|
return get_available_devices_info()
|
|
|
|
def __init__(self, device_info : ORTDeviceInfo):
|
|
if device_info not in FaceMesh.get_available_devices():
|
|
raise Exception(f'device_info {device_info} is not in available devices for FaceMesh')
|
|
|
|
path = Path(__file__).parent / 'FaceMesh.onnx'
|
|
if not path.exists():
|
|
raise FileNotFoundError(f'{path} not found')
|
|
|
|
self._sess = sess = InferenceSession_with_device(str(path), device_info)
|
|
self._input_name = sess.get_inputs()[0].name
|
|
self._input_width = 192
|
|
self._input_height = 192
|
|
|
|
def extract(self, img):
|
|
"""
|
|
arguments
|
|
|
|
img np.ndarray HW,HWC,NHWC uint8/float32
|
|
|
|
returns (N,468,3)
|
|
"""
|
|
ip = ImageProcessor(img)
|
|
N,H,W,_ = ip.get_dims()
|
|
|
|
h_scale = H / self._input_height
|
|
w_scale = W / self._input_width
|
|
|
|
feed_img = ip.resize( (self._input_width, self._input_height) ).to_ufloat32().ch(3).get_image('NHWC')
|
|
|
|
lmrks = self._sess.run(None, {self._input_name: feed_img})[0]
|
|
lmrks = lmrks.reshape( (N,468,3))
|
|
lmrks *= (w_scale, h_scale, 1)
|
|
|
|
return lmrks
|