mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-26 07:51:13 -08:00
74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
import numpy as np
|
|
|
|
from ..AShape import AShape
|
|
from ..backend import Kernel
|
|
from ..HKernel import HKernel
|
|
from ..HType import HType
|
|
from ..info import BroadcastInfo, SliceInfo
|
|
from ..SCacheton import SCacheton
|
|
from ..Tensor import Tensor
|
|
|
|
|
|
def slice_set(input_t : Tensor, slices, value) -> Tensor:
|
|
"""
|
|
arguments:
|
|
|
|
input_t input tensor
|
|
slices argument received from class.__getitem__(slices)
|
|
value
|
|
|
|
|
|
Remark.
|
|
|
|
"""
|
|
if HType.is_scalar_type(value):
|
|
v_shape = None
|
|
v_dtype = None
|
|
v_scalar = value
|
|
elif not isinstance(value, Tensor):
|
|
value = Tensor.from_value(value, dtype=input_t.dtype, device=input_t.get_device())
|
|
v_shape = value.shape
|
|
v_dtype = value.dtype
|
|
v_scalar = None
|
|
|
|
op = SCacheton.get(_SliceSetOp, input_t.shape, input_t.dtype, v_shape, v_dtype, v_scalar, HType.hashable_slices(slices) )
|
|
|
|
if v_scalar is not None:
|
|
input_t.get_device().run_kernel(op.forward_krn, input_t.get_buffer() )
|
|
else:
|
|
input_t.get_device().run_kernel(op.forward_krn, input_t.get_buffer(), value.get_buffer() )
|
|
|
|
return input_t
|
|
|
|
class _SliceSetOp:
|
|
def __init__(self, i_shape : AShape, i_dtype : np.dtype, v_shape : AShape, v_dtype : np.dtype, v_scalar, slices):
|
|
slice_info = SliceInfo(i_shape, slices)
|
|
|
|
if v_scalar is None:
|
|
if v_shape.ndim > i_shape.ndim:
|
|
raise ValueError(f'v_shape.ndim {v_shape.ndim} cannot be larger than i_shape.ndim {i_shape.ndim}')
|
|
|
|
# Check that v_shape can broadcast with slice_info.shape
|
|
br_info = BroadcastInfo([slice_info.o_shape_kd, v_shape])
|
|
|
|
v_br_shape = br_info.br_shapes[1]
|
|
|
|
self.forward_krn = Kernel(global_shape=(i_shape.size,), kernel_text=f"""
|
|
{HKernel.define_tensor('O', i_shape, i_dtype )}
|
|
|
|
{HKernel.define_tensor('I', v_br_shape, v_dtype ) if v_scalar is None else ''}
|
|
|
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME
|
|
{', __global const I_PTR_TYPE* I_PTR_NAME' if v_scalar is None else ''})
|
|
{{
|
|
size_t gid = get_global_id(0);
|
|
|
|
{HKernel.decompose_idx_to_axes_idxs('gid', 'O', slice_info.o_shape_kd.ndim)}
|
|
|
|
if ({' & '.join( [f'o{i} >= {b} & o{i} < {e}' if s != 0 else f'o{i} == {b}' for i, (b,e,s) in enumerate(slice_info.axes_abs_bes)] +
|
|
[f'((o{i} % {s}) == 0)' for i, (_,_,s) in enumerate(slice_info.axes_abs_bes) if s > 1 ] ) } )
|
|
|
|
O_GLOBAL_STORE(gid, {f"I_GLOBAL_LOAD( I_IDX_MOD({HKernel.axes_seq_enum('O', i_shape.ndim)}) ) " if v_scalar is None else f" (O_TYPE)({v_scalar})"} );
|
|
}}
|
|
""")
|