mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-26 07:51:13 -08:00
52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
import numpy as np
|
|
from ..AShape import AShape, AShape
|
|
|
|
class TileInfo:
|
|
"""
|
|
Tile info.
|
|
|
|
arguments
|
|
|
|
shape AShape
|
|
|
|
tiles Iterable of ints
|
|
|
|
errors during the construction:
|
|
|
|
ValueError
|
|
|
|
result:
|
|
|
|
.o_shape AShape
|
|
|
|
.axes_slices list of slice() to fetch original shape
|
|
from o_shape for each tile
|
|
"""
|
|
|
|
__slots__ = ['o_shape', 'axes_slices']
|
|
|
|
def __init__(self, shape, tiles):
|
|
if len(tiles) != shape.ndim:
|
|
raise ValueError(f'tiles should match shape.ndim {shape.ndim}')
|
|
|
|
self.o_shape = AShape(dim*tiles[i] for i,dim in enumerate(shape))
|
|
|
|
c = [0]*shape.ndim
|
|
|
|
axes_offsets = []
|
|
for n in range(np.prod(tiles)):
|
|
axes_offsets.append( c.copy() )
|
|
for i in range(shape.ndim-1,-1,-1):
|
|
c[i] += 1
|
|
if c[i] < tiles[i]:
|
|
break
|
|
c[i] = 0
|
|
|
|
axes_slices = []
|
|
for axes_offset in axes_offsets:
|
|
sl = []
|
|
for axis,tile in enumerate(axes_offset):
|
|
axis_size = shape[axis]
|
|
sl.append( slice(axis_size*tile, axis_size*(tile+1)) )
|
|
axes_slices.append(tuple(sl))
|
|
self.axes_slices = tuple(axes_slices) |