DeepFaceLive/xlib/avecl/_internal/HArgs.py
2021-09-30 18:21:30 +04:00

107 lines
3.2 KiB
Python

from typing import List
import numpy as np
from .backend import Device
from .HTensor import HTensor
from .HType import HType
from .Tensor import Tensor
from .AShape import AShape
class HArgs:
"""
Helper functions for list of arguments
"""
@staticmethod
def decompose(args):
"""
decompose list of args of Tensor and supported numeric values
returns ( shape_list, # if scalar value -> shape is None
dtype_list, #
kernel_args_list #
)
"""
shape_list = []
dtype_list = []
kernel_args_list = []
for arg in args:
if isinstance(arg, Tensor):
shape_list.append(arg.shape)
dtype_list.append(arg.dtype)
kernel_args_list.append(arg.get_buffer())
else:
if isinstance(arg, int):
dtype, arg = np.int32, np.int32(arg)
elif isinstance(arg, float):
dtype, arg = np.float32, np.float32(arg)
elif HType.is_obj_of_np_scalar_type(arg):
dtype = arg.__class__
else:
raise ValueError(f'Unsupported type of arg: {arg.__class__} Use Tensor or number type.')
shape_list.append(None)
dtype_list.append(dtype)
kernel_args_list.append(arg)
return tuple(shape_list), tuple(dtype_list), tuple(kernel_args_list)
@staticmethod
def get_shapes(args : List[Tensor]) -> List[AShape]:
"""
"""
return tuple(t.shape for t in args)
@staticmethod
def check_zero_get_length(args) -> int:
"""
raises an error if len(args) == 0, otherwise returns len
"""
args_len = len(args)
if len(args) == 0:
raise ValueError('args must be specified')
return args_len
@staticmethod
def check_get_same_device(args : List[Tensor]) -> Device:
"""
check all device of tensors are the same and return the device
"""
result = HTensor.all_same_device(args)
if not result:
raise ValueError('all Tensors must have the same device')
return args[0].get_device()
@staticmethod
def check_all_tensors(args : List[Tensor]):
"""
"""
if not all (isinstance(tensor, Tensor) for tensor in args):
raise ValueError('All values must have type of Tensor')
@staticmethod
def check_get_same_shape(args : List[Tensor]) -> AShape:
"""
check all shapes of tensors are the same and return the shape
"""
shape = args[0].shape
if not all (t.shape == shape for t in args):
raise ValueError('All tensors must have the same shape')
return shape
@staticmethod
def filter_tensor(args, raise_on_empty : bool):
"""
get only tensors from the list
"""
tensor_args = [arg for arg in args if isinstance(arg, Tensor) ]
if raise_on_empty and len(tensor_args) == 0:
raise ValueError('At least one arg must be a Tensor')
return tensor_args
__all__ = ['HArgs']