mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 15:31:13 -08:00
169 lines
5.0 KiB
Python
169 lines
5.0 KiB
Python
from collections import Iterable
|
|
|
|
class AAxes(Iterable):
|
|
__slots__ = ['axes','ndim','_inversed']
|
|
|
|
def __init__(self, axes, shape_ndim=None):
|
|
"""
|
|
Constructs AAxes from user argument
|
|
|
|
arguments
|
|
|
|
axes AAxes
|
|
Int
|
|
Iterable of ints
|
|
None
|
|
|
|
shape_ndim(None) provide shape_ndim if axes contain negative values
|
|
|
|
can raise an errors during the construction
|
|
|
|
AAxes supports:
|
|
|
|
A+B : concat A_axes with B_axes
|
|
|
|
A-B : removes B_axes from A_axes
|
|
"""
|
|
|
|
if isinstance(axes, AAxes):
|
|
self.axes = axes.axes
|
|
self.ndim = axes.ndim
|
|
self._inversed = axes._inversed
|
|
elif axes is None:
|
|
self.axes = None
|
|
self.ndim = None
|
|
self._inversed = None
|
|
else:
|
|
if not isinstance(axes, Iterable):
|
|
axes = (axes,)
|
|
|
|
if isinstance(axes, Iterable):
|
|
valid_axes = []
|
|
for x in axes:
|
|
if x is None:
|
|
raise ValueError(f'Incorrent value {x} in axes {axes}')
|
|
x = int(x)
|
|
if x < 0:
|
|
if shape_ndim is None:
|
|
raise ValueError(f'Incorrent value {x} in axes {axes}, or provide shape_ndim')
|
|
x = shape_ndim + x
|
|
|
|
if x in valid_axes:
|
|
raise ValueError(f'Axes must contain unique values.')
|
|
valid_axes.append(x)
|
|
|
|
self.axes = tuple(valid_axes)
|
|
self.ndim = len(self.axes)
|
|
self._inversed = None
|
|
|
|
def is_none_axes(self):
|
|
"""
|
|
returns True if AAxes is constructed with (None) argument, i.e. all-axes
|
|
"""
|
|
return self.axes is None
|
|
|
|
def sorted(self) -> 'AAxes':
|
|
"""
|
|
returns sorted AAxes
|
|
"""
|
|
return AAxes(sorted(self.axes))
|
|
|
|
def swapped_axes(self, axis_a, axis_b) -> 'AAxes':
|
|
x = list(self.axes)
|
|
if axis_a < 0:
|
|
axis_a = len(x) + axis_a
|
|
if axis_b < 0:
|
|
axis_b = len(x) + axis_b
|
|
|
|
x[axis_b], x[axis_a] = x[axis_a], x[axis_b]
|
|
|
|
return AAxes( tuple(x) )
|
|
|
|
def inversed(self) -> 'AAxes':
|
|
"""
|
|
Returns inversed axes order
|
|
|
|
Example:
|
|
|
|
for (0,2,3,1) returns (0,3,1,2)
|
|
"""
|
|
if self.is_none_axes():
|
|
raise Exception(f'none-axes does not support inversed(). Handle none-axes by calling .is_none_axes()')
|
|
|
|
if self._inversed is None:
|
|
x = { axis:i for i,axis in enumerate(self.axes) }
|
|
t = []
|
|
for i in range(self.ndim):
|
|
axis = x.get(i, None)
|
|
if axis is None:
|
|
raise Exception(f'axes {self.axes} are inconsistent to do inverse order.')
|
|
t.append(axis)
|
|
self._inversed = AAxes(t)
|
|
|
|
return self._inversed
|
|
|
|
|
|
def __hash__(self): return self.axes.__hash__()
|
|
def __eq__(self, other):
|
|
if isinstance(other, AAxes):
|
|
return self.axes == other.axes
|
|
elif isinstance(other, Iterable):
|
|
return self.axes == tuple(other)
|
|
return False
|
|
def __iter__(self):
|
|
if self.is_none_axes():
|
|
raise Exception(f'none-axes does not support iteration. Handle none-axes by calling .is_none_axes()')
|
|
return self.axes.__iter__()
|
|
|
|
def __len__(self): return self.ndim
|
|
def __getitem__(self,key):
|
|
if self.is_none_axes():
|
|
raise Exception(f'none-axes does not support indexing. Handle none-axes by calling .is_none_axes()')
|
|
|
|
elif isinstance(key, slice):
|
|
return AAxes(self.axes[key])
|
|
|
|
return self.axes[key]
|
|
|
|
def __radd__(self, o):
|
|
if isinstance(o, Iterable):
|
|
return AAxes( tuple(o) + self.axes)
|
|
else:
|
|
raise ValueError(f'unable to use type {o.__class__} in AAxes append')
|
|
def __add__(self, o):
|
|
if isinstance(o, Iterable):
|
|
return AAxes( self.axes + tuple(o) )
|
|
else:
|
|
raise ValueError(f'unable to use type {o.__class__} in AAxes append')
|
|
|
|
def __rsub__(self, o):
|
|
if isinstance(o, Iterable):
|
|
new_axes = []
|
|
for axis in o:
|
|
if axis not in self.axes:
|
|
new_axes.append(axis)
|
|
|
|
return AAxes(new_axes)
|
|
else:
|
|
raise ValueError(f'unable to use type {o.__class__} in AAxes substraction')
|
|
|
|
def __sub__(self, o):
|
|
if isinstance(o, Iterable):
|
|
new_axes = []
|
|
o_axes = tuple(o)
|
|
for axis in self.axes:
|
|
if axis not in o_axes:
|
|
new_axes.append(axis)
|
|
|
|
return AAxes(new_axes)
|
|
else:
|
|
raise ValueError(f'unable to use type {o.__class__} in AAxes substraction')
|
|
|
|
def __str__(self):
|
|
if self.is_none_axes():
|
|
return '(None)'
|
|
return str(self.axes)
|
|
|
|
def __repr__(self): return 'AAxes' + self.__str__()
|
|
|
|
__all__ = ['AAxes'] |