DeepFaceLive/xlib/avecl/_internal/info/ConcatInfo.py
2021-09-30 18:21:30 +04:00

63 lines
1.7 KiB
Python

from ..AShape import AShape
class ConcatInfo:
__slots__ = ['o_shape', 'axis', 'axis_sizes', 'axis_offsets']
def __init__(self, shapes, axis):
"""
Concat info
arguments
shapes Iterable of shapes
errors during the construction:
ValueError
result
.o_shape AShape
.axis Int fixed axis argument
.axis_sizes List[Int] axis sizes for every shape in shapes
.axis_offsets List[Int] axis offset in o_shape
for every shape in shapes
"""
shapes = tuple(shapes)
if len(shapes) == 0:
raise ValueError('shapes is empty')
shape = shapes[0]
if axis < 0:
axis = shape.ndim + axis
if axis < 0 or axis >= shape.ndim:
raise ValueError(f'Wrong axis {axis}')
fixed_shapes = [ tuple(a for i,a in enumerate(shape) if i != axis) for shape in shapes ]
req_shape = fixed_shapes[0]
if not all(shape == req_shape for shape in fixed_shapes[1:]):
raise ValueError(f'All shapes must match shape {tuple(a if i != axis else "*" for i,a in enumerate(shape))}')
axis_sizes = [ shape[axis] for shape in shapes ]
axis_offset = 0
axis_offsets = []
for axis_size in axis_sizes:
axis_offsets.append(axis_offset)
axis_offset += axis_size
self.o_shape = AShape( tuple(shape)[0:axis] + (sum(axis_sizes),) + tuple(shape)[axis+1:] )
self.axis = axis
self.axis_sizes = axis_sizes
self.axis_offsets = tuple(axis_offsets)