mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-26 07:51:13 -08:00
44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
from ..AShape import AShape
|
|
|
|
class ReshapeInfo:
|
|
"""
|
|
Reshape info.
|
|
can raise ValueError,TypeError during the construction
|
|
|
|
arguments
|
|
|
|
shape AShape
|
|
|
|
target_shape Iterable of ints
|
|
can be any len and contains only one '-1'
|
|
Example
|
|
|
|
shape (2, 512, 8, 8, 64)
|
|
target_shape (2, 512, -1)
|
|
o_shape (2, 512, 4096)
|
|
"""
|
|
|
|
__slots__ = ['o_shape']
|
|
|
|
def __init__(self, shape, target_shape):
|
|
o_shape = []
|
|
|
|
remain_size = shape.size
|
|
|
|
unk_axis = None
|
|
for t_size in target_shape:
|
|
t_size = int(t_size)
|
|
if t_size != -1:
|
|
mod = remain_size % t_size
|
|
if mod != 0:
|
|
raise ValueError(f'Cannot reshape {shape} to {target_shape}.')
|
|
remain_size /= t_size
|
|
else:
|
|
if unk_axis is not None:
|
|
raise ValueError('Can specify only one unknown dimension.')
|
|
unk_axis = len(o_shape)
|
|
o_shape.append( t_size )
|
|
|
|
if unk_axis is not None:
|
|
o_shape[unk_axis] = int(remain_size)
|
|
self.o_shape = AShape(o_shape) |