DeepFaceLive/xlib/avecl/_internal/op/remap.py
2021-09-30 18:21:30 +04:00

104 lines
3.7 KiB
Python

import numpy as np
from ..AShape import AShape
from ..backend import Kernel
from ..HKernel import HKernel
from ..info import BroadcastInfo
from ..SCacheton import SCacheton
from ..Tensor import Tensor
def remap (input_t : Tensor, coords_t : Tensor, dtype=None) -> Tensor:
"""
remap input_t in spatial axes using coords_t
arguments
input_t Tensor( ...,IH,IW )
coords_t Tensor( ...,OH,OW,D )
OH - output height
OW - output width
D is (2)[x,y] coords
dtype
...-head part of shapes will be broadcasted to each other
"""
op = SCacheton.get(_RemapOp, input_t.shape, input_t.dtype, coords_t.shape, coords_t.dtype, dtype)
output_t = Tensor( op.o_shape, op.o_dtype, device=input_t.get_device() )
input_t.get_device().run_kernel(op.forward_krn, output_t.get_buffer(), input_t.get_buffer(), coords_t.get_buffer())
return output_t
class _RemapOp():
def __init__(self, i_shape : AShape, i_dtype, c_shape : AShape, c_dtype, o_dtype):
if np.dtype(i_dtype).type == np.bool_:
raise ValueError('np.bool_ dtype of i_dtype is not supported.')
if np.dtype(c_dtype).type == np.bool_:
raise ValueError('np.bool_ dtype of c_dtype is not supported.')
if i_shape.ndim < 2:
raise ValueError('i_shape.ndim must be >= 2 (...,H,W)')
if c_shape.ndim < 3:
raise ValueError(f'Coords shape ndim must be >= 3(...,H,W,D)')
if c_shape[-1] != 2:
raise ValueError('Last coords dim must be == 2 (x,y)')
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
if i_shape.ndim == 2 and c_shape.ndim == 3:
# nothing to broadcast
i_br_shape = i_shape
c_br_shape = c_shape
o_shape = c_shape[-3:-1]
else:
op = BroadcastInfo([ i_shape[:-2], c_shape[:-3] ])
i_br_shape = op.br_shapes[0] + i_shape[-2:]
c_br_shape = op.br_shapes[1] + c_shape[-3:]
o_shape = op.o_shape + c_shape[-3:-1]
self.o_shape = o_shape
self.forward_krn = Kernel(global_shape=(o_shape.size,), kernel_text=f"""
{HKernel.define_tensor('O', o_shape, o_dtype)}
{HKernel.define_tensor('I', i_br_shape, i_dtype)}
{HKernel.define_tensor('C', c_br_shape[:-1], c_dtype)}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I_PTR_NAME, __global const C_PTR_TYPE2* C_PTR_NAME)
{{
size_t gid = get_global_id(0);
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', o_shape.ndim)}
C_TYPE2 c_value = C_GLOBAL_LOAD2(C_IDX_MOD({HKernel.axes_seq_enum('O', o_shape.ndim)}));
float cx01 = (float) c_value.x;
float cy01 = (float) c_value.y;
float cx0f = floor(cx01); int cx0 = (int)cx0f;
float cy0f = floor(cy01); int cy0 = (int)cy0f;
float cx1f = cx0f+1; int cx1 = (int)cx1f;
float cy1f = cy0f+1; int cy1 = (int)cy1f;
float p00 = I_GLOBAL_LOAD(I_IDX_MOD({HKernel.axes_seq_enum('O', o_shape.ndim-2, suffix='cy0,cx0')}));
float p01 = I_GLOBAL_LOAD(I_IDX_MOD({HKernel.axes_seq_enum('O', o_shape.ndim-2, suffix='cy0,cx1')}));
float p10 = I_GLOBAL_LOAD(I_IDX_MOD({HKernel.axes_seq_enum('O', o_shape.ndim-2, suffix='cy1,cx0')}));
float p11 = I_GLOBAL_LOAD(I_IDX_MOD({HKernel.axes_seq_enum('O', o_shape.ndim-2, suffix='cy1,cx1')}));
p00 *= (cx1f - cx01)*(cy1f - cy01)*(cy0 >= 0 & cy0 < Im2 & cx0 >= 0 & cx0 < Im1);
p01 *= (cx01 - cx0f)*(cy1f - cy01)*(cy0 >= 0 & cy0 < Im2 & cx1 >= 0 & cx1 < Im1);
p10 *= (cx1f - cx01)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx0 >= 0 & cx0 < Im1);
p11 *= (cx01 - cx0f)*(cy01 - cy0f)*(cy1 >= 0 & cy1 < Im2 & cx1 >= 0 & cx1 < Im1);
O_GLOBAL_STORE(gid, p00 + p01 + p10 + p11);
}}
""")