mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from .backend import Device
|
|
from .HTensor import HTensor
|
|
from .HType import HType
|
|
from .Tensor import Tensor
|
|
from .AShape import AShape
|
|
|
|
class HArgs:
|
|
"""
|
|
Helper functions for list of arguments
|
|
"""
|
|
|
|
@staticmethod
|
|
def decompose(args):
|
|
"""
|
|
decompose list of args of Tensor and supported numeric values
|
|
|
|
returns ( shape_list, # if scalar value -> shape is None
|
|
dtype_list, #
|
|
kernel_args_list #
|
|
)
|
|
"""
|
|
shape_list = []
|
|
dtype_list = []
|
|
kernel_args_list = []
|
|
for arg in args:
|
|
|
|
if isinstance(arg, Tensor):
|
|
shape_list.append(arg.shape)
|
|
dtype_list.append(arg.dtype)
|
|
kernel_args_list.append(arg.get_buffer())
|
|
else:
|
|
|
|
if isinstance(arg, int):
|
|
dtype, arg = np.int32, np.int32(arg)
|
|
elif isinstance(arg, float):
|
|
dtype, arg = np.float32, np.float32(arg)
|
|
elif HType.is_obj_of_np_scalar_type(arg):
|
|
dtype = arg.__class__
|
|
else:
|
|
raise ValueError(f'Unsupported type of arg: {arg.__class__} Use Tensor or number type.')
|
|
|
|
shape_list.append(None)
|
|
dtype_list.append(dtype)
|
|
kernel_args_list.append(arg)
|
|
|
|
return tuple(shape_list), tuple(dtype_list), tuple(kernel_args_list)
|
|
|
|
@staticmethod
|
|
def get_shapes(args : List[Tensor]) -> List[AShape]:
|
|
"""
|
|
"""
|
|
return tuple(t.shape for t in args)
|
|
|
|
@staticmethod
|
|
def check_zero_get_length(args) -> int:
|
|
"""
|
|
raises an error if len(args) == 0, otherwise returns len
|
|
"""
|
|
args_len = len(args)
|
|
if len(args) == 0:
|
|
raise ValueError('args must be specified')
|
|
return args_len
|
|
|
|
@staticmethod
|
|
def check_get_same_device(args : List[Tensor]) -> Device:
|
|
"""
|
|
check all device of tensors are the same and return the device
|
|
"""
|
|
result = HTensor.all_same_device(args)
|
|
if not result:
|
|
raise ValueError('all Tensors must have the same device')
|
|
return args[0].get_device()
|
|
|
|
@staticmethod
|
|
def check_all_tensors(args : List[Tensor]):
|
|
"""
|
|
"""
|
|
if not all (isinstance(tensor, Tensor) for tensor in args):
|
|
raise ValueError('All values must have type of Tensor')
|
|
|
|
@staticmethod
|
|
def check_get_same_shape(args : List[Tensor]) -> AShape:
|
|
"""
|
|
check all shapes of tensors are the same and return the shape
|
|
"""
|
|
shape = args[0].shape
|
|
if not all (t.shape == shape for t in args):
|
|
raise ValueError('All tensors must have the same shape')
|
|
return shape
|
|
|
|
|
|
@staticmethod
|
|
def filter_tensor(args, raise_on_empty : bool):
|
|
"""
|
|
get only tensors from the list
|
|
"""
|
|
tensor_args = [arg for arg in args if isinstance(arg, Tensor) ]
|
|
if raise_on_empty and len(tensor_args) == 0:
|
|
raise ValueError('At least one arg must be a Tensor')
|
|
return tensor_args
|
|
|
|
__all__ = ['HArgs']
|