DeepFaceLive/xlib/avecl/_internal/op/reshape.py
2021-09-30 18:21:30 +04:00

27 lines
704 B
Python

from typing import Iterable
from ..Tensor import Tensor
from ..SCacheton import SCacheton
from ..info import ReshapeInfo
def reshape(input_t : Tensor, new_shape : Iterable, copy=True) -> Tensor:
"""
reshape operator
arguments
new_shape Iterable of ints
copy(True) if True, produces new Tensor
otherwise result tensor points to the same memory
Produces reference Tensor with new shape.
"""
info = SCacheton.get(ReshapeInfo, input_t.shape, tuple(int(x) for x in new_shape) )
if copy:
return Tensor(info.o_shape, input_t.dtype, device=input_t.get_device()).set(input_t)
return input_t.as_shape( info.o_shape )