|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import glob |
|
|
import insightface |
|
|
import cv2 |
|
|
import subprocess |
|
|
import argparse |
|
|
from decord import VideoReader |
|
|
from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip |
|
|
from facexlib.parsing import init_parsing_model |
|
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
|
|
from insightface.app import FaceAnalysis |
|
|
|
|
|
from diffusers.models import AutoencoderKLCogVideoX |
|
|
from diffusers.utils import export_to_video, load_image |
|
|
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel |
|
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
|
|
|
from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel |
|
|
from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline |
|
|
from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor |
|
|
from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor |
|
|
from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d |
|
|
|
|
|
def crop_and_resize(image, height, width): |
|
|
image = np.array(image) |
|
|
image_height, image_width, _ = image.shape |
|
|
if image_height / image_width < height / width: |
|
|
croped_width = int(image_height / height * width) |
|
|
left = (image_width - croped_width) // 2 |
|
|
image = image[:, left: left+croped_width] |
|
|
image = Image.fromarray(image).resize((width, height)) |
|
|
else: |
|
|
pad = int((((width / height) * image_height) - image_width) / 2.) |
|
|
padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8) |
|
|
padded_image[:, pad:pad+image_width] = image |
|
|
image = Image.fromarray(padded_image).resize((width, height)) |
|
|
return image |
|
|
|
|
|
def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"): |
|
|
clip = ImageSequenceClip(samples, fps=fps) |
|
|
clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, |
|
|
ffmpeg_params=["-crf", "18", "-preset", "slow"]) |
|
|
|
|
|
def parse_video(driving_video_path, max_frame_num): |
|
|
vr = VideoReader(driving_video_path) |
|
|
fps = vr.get_avg_fps() |
|
|
video_length = len(vr) |
|
|
|
|
|
duration = video_length / fps |
|
|
target_times = np.arange(0, duration, 1/12) |
|
|
frame_indices = (target_times * fps).astype(np.int32) |
|
|
|
|
|
frame_indices = frame_indices[frame_indices < video_length] |
|
|
control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)] |
|
|
|
|
|
out_frames = len(control_frames) - 1 |
|
|
if len(control_frames) < max_frame_num - 1: |
|
|
video_lenght_add = max_frame_num - len(control_frames) - 1 |
|
|
control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1], [control_frames[-1]] * video_lenght_add), axis=0) |
|
|
else: |
|
|
control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1]), axis=0) |
|
|
|
|
|
return control_frames |
|
|
|
|
|
def exec_cmd(cmd): |
|
|
return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) |
|
|
|
|
|
def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str): |
|
|
cmd = [ |
|
|
'ffmpeg', |
|
|
'-y', |
|
|
'-i', f'"{silent_video_path}"', |
|
|
'-i', f'"{audio_video_path}"', |
|
|
'-map', '0:v', |
|
|
'-map', '1:a', |
|
|
'-c:v', 'copy', |
|
|
'-shortest', |
|
|
f'"{output_video_path}"' |
|
|
] |
|
|
|
|
|
try: |
|
|
exec_cmd(' '.join(cmd)) |
|
|
print(f"Video with audio generated successfully: {output_video_path}") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error occurred: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Process video and image for face animation.") |
|
|
parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.') |
|
|
parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.') |
|
|
parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.') |
|
|
args = parser.parse_args() |
|
|
|
|
|
guidance_scale = 3.0 |
|
|
seed = 43 |
|
|
num_inference_steps = 10 |
|
|
sample_size = [480, 720] |
|
|
max_frame_num = 49 |
|
|
weight_dtype = torch.bfloat16 |
|
|
save_path = args.output_path |
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
model_name = "pretrained_models/SkyReels-A1-5B/" |
|
|
siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" |
|
|
|
|
|
lmk_extractor = LMKExtractor() |
|
|
processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') |
|
|
vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,) |
|
|
face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",) |
|
|
|
|
|
|
|
|
siglip = SiglipVisionModel.from_pretrained(siglip_name) |
|
|
siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) |
|
|
|
|
|
|
|
|
transformer = CogVideoXTransformer3DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer" |
|
|
).to(weight_dtype) |
|
|
|
|
|
vae = AutoencoderKLCogVideoX.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( |
|
|
model_name, |
|
|
subfolder="pose_guider", |
|
|
).to(weight_dtype) |
|
|
|
|
|
pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( |
|
|
model_name, |
|
|
transformer = transformer, |
|
|
vae = vae, |
|
|
lmk_encoder = lmk_encoder, |
|
|
image_encoder = siglip, |
|
|
feature_extractor = siglip_normalize, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
pipe.to("cuda") |
|
|
pipe.enable_model_cpu_offload() |
|
|
pipe.vae.enable_tiling() |
|
|
|
|
|
control_frames = parse_video(args.driving_video_path, max_frame_num) |
|
|
|
|
|
|
|
|
driving_video_crop = [] |
|
|
for control_frame in control_frames: |
|
|
frame, _, _ = processor.face_crop(control_frame) |
|
|
driving_video_crop.append(frame) |
|
|
|
|
|
image = load_image(image=args.image_path) |
|
|
image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) |
|
|
|
|
|
|
|
|
ref_image, x1, y1 = processor.face_crop(np.array(image)) |
|
|
face_h, face_w, _, = ref_image.shape |
|
|
source_image = ref_image |
|
|
driving_video = driving_video_crop |
|
|
out_frames = processor.preprocess_lmk3d(source_image, driving_video) |
|
|
|
|
|
rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0) |
|
|
for ii in range(rescale_motions.shape[0]): |
|
|
rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] |
|
|
ref_image = cv2.resize(ref_image, (512, 512)) |
|
|
ref_lmk = lmk_extractor(ref_image[:, :, ::-1]) |
|
|
|
|
|
ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True) |
|
|
|
|
|
first_motion = np.zeros_like(np.array(image)) |
|
|
first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img |
|
|
first_motion = first_motion[np.newaxis, :] |
|
|
|
|
|
motions = np.concatenate([first_motion, rescale_motions]) |
|
|
input_video = motions[:max_frame_num] |
|
|
|
|
|
face_helper.clean_all() |
|
|
face_helper.read_image(np.array(image)[:, :, ::-1]) |
|
|
face_helper.get_face_landmarks_5(only_center_face=True) |
|
|
face_helper.align_warp_face() |
|
|
align_face = face_helper.cropped_faces[0] |
|
|
image_face = align_face[:, :, ::-1] |
|
|
|
|
|
input_video = input_video[:max_frame_num] |
|
|
motions = np.array(input_video) |
|
|
|
|
|
|
|
|
input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) |
|
|
input_video = input_video / 255 |
|
|
|
|
|
out_samples = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
sample = pipe( |
|
|
image=image, |
|
|
image_face=image_face, |
|
|
control_video = input_video, |
|
|
prompt = "", |
|
|
negative_prompt = "", |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
num_frames = 49, |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
num_inference_steps = num_inference_steps, |
|
|
) |
|
|
out_samples.extend(sample.frames[0]) |
|
|
out_samples = out_samples[2:] |
|
|
|
|
|
save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4" |
|
|
|
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
video_path = os.path.join(save_path, save_path_name + "-output.mp4") |
|
|
export_to_video(out_samples, video_path, fps=12) |
|
|
|
|
|
target_h, target_w = sample_size[0], sample_size[1] |
|
|
final_images = [] |
|
|
final_images2 =[] |
|
|
rescale_motions = rescale_motions[1:] |
|
|
control_frames = control_frames[1:] |
|
|
for q in range(len(out_samples)): |
|
|
frame1 = image |
|
|
frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w) |
|
|
frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB") |
|
|
|
|
|
result = Image.new('RGB', (target_w * 3, target_h)) |
|
|
result.paste(frame1, (0, 0)) |
|
|
result.paste(frame2, (target_w, 0)) |
|
|
result.paste(frame3, (target_w * 2, 0)) |
|
|
final_images.append(np.array(result)) |
|
|
|
|
|
video_out_path = os.path.join(save_path, save_path_name) |
|
|
write_mp4(video_out_path, final_images, fps=12) |
|
|
|
|
|
add_audio_to_video(video_out_path, args.driving_video_path, video_out_path + ".audio.mp4") |
|
|
add_audio_to_video(video_path, args.driving_video_path, video_path + ".audio.mp4") |
|
|
|