DeepFaceLive/xlib/avecl/_internal/HTensor.py
2021-09-30 18:21:30 +04:00

23 lines
560 B
Python

from typing import List, Union
from .Tensor import *
class HTensor:
"""
Helper functions for Tensor
"""
@staticmethod
def all_same_device( tensor_or_list : Union[Tensor, List[Tensor] ] ) -> bool:
"""
check if all tensors in a list use the same device
"""
if not isinstance(tensor_or_list, (list,tuple)):
tensor_or_list = (tensor_or_list,)
device = tensor_or_list[0].get_device()
return all( device == tensor.get_device() for tensor in tensor_or_list )
__all__ = ['HTensor']