mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-25 07:21:13 -08:00
17 lines
401 B
Python
17 lines
401 B
Python
import numpy as np
|
|
|
|
def get_NHWC_shape(img : np.ndarray):
|
|
"""
|
|
returns NHWC shape where missed dims are 1
|
|
"""
|
|
ndim = img.ndim
|
|
if ndim not in [2,3,4]:
|
|
raise ValueError(f'img.ndim must be 2,3,4, not {ndim}.')
|
|
|
|
if ndim == 2:
|
|
N, (H,W), C = 1, img.shape, 1
|
|
elif ndim == 3:
|
|
N, (H,W,C) = 1, img.shape
|
|
else:
|
|
N,H,W,C = img.shape
|
|
return N,H,W,C |