DeepFaceLive/xlib/avecl/_internal/op/slice_set.py
2021-09-30 18:21:30 +04:00

74 lines
2.5 KiB
Python

import numpy as np
from ..AShape import AShape
from ..backend import Kernel
from ..HKernel import HKernel
from ..HType import HType
from ..info import BroadcastInfo, SliceInfo
from ..SCacheton import SCacheton
from ..Tensor import Tensor
def slice_set(input_t : Tensor, slices, value) -> Tensor:
"""
arguments:
input_t input tensor
slices argument received from class.__getitem__(slices)
value
Remark.
"""
if HType.is_scalar_type(value):
v_shape = None
v_dtype = None
v_scalar = value
elif not isinstance(value, Tensor):
value = Tensor.from_value(value, dtype=input_t.dtype, device=input_t.get_device())
v_shape = value.shape
v_dtype = value.dtype
v_scalar = None
op = SCacheton.get(_SliceSetOp, input_t.shape, input_t.dtype, v_shape, v_dtype, v_scalar, HType.hashable_slices(slices) )
if v_scalar is not None:
input_t.get_device().run_kernel(op.forward_krn, input_t.get_buffer() )
else:
input_t.get_device().run_kernel(op.forward_krn, input_t.get_buffer(), value.get_buffer() )
return input_t
class _SliceSetOp:
def __init__(self, i_shape : AShape, i_dtype : np.dtype, v_shape : AShape, v_dtype : np.dtype, v_scalar, slices):
slice_info = SliceInfo(i_shape, slices)
if v_scalar is None:
if v_shape.ndim > i_shape.ndim:
raise ValueError(f'v_shape.ndim {v_shape.ndim} cannot be larger than i_shape.ndim {i_shape.ndim}')
# Check that v_shape can broadcast with slice_info.shape
br_info = BroadcastInfo([slice_info.o_shape_kd, v_shape])
v_br_shape = br_info.br_shapes[1]
self.forward_krn = Kernel(global_shape=(i_shape.size,), kernel_text=f"""
{HKernel.define_tensor('O', i_shape, i_dtype )}
{HKernel.define_tensor('I', v_br_shape, v_dtype ) if v_scalar is None else ''}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME
{', __global const I_PTR_TYPE* I_PTR_NAME' if v_scalar is None else ''})
{{
size_t gid = get_global_id(0);
{HKernel.decompose_idx_to_axes_idxs('gid', 'O', slice_info.o_shape_kd.ndim)}
if ({' & '.join( [f'o{i} >= {b} & o{i} < {e}' if s != 0 else f'o{i} == {b}' for i, (b,e,s) in enumerate(slice_info.axes_abs_bes)] +
[f'((o{i} % {s}) == 0)' for i, (_,_,s) in enumerate(slice_info.axes_abs_bes) if s > 1 ] ) } )
O_GLOBAL_STORE(gid, {f"I_GLOBAL_LOAD( I_IDX_MOD({HKernel.axes_seq_enum('O', i_shape.ndim)}) ) " if v_scalar is None else f" (O_TYPE)({v_scalar})"} );
}}
""")