mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
95 lines
2.6 KiB
Python
95 lines
2.6 KiB
Python
from .op import *
|
|
from .AShape import *
|
|
from .backend import *
|
|
from .Tensor import Tensor
|
|
|
|
def Tensor__str__(self : Tensor): return f"T {self.name} {self.shape} {self.dtype.name}"
|
|
Tensor.__str__ = Tensor__str__
|
|
|
|
def Tensor__repr__(self : Tensor):
|
|
s = self.__str__() + '\n'
|
|
s += str(self.np()) + '\n'
|
|
s += self.__str__()
|
|
return s
|
|
Tensor.__repr__ = Tensor__repr__
|
|
|
|
def Tensor__add__(self : Tensor, value) -> Tensor:
|
|
return add(self, value)
|
|
Tensor.__add__ = Tensor__add__
|
|
|
|
def Tensor__radd__(self : Tensor, value) -> Tensor:
|
|
return add(value, self)
|
|
Tensor.__radd__ = Tensor__radd__
|
|
|
|
def Tensor__sub__(self : Tensor, value) -> Tensor:
|
|
return sub(self, value)
|
|
Tensor.__sub__ = Tensor__sub__
|
|
|
|
def Tensor__rsub__(self : Tensor, value) -> Tensor:
|
|
return sub(value, self)
|
|
Tensor.__rsub__ = Tensor__rsub__
|
|
|
|
def Tensor__mul__(self : Tensor, value) -> Tensor:
|
|
if self == value:
|
|
return square(self)
|
|
return mul(self, value)
|
|
Tensor.__mul__ = Tensor__mul__
|
|
|
|
def Tensor__rmul__(self : Tensor, value) -> Tensor:
|
|
if self == value:
|
|
return square(self)
|
|
return mul(value, self)
|
|
Tensor.__rmul__ = Tensor__rmul__
|
|
|
|
def Tensor__truediv__(self : Tensor, value) -> Tensor:
|
|
return div(self, value)
|
|
Tensor.__truediv__ = Tensor__truediv__
|
|
|
|
def Tensor__rtruediv__(self : Tensor, value) -> Tensor:
|
|
return div(value, self)
|
|
Tensor.__rtruediv__ = Tensor__rtruediv__
|
|
|
|
def Tensor___neg__(self : Tensor):
|
|
raise NotImplementedError()
|
|
Tensor.___neg__ = Tensor___neg__
|
|
|
|
Tensor.__getitem__ = slice_
|
|
Tensor.__setitem__ = slice_set
|
|
|
|
def Tensor_as_shape(self : Tensor, shape) -> Tensor:
|
|
return TensorRef(self, shape)
|
|
Tensor.as_shape = Tensor_as_shape
|
|
|
|
Tensor.cast = cast
|
|
def Tensor_copy(self : Tensor) -> Tensor:
|
|
return Tensor.from_value(self)
|
|
Tensor.copy = Tensor_copy
|
|
|
|
Tensor.max = reduce_max
|
|
Tensor.mean = reduce_mean
|
|
Tensor.min = reduce_min
|
|
Tensor.reshape = reshape
|
|
Tensor.sum = reduce_sum
|
|
Tensor.std = reduce_std
|
|
Tensor.transpose = transpose
|
|
|
|
class TensorRef(Tensor):
|
|
"""
|
|
TensorRef used to interpret existing Tensor with different shape.
|
|
use Tensor.as_ref() method
|
|
"""
|
|
|
|
def __init__(self, t : Tensor, shape):
|
|
shape = AShape(shape)
|
|
if t.shape.size != shape.size:
|
|
raise ValueError(f'Cannot interpet shape {t.shape} as ref shape {shape}')
|
|
super().__init__(shape, t.dtype, device=t.get_device())
|
|
self._t = t
|
|
|
|
# Forward methods to original tensor
|
|
def get_seq_id(self) -> int: return self._t.get_seq_id()
|
|
def get_buffer(self) -> Buffer: return self._t.get_buffer()
|
|
def get_device(self) -> Device: return self._t.get_device()
|
|
|
|
__all__ = []
|