mirror of
https://github.com/iperov/DeepFaceLive.git
synced 2024-12-26 07:51:13 -08:00
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from ..AShape import AShape
|
|
from ..AAxes import AAxes
|
|
from ..backend import Kernel
|
|
from ..HKernel import HKernel
|
|
from ..HType import HType
|
|
from ..info import SliceInfo
|
|
from ..SCacheton import SCacheton
|
|
from ..Tensor import Tensor
|
|
|
|
|
|
def split(input_t : Tensor, axis, keepdims=False) -> List[Tensor]:
|
|
"""
|
|
|
|
arguments
|
|
|
|
input_t Tensor
|
|
|
|
axis
|
|
|
|
"""
|
|
shape = input_t.shape
|
|
|
|
result = []
|
|
for i in range(shape[axis]):
|
|
slices = [slice(None, None, None)]*shape.ndim
|
|
|
|
slices[axis] = i if not keepdims else slice(i,i+1,1)
|
|
|
|
result.append( slice_(input_t, slices) )
|
|
|
|
return result
|
|
|
|
|
|
def slice_(input_t : Tensor, slices, dtype : np.dtype = None, output_t=None, is_add_to_output=False) -> Tensor:
|
|
"""
|
|
arguments:
|
|
|
|
input_t input tensor
|
|
slices argument received from class.__getitem__(slices)
|
|
|
|
output_t compute result to this Tensor.
|
|
Tensor may be with different shape, but should match total size.
|
|
gradfn will not be set.
|
|
|
|
is_add_to_output add result to output_t if output_t is set.
|
|
|
|
Remark.
|
|
|
|
Slicing logic is not the same as numpy:
|
|
For example np[2:0:1] slice will produce invalid array with zero index,
|
|
but nn.slice() will select 2 index, same as val_t[2].
|
|
"""
|
|
op = SCacheton.get(_SliceOp, input_t.shape, input_t.dtype, dtype, HType.hashable_slices(slices), False if output_t is None else is_add_to_output )
|
|
o_shape = op.slice_info.o_shape
|
|
|
|
if output_t is None:
|
|
if op.slice_info.just_reshaped:
|
|
return input_t.reshape(o_shape)
|
|
else:
|
|
output_t = Tensor(o_shape, op.o_dtype, device=input_t.get_device())
|
|
|
|
elif output_t.shape.size != o_shape.size:
|
|
raise ValueError(f'output_t must have size {o_shape.size}')
|
|
|
|
input_t.get_device().run_kernel(op.forward_krn, output_t.get_buffer(), input_t.get_buffer() )
|
|
|
|
return output_t
|
|
|
|
|
|
class _SliceOp:
|
|
def __init__(self, i_shape : AShape, i_dtype : np.dtype, o_dtype : np.dtype, slices, is_add_to_output):
|
|
self.slice_info = slice_info = SliceInfo(i_shape, slices)
|
|
|
|
self.o_dtype = o_dtype = o_dtype if o_dtype is not None else i_dtype
|
|
|
|
self.forward_krn = Kernel(global_shape=(slice_info.o_shape_kd.size,), kernel_text=f"""
|
|
{HKernel.define_tensor('O', slice_info.o_shape_kd, o_dtype )}
|
|
{HKernel.define_tensor('I', i_shape, i_dtype )}
|
|
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME, __global const I_PTR_TYPE* I_PTR_NAME)
|
|
{{
|
|
size_t gid = get_global_id(0);
|
|
|
|
{HKernel.decompose_idx_to_axes_idxs('gid', 'o', slice_info.o_shape_kd.ndim)}
|
|
|
|
{chr(10).join( f'size_t i{i} = {b} + o{i} * {s}; ' for i, (b,e,s) in enumerate(slice_info.axes_bes) ) }
|
|
|
|
{'O_STORE_ADD' if is_add_to_output else 'O_GLOBAL_STORE'}(gid, I_GLOBAL_LOAD( I_IDX({HKernel.axes_seq_enum('i', i_shape.ndim)}) ) );
|
|
}}
|
|
""")
|