mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
23 lines
560 B
Python
23 lines
560 B
Python
from typing import List, Union
|
|
|
|
from .Tensor import *
|
|
|
|
|
|
class HTensor:
|
|
"""
|
|
Helper functions for Tensor
|
|
"""
|
|
|
|
@staticmethod
|
|
def all_same_device( tensor_or_list : Union[Tensor, List[Tensor] ] ) -> bool:
|
|
"""
|
|
check if all tensors in a list use the same device
|
|
"""
|
|
if not isinstance(tensor_or_list, (list,tuple)):
|
|
tensor_or_list = (tensor_or_list,)
|
|
|
|
device = tensor_or_list[0].get_device()
|
|
return all( device == tensor.get_device() for tensor in tensor_or_list )
|
|
|
|
__all__ = ['HTensor']
|