DeepFaceLab/samplelib/SampleGeneratorImage.py
2020-03-09 13:09:46 +04:00

67 lines
2.1 KiB
Python

import traceback
import cv2
import numpy as np
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
class SampleGeneratorImage(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs):
super().__init__(debug, batch_size)
self.initialized = False
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
samples = SampleLoader.load (SampleType.IMAGE, samples_path)
if len(samples) == 0:
if raise_on_no_data:
raise ValueError('No training data provided.')
return
self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \
[SubprocessGenerator ( self.batch_func, samples )]
self.generator_counter = -1
self.initialized = True
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, samples):
samples_len = len(samples)
idxs = [ *range(samples_len) ]
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle (shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[idx]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)
if batches is None:
batches = [ [] for _ in range(len(x)) ]
for i in range(len(x)):
batches[i].append ( x[i] )
yield [ np.array(batch) for batch in batches]