mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-12-24 23:11:12 -08:00
325 lines
11 KiB
Python
325 lines
11 KiB
Python
import pickle
|
|
import struct
|
|
import traceback
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from core import imagelib
|
|
from core.cv2ex import *
|
|
from core.imagelib import SegIEPolys
|
|
from core.interact import interact as io
|
|
from core.structex import *
|
|
from facelib import FaceType
|
|
|
|
|
|
class DFLJPG(object):
|
|
def __init__(self, filename):
|
|
self.filename = filename
|
|
self.data = b""
|
|
self.length = 0
|
|
self.chunks = []
|
|
self.dfl_dict = None
|
|
self.shape = None
|
|
self.img = None
|
|
|
|
@staticmethod
|
|
def load_raw(filename, loader_func=None):
|
|
try:
|
|
if loader_func is not None:
|
|
data = loader_func(filename)
|
|
else:
|
|
with open(filename, "rb") as f:
|
|
data = f.read()
|
|
except:
|
|
raise FileNotFoundError(filename)
|
|
|
|
try:
|
|
inst = DFLJPG(filename)
|
|
inst.data = data
|
|
inst.length = len(data)
|
|
inst_length = inst.length
|
|
chunks = []
|
|
data_counter = 0
|
|
while data_counter < inst_length:
|
|
chunk_m_l, chunk_m_h = struct.unpack ("BB", data[data_counter:data_counter+2])
|
|
data_counter += 2
|
|
|
|
if chunk_m_l != 0xFF:
|
|
raise ValueError(f"No Valid JPG info in {filename}")
|
|
|
|
chunk_name = None
|
|
chunk_size = None
|
|
chunk_data = None
|
|
chunk_ex_data = None
|
|
is_unk_chunk = False
|
|
|
|
if chunk_m_h & 0xF0 == 0xD0:
|
|
n = chunk_m_h & 0x0F
|
|
|
|
if n >= 0 and n <= 7:
|
|
chunk_name = "RST%d" % (n)
|
|
chunk_size = 0
|
|
elif n == 0x8:
|
|
chunk_name = "SOI"
|
|
chunk_size = 0
|
|
if len(chunks) != 0:
|
|
raise Exception("")
|
|
elif n == 0x9:
|
|
chunk_name = "EOI"
|
|
chunk_size = 0
|
|
elif n == 0xA:
|
|
chunk_name = "SOS"
|
|
elif n == 0xB:
|
|
chunk_name = "DQT"
|
|
elif n == 0xD:
|
|
chunk_name = "DRI"
|
|
chunk_size = 2
|
|
else:
|
|
is_unk_chunk = True
|
|
elif chunk_m_h & 0xF0 == 0xC0:
|
|
n = chunk_m_h & 0x0F
|
|
if n == 0:
|
|
chunk_name = "SOF0"
|
|
elif n == 2:
|
|
chunk_name = "SOF2"
|
|
elif n == 4:
|
|
chunk_name = "DHT"
|
|
else:
|
|
is_unk_chunk = True
|
|
elif chunk_m_h & 0xF0 == 0xE0:
|
|
n = chunk_m_h & 0x0F
|
|
chunk_name = "APP%d" % (n)
|
|
else:
|
|
is_unk_chunk = True
|
|
|
|
#if is_unk_chunk:
|
|
# #raise ValueError(f"Unknown chunk {chunk_m_h} in {filename}")
|
|
# io.log_info(f"Unknown chunk {chunk_m_h} in {filename}")
|
|
|
|
if chunk_size == None: #variable size
|
|
chunk_size, = struct.unpack (">H", data[data_counter:data_counter+2])
|
|
chunk_size -= 2
|
|
data_counter += 2
|
|
|
|
if chunk_size > 0:
|
|
chunk_data = data[data_counter:data_counter+chunk_size]
|
|
data_counter += chunk_size
|
|
|
|
if chunk_name == "SOS":
|
|
c = data_counter
|
|
while c < inst_length and (data[c] != 0xFF or data[c+1] != 0xD9):
|
|
c += 1
|
|
|
|
chunk_ex_data = data[data_counter:c]
|
|
data_counter = c
|
|
|
|
chunks.append ({'name' : chunk_name,
|
|
'm_h' : chunk_m_h,
|
|
'data' : chunk_data,
|
|
'ex_data' : chunk_ex_data,
|
|
})
|
|
inst.chunks = chunks
|
|
|
|
return inst
|
|
except Exception as e:
|
|
raise Exception (f"Corrupted JPG file {filename} {e}")
|
|
|
|
@staticmethod
|
|
def load(filename, loader_func=None):
|
|
try:
|
|
inst = DFLJPG.load_raw (filename, loader_func=loader_func)
|
|
inst.dfl_dict = {}
|
|
|
|
for chunk in inst.chunks:
|
|
if chunk['name'] == 'APP0':
|
|
d, c = chunk['data'], 0
|
|
c, id, _ = struct_unpack (d, c, "=4sB")
|
|
|
|
if id == b"JFIF":
|
|
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
|
|
else:
|
|
raise Exception("Unknown jpeg ID: %s" % (id) )
|
|
elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2':
|
|
d, c = chunk['data'], 0
|
|
c, precision, height, width = struct_unpack (d, c, ">BHH")
|
|
inst.shape = (height, width, 3)
|
|
|
|
elif chunk['name'] == 'APP15':
|
|
if type(chunk['data']) == bytes:
|
|
inst.dfl_dict = pickle.loads(chunk['data'])
|
|
|
|
return inst
|
|
except Exception as e:
|
|
io.log_err (f'Exception occured while DFLJPG.load : {traceback.format_exc()}')
|
|
return None
|
|
|
|
def has_data(self):
|
|
return len(self.dfl_dict.keys()) != 0
|
|
|
|
def save(self):
|
|
try:
|
|
with open(self.filename, "wb") as f:
|
|
f.write ( self.dump() )
|
|
except:
|
|
raise Exception( f'cannot save {self.filename}' )
|
|
|
|
def dump(self):
|
|
data = b""
|
|
|
|
dict_data = self.dfl_dict
|
|
|
|
# Remove None keys
|
|
for key in list(dict_data.keys()):
|
|
if dict_data[key] is None:
|
|
dict_data.pop(key)
|
|
|
|
for chunk in self.chunks:
|
|
if chunk['name'] == 'APP15':
|
|
self.chunks.remove(chunk)
|
|
break
|
|
|
|
last_app_chunk = 0
|
|
for i, chunk in enumerate (self.chunks):
|
|
if chunk['m_h'] & 0xF0 == 0xE0:
|
|
last_app_chunk = i
|
|
|
|
dflchunk = {'name' : 'APP15',
|
|
'm_h' : 0xEF,
|
|
'data' : pickle.dumps(dict_data),
|
|
'ex_data' : None,
|
|
}
|
|
self.chunks.insert (last_app_chunk+1, dflchunk)
|
|
|
|
|
|
for chunk in self.chunks:
|
|
data += struct.pack ("BB", 0xFF, chunk['m_h'] )
|
|
chunk_data = chunk['data']
|
|
if chunk_data is not None:
|
|
data += struct.pack (">H", len(chunk_data)+2 )
|
|
data += chunk_data
|
|
|
|
chunk_ex_data = chunk['ex_data']
|
|
if chunk_ex_data is not None:
|
|
data += chunk_ex_data
|
|
|
|
return data
|
|
|
|
def get_img(self):
|
|
if self.img is None:
|
|
self.img = cv2_imread(self.filename)
|
|
return self.img
|
|
|
|
def get_shape(self):
|
|
if self.shape is None:
|
|
img = self.get_img()
|
|
if img is not None:
|
|
self.shape = img.shape
|
|
return self.shape
|
|
|
|
def get_height(self):
|
|
for chunk in self.chunks:
|
|
if type(chunk) == IHDR:
|
|
return chunk.height
|
|
return 0
|
|
|
|
def get_dict(self):
|
|
return self.dfl_dict
|
|
|
|
def set_dict (self, dict_data=None):
|
|
self.dfl_dict = dict_data
|
|
|
|
def get_face_type(self): return self.dfl_dict.get('face_type', FaceType.toString (FaceType.FULL) )
|
|
def set_face_type(self, face_type): self.dfl_dict['face_type'] = face_type
|
|
|
|
def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] )
|
|
def set_landmarks(self, landmarks): self.dfl_dict['landmarks'] = landmarks
|
|
|
|
def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', 1.0)
|
|
def set_eyebrows_expand_mod(self, eyebrows_expand_mod): self.dfl_dict['eyebrows_expand_mod'] = eyebrows_expand_mod
|
|
|
|
def get_source_filename(self): return self.dfl_dict.get ('source_filename', None)
|
|
def set_source_filename(self, source_filename): self.dfl_dict['source_filename'] = source_filename
|
|
|
|
def get_source_rect(self): return self.dfl_dict.get ('source_rect', None)
|
|
def set_source_rect(self, source_rect): self.dfl_dict['source_rect'] = source_rect
|
|
|
|
def get_source_landmarks(self): return np.array ( self.dfl_dict.get('source_landmarks', None) )
|
|
def set_source_landmarks(self, source_landmarks): self.dfl_dict['source_landmarks'] = source_landmarks
|
|
|
|
def get_image_to_face_mat(self):
|
|
mat = self.dfl_dict.get ('image_to_face_mat', None)
|
|
if mat is not None:
|
|
return np.array (mat)
|
|
return None
|
|
def set_image_to_face_mat(self, image_to_face_mat): self.dfl_dict['image_to_face_mat'] = image_to_face_mat
|
|
|
|
def has_seg_ie_polys(self):
|
|
return self.dfl_dict.get('seg_ie_polys',None) is not None
|
|
|
|
def get_seg_ie_polys(self):
|
|
d = self.dfl_dict.get('seg_ie_polys',None)
|
|
if d is not None:
|
|
d = SegIEPolys.load(d)
|
|
else:
|
|
d = SegIEPolys()
|
|
|
|
return d
|
|
|
|
def set_seg_ie_polys(self, seg_ie_polys):
|
|
if seg_ie_polys is not None:
|
|
if not isinstance(seg_ie_polys, SegIEPolys):
|
|
raise ValueError('seg_ie_polys should be instance of SegIEPolys')
|
|
|
|
if seg_ie_polys.has_polys():
|
|
seg_ie_polys = seg_ie_polys.dump()
|
|
else:
|
|
seg_ie_polys = None
|
|
|
|
self.dfl_dict['seg_ie_polys'] = seg_ie_polys
|
|
|
|
def has_xseg_mask(self):
|
|
return self.dfl_dict.get('xseg_mask',None) is not None
|
|
|
|
def get_xseg_mask_compressed(self):
|
|
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
|
if mask_buf is None:
|
|
return None
|
|
|
|
return mask_buf
|
|
|
|
def get_xseg_mask(self):
|
|
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
|
if mask_buf is None:
|
|
return None
|
|
|
|
img = cv2.imdecode(mask_buf, cv2.IMREAD_UNCHANGED)
|
|
if len(img.shape) == 2:
|
|
img = img[...,None]
|
|
|
|
return img.astype(np.float32) / 255.0
|
|
|
|
|
|
def set_xseg_mask(self, mask_a):
|
|
if mask_a is None:
|
|
self.dfl_dict['xseg_mask'] = None
|
|
return
|
|
|
|
mask_a = imagelib.normalize_channels(mask_a, 1)
|
|
img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8)
|
|
|
|
data_max_len = 50000
|
|
|
|
ret, buf = cv2.imencode('.png', img_data)
|
|
|
|
if not ret or len(buf) > data_max_len:
|
|
for jpeg_quality in range(100,-1,-1):
|
|
ret, buf = cv2.imencode( '.jpg', img_data, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] )
|
|
if ret and len(buf) <= data_max_len:
|
|
break
|
|
|
|
if not ret:
|
|
raise Exception("set_xseg_mask: unable to generate image data for set_xseg_mask")
|
|
|
|
self.dfl_dict['xseg_mask'] = buf
|