DeepFaceLive/xlib/avecl/_internal/op/slice_.py
2021-10-20 18:02:50 +04:00

94 lines
2.9 KiB
Python

from typing import List
import numpy as np
from ..AShape import AShape
from ..AAxes import AAxes
from ..backend import Kernel
from ..HKernel import HKernel
from ..HType import HType
from ..info import SliceInfo
from ..SCacheton import SCacheton
from ..Tensor import Tensor
def split(input_t : Tensor, axis, keepdims=False) -> List[Tensor]:
"""
arguments
input_t Tensor
axis
"""
shape = input_t.shape
result = []
for i in range(shape[axis]):
slices = [slice(None, None, None)]*shape.ndim
slices[axis] = i if not keepdims else slice(i,i+1,1)
result.append( slice_(input_t, slices) )
return result
def slice_(input_t : Tensor, slices, dtype : np.dtype = None, output_t=None, is_add_to_output=False) -> Tensor:
"""
arguments:
input_t input tensor
slices argument received from class.__getitem__(slices)
output_t compute result to this Tensor.
Tensor may be with different shape, but should match total size.
gradfn will not be set.
is_add_to_output add result to output_t if output_t is set.
Remark.
Slicing logic is not the same as numpy:
For example np[2:0:1] slice will produce invalid array with zero index,
but nn.slice() will select 2 index, same as val_t[2].
"""
op = SCacheton.get(_SliceOp, input_t.shape, input_t.dtype, dtype, HType.hashable_slices(slices), False if output_t is None else is_add_to_output )
o_shape = op.slice_info.o_shape
if output_t is None:
if op.slice_info.just_reshaped:
return input_t.reshape(o_shape)
else:
output_t = Tensor(o_shape, op.o_dtype, device=input_t.get_device())
elif output_t.shape.size != o_shape.size:
raise ValueError(f'output_t must have size {o_shape.size}')
input_t.get_device().run_kernel(op.forward_krn, output_t.get_buffer(), input_t.get_buffer() )
return output_t
class _SliceOp:
def __init__(self, i_shape : AShape, i_dtype : np.dtype, o_dtype : np.dtype, slices, is_add_to_output):
self.slice_info = slice_info = SliceInfo(i_shape, slices)
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
self.forward_krn = Kernel(global_shape=(slice_info.o_shape_kd.size,), kernel_text=f"""
{HKernel.define_tensor('O', slice_info.o_shape_kd, o_dtype )}
{HKernel.define_tensor('I', i_shape, i_dtype )}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I_PTR_NAME)
{{
size_t gid = get_global_id(0);
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', slice_info.o_shape_kd.ndim)}
{chr(10).join( f'size_t i{i} = {b} + o{i} * {s}; ' for i, (b,e,s) in enumerate(slice_info.axes_bes) ) }
{'O_STORE_ADD' if is_add_to_output else 'O_GLOBAL_STORE'}(gid, I_GLOBAL_LOAD( I_IDX({HKernel.axes_seq_enum('i', i_shape.ndim)}) ) );
}}
""")