DeepFaceLive/xlib/avecl/_internal/TensorImpl.py
2021-10-20 18:02:50 +04:00

95 lines
2.6 KiB
Python

from .op import *
from .AShape import *
from .backend import *
from .Tensor import Tensor
def Tensor__str__(self : Tensor): return f"T {self.name} {self.shape} {self.dtype.name}"
Tensor.__str__ = Tensor__str__
def Tensor__repr__(self : Tensor):
s = self.__str__() + '\n'
s += str(self.np()) + '\n'
s += self.__str__()
return s
Tensor.__repr__ = Tensor__repr__
def Tensor__add__(self : Tensor, value) -> Tensor:
return add(self, value)
Tensor.__add__ = Tensor__add__
def Tensor__radd__(self : Tensor, value) -> Tensor:
return add(value, self)
Tensor.__radd__ = Tensor__radd__
def Tensor__sub__(self : Tensor, value) -> Tensor:
return sub(self, value)
Tensor.__sub__ = Tensor__sub__
def Tensor__rsub__(self : Tensor, value) -> Tensor:
return sub(value, self)
Tensor.__rsub__ = Tensor__rsub__
def Tensor__mul__(self : Tensor, value) -> Tensor:
if self == value:
return square(self)
return mul(self, value)
Tensor.__mul__ = Tensor__mul__
def Tensor__rmul__(self : Tensor, value) -> Tensor:
if self == value:
return square(self)
return mul(value, self)
Tensor.__rmul__ = Tensor__rmul__
def Tensor__truediv__(self : Tensor, value) -> Tensor:
return div(self, value)
Tensor.__truediv__ = Tensor__truediv__
def Tensor__rtruediv__(self : Tensor, value) -> Tensor:
return div(value, self)
Tensor.__rtruediv__ = Tensor__rtruediv__
def Tensor___neg__(self : Tensor):
raise NotImplementedError()
Tensor.___neg__ = Tensor___neg__
Tensor.__getitem__ = slice_
Tensor.__setitem__ = slice_set
def Tensor_as_shape(self : Tensor, shape) -> Tensor:
return TensorRef(self, shape)
Tensor.as_shape = Tensor_as_shape
Tensor.cast = cast
def Tensor_copy(self : Tensor) -> Tensor:
return Tensor.from_value(self)
Tensor.copy = Tensor_copy
Tensor.max = reduce_max
Tensor.mean = reduce_mean
Tensor.min = reduce_min
Tensor.reshape = reshape
Tensor.sum = reduce_sum
Tensor.std = reduce_std
Tensor.transpose = transpose
class TensorRef(Tensor):
"""
TensorRef used to interpret existing Tensor with different shape.
use Tensor.as_ref() method
"""
def __init__(self, t : Tensor, shape):
shape = AShape(shape)
if t.shape.size != shape.size:
raise ValueError(f'Cannot interpet shape {t.shape} as ref shape {shape}')
super().__init__(shape, t.dtype, device=t.get_device())
self._t = t
# Forward methods to original tensor
def get_seq_id(self) -> int: return self._t.get_seq_id()
def get_buffer(self) -> Buffer: return self._t.get_buffer()
def get_device(self) -> Device: return self._t.get_device()
__all__ = []