DeepFaceLive/xlib/avecl/_internal/info/TileInfo.py
2021-09-30 18:21:30 +04:00

52 lines
1.3 KiB
Python

import numpy as np
from ..AShape import AShape, AShape
class TileInfo:
"""
Tile info.
arguments
shape AShape
tiles Iterable of ints
errors during the construction:
ValueError
result:
.o_shape AShape
.axes_slices list of slice() to fetch original shape
from o_shape for each tile
"""
__slots__ = ['o_shape', 'axes_slices']
def __init__(self, shape, tiles):
if len(tiles) != shape.ndim:
raise ValueError(f'tiles should match shape.ndim {shape.ndim}')
self.o_shape = AShape(dim*tiles[i] for i,dim in enumerate(shape))
c = [0]*shape.ndim
axes_offsets = []
for n in range(np.prod(tiles)):
axes_offsets.append( c.copy() )
for i in range(shape.ndim-1,-1,-1):
c[i] += 1
if c[i] < tiles[i]:
break
c[i] = 0
axes_slices = []
for axes_offset in axes_offsets:
sl = []
for axis,tile in enumerate(axes_offset):
axis_size = shape[axis]
sl.append( slice(axis_size*tile, axis_size*(tile+1)) )
axes_slices.append(tuple(sl))
self.axes_slices = tuple(axes_slices)