DeepFaceLive/xlib/avecl/_internal/AAxes.py
2021-09-30 18:21:30 +04:00

169 lines
5.0 KiB
Python

from collections import Iterable
class AAxes(Iterable):
__slots__ = ['axes','ndim','_inversed']
def __init__(self, axes, shape_ndim=None):
"""
Constructs AAxes from user argument
arguments
axes AAxes
Int
Iterable of ints
None
shape_ndim(None) provide shape_ndim if axes contain negative values
can raise an errors during the construction
AAxes supports:
A+B : concat A_axes with B_axes
A-B : removes B_axes from A_axes
"""
if isinstance(axes, AAxes):
self.axes = axes.axes
self.ndim = axes.ndim
self._inversed = axes._inversed
elif axes is None:
self.axes = None
self.ndim = None
self._inversed = None
else:
if not isinstance(axes, Iterable):
axes = (axes,)
if isinstance(axes, Iterable):
valid_axes = []
for x in axes:
if x is None:
raise ValueError(f'Incorrent value {x} in axes {axes}')
x = int(x)
if x < 0:
if shape_ndim is None:
raise ValueError(f'Incorrent value {x} in axes {axes}, or provide shape_ndim')
x = shape_ndim + x
if x in valid_axes:
raise ValueError(f'Axes must contain unique values.')
valid_axes.append(x)
self.axes = tuple(valid_axes)
self.ndim = len(self.axes)
self._inversed = None
def is_none_axes(self):
"""
returns True if AAxes is constructed with (None) argument, i.e. all-axes
"""
return self.axes is None
def sorted(self) -> 'AAxes':
"""
returns sorted AAxes
"""
return AAxes(sorted(self.axes))
def swapped_axes(self, axis_a, axis_b) -> 'AAxes':
x = list(self.axes)
if axis_a < 0:
axis_a = len(x) + axis_a
if axis_b < 0:
axis_b = len(x) + axis_b
x[axis_b], x[axis_a] = x[axis_a], x[axis_b]
return AAxes( tuple(x) )
def inversed(self) -> 'AAxes':
"""
Returns inversed axes order
Example:
for (0,2,3,1) returns (0,3,1,2)
"""
if self.is_none_axes():
raise Exception(f'none-axes does not support inversed(). Handle none-axes by calling .is_none_axes()')
if self._inversed is None:
x = { axis:i for i,axis in enumerate(self.axes) }
t = []
for i in range(self.ndim):
axis = x.get(i, None)
if axis is None:
raise Exception(f'axes {self.axes} are inconsistent to do inverse order.')
t.append(axis)
self._inversed = AAxes(t)
return self._inversed
def __hash__(self): return self.axes.__hash__()
def __eq__(self, other):
if isinstance(other, AAxes):
return self.axes == other.axes
elif isinstance(other, Iterable):
return self.axes == tuple(other)
return False
def __iter__(self):
if self.is_none_axes():
raise Exception(f'none-axes does not support iteration. Handle none-axes by calling .is_none_axes()')
return self.axes.__iter__()
def __len__(self): return self.ndim
def __getitem__(self,key):
if self.is_none_axes():
raise Exception(f'none-axes does not support indexing. Handle none-axes by calling .is_none_axes()')
elif isinstance(key, slice):
return AAxes(self.axes[key])
return self.axes[key]
def __radd__(self, o):
if isinstance(o, Iterable):
return AAxes( tuple(o) + self.axes)
else:
raise ValueError(f'unable to use type {o.__class__} in AAxes append')
def __add__(self, o):
if isinstance(o, Iterable):
return AAxes( self.axes + tuple(o) )
else:
raise ValueError(f'unable to use type {o.__class__} in AAxes append')
def __rsub__(self, o):
if isinstance(o, Iterable):
new_axes = []
for axis in o:
if axis not in self.axes:
new_axes.append(axis)
return AAxes(new_axes)
else:
raise ValueError(f'unable to use type {o.__class__} in AAxes substraction')
def __sub__(self, o):
if isinstance(o, Iterable):
new_axes = []
o_axes = tuple(o)
for axis in self.axes:
if axis not in o_axes:
new_axes.append(axis)
return AAxes(new_axes)
else:
raise ValueError(f'unable to use type {o.__class__} in AAxes substraction')
def __str__(self):
if self.is_none_axes():
return '(None)'
return str(self.axes)
def __repr__(self): return 'AAxes' + self.__str__()
__all__ = ['AAxes']