DeepFaceLive/xlib/avecl/_internal/AShape.py
2021-10-20 18:02:50 +04:00

157 lines
4.5 KiB
Python

from collections import Iterable
from typing import Tuple, List
from .AAxes import AAxes
class AShape(Iterable):
__slots__ = ['shape','size','ndim']
def __init__(self, shape):
"""
Constructs valid shape from user argument
arguments
shape AShape
Iterable
AShape cannot be scalar shape, thus minimal AShape is (1,)
can raise ValueError during the construction
"""
if isinstance(shape, AShape):
self.shape = shape.shape
self.size = shape.size
self.ndim = shape.ndim
else:
if isinstance(shape, (int,float) ):
shape = (int(shape),)
if isinstance(shape, Iterable):
size = 1
valid_shape = []
for x in shape:
if x is None:
raise ValueError(f'Incorrent value {x} in shape {shape}')
x = int(x)
if x < 1:
raise ValueError(f'Incorrent value {x} in shape {shape}')
valid_shape.append(x)
size *= x # Faster than np.prod()
self.shape = tuple(valid_shape)
self.ndim = len(self.shape)
if self.ndim == 0:
# Force (1,) shape for scalar shape
self.ndim = 1
self.shape = (1,)
self.size = size
else:
raise ValueError('Invalid type to create AShape')
def copy(self) -> 'AShape':
return AShape(self)
def as_list(self) -> List[int]:
return list(self.shape)
def check_axis(self, axis : int) -> int:
"""
Check axis and returns normalized axis value
can raise ValueError
"""
if axis < 0:
axis += self.ndim
if axis < 0 or axis >= self.ndim:
raise ValueError(f'axis {axis} out of bound of ndim {self.ndim}')
return axis
def axes_arange(self) -> AAxes:
"""
Returns tuple of axes arange.
Example (0,1,2) for ndim 3
"""
return AAxes(range(self.ndim))
def replaced_axes(self, axes, dims) -> 'AShape':
"""
returns new AShape where axes replaced with new dims
"""
new_shape = list(self.shape)
ndim = self.ndim
for axis, dim in zip(axes, dims):
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise ValueError(f'invalid axis value {axis}')
new_shape[axis] = dim
return AShape(new_shape)
def split(self, axis) -> Tuple['AShape', 'AShape']:
"""
split AShape at specified axis
returns two AShape before+exclusive and inclusive+after
"""
if axis < 0:
axis = self.ndim + axis
if axis < 0 or axis >= self.ndim:
raise ValueError(f'invalid axis value {axis}')
return self[:axis], self[axis:]
def transpose_by_axes(self, axes) -> 'AShape':
"""
Same as AShape[axes]
Returns AShape transposed by axes.
axes AAxes
Iterable(list,tuple,set,generator)
"""
return AShape(self.shape[axis] for axis in AAxes(axes) )
def __hash__(self): return self.shape.__hash__()
def __eq__(self, other):
if isinstance(other, AShape):
return self.shape == other.shape
elif isinstance(other, Iterable):
return self.shape == tuple(other)
return False
def __iter__(self): return self.shape.__iter__()
def __len__(self): return len(self.shape)
def __getitem__(self,key):
if isinstance(key, Iterable):
if isinstance(key, AAxes):
if key.is_none_axes():
return self
return self.transpose_by_axes(key)
elif isinstance(key, slice):
return AShape(self.shape[key])
return self.shape[key]
def __radd__(self, o):
if isinstance(o, Iterable):
return AShape( tuple(o) + self.shape)
else:
raise ValueError(f'unable to use type {o.__class__} in AShape append')
def __add__(self, o):
if isinstance(o, Iterable):
return AShape( self.shape + tuple(o) )
else:
raise ValueError(f'unable to use type {o.__class__} in AShape append')
def __str__(self): return str(self.shape)
def __repr__(self): return 'AShape' + self.__str__()
__all__ = ['AShape']