DeepFaceLive/xlib/avecl/_internal/initializer/InitConst.py
2021-10-22 16:35:05 +04:00

38 lines
1018 B
Python

from ..backend import Kernel
from ..HKernel import HKernel
from ..SCacheton import SCacheton
from ..Tensor import Tensor
from .Initializer import Initializer
class InitConst(Initializer):
def __init__(self, value=0):
"""
arguments
value(0)
"""
super().__init__()
self._value = value
def initialize_tensor(self, tensor : Tensor):
key = (InitConst, self._value, tensor.dtype)
kernel = SCacheton.get_var(key)
if kernel is None:
kernel = Kernel(kernel_text=f"""
{HKernel.define_tensor('O', (tensor.shape.size,), tensor.dtype )}
__kernel void impl(__global O_PTR_TYPE* O_PTR_NAME)
{{
O_GLOBAL_STORE(get_global_id(0), (O_TYPE){self._value} );
}}
""")
SCacheton.set_var(key, kernel)
tensor.get_device().run_kernel( kernel, tensor.get_buffer(),
global_shape=(tensor.shape.size,) )
def __str__(self): return f'InitConst low={self._low}, high={self._high}'