import torch, warnings, glob, os import numpy as np from PIL import Image from einops import repeat, reduce from typing import Optional, Union from dataclasses import dataclass from huggingface_hub import snapshot_download as hf_snapshot_download from modelscope import snapshot_download import numpy as np from PIL import Image from typing import Optional class BasePipeline(torch.nn.Module): def __init__( self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64, time_division_factor=None, time_division_remainder=None, ): super().__init__() # The device and torch_dtype is used for the storage of intermediate variables, not models. self.device = device self.torch_dtype = torch_dtype # The following parameters are used for shape check. self.height_division_factor = height_division_factor self.width_division_factor = width_division_factor self.time_division_factor = time_division_factor self.time_division_remainder = time_division_remainder self.vram_management_enabled = False def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None: self.device = device if dtype is not None: self.torch_dtype = dtype super().to(*args, **kwargs) return self def check_resize_height_width(self, height, width, num_frames=None): # Shape check if height % self.height_division_factor != 0: height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") if width % self.width_division_factor != 0: width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") if num_frames is None: return height, width else: if num_frames % self.time_division_factor != self.time_division_remainder: num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") return height, width, num_frames def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): # Transform a PIL.Image to torch.Tensor image = torch.Tensor(np.array(image, dtype=np.float32)) image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) image = image * ((max_value - min_value) / 255) + min_value image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) return image def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): # Transform a list of PIL.Image to torch.Tensor video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] video = torch.stack(video, dim=pattern.index("T") // 2) return video def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): # Transform a torch.Tensor to PIL.Image if pattern != "H W C": vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) image = image.to(device="cpu", dtype=torch.uint8) image = Image.fromarray(image.numpy()) return image def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): # Transform a torch.Tensor to list of PIL.Image if pattern != "T H W C": vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] return video def load_models_to_device(self, model_names=[]): if self.vram_management_enabled: # offload models for name, model in self.named_children(): if name not in model_names: if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: for module in model.modules(): if hasattr(module, "offload"): module.offload() else: model.cpu() torch.cuda.empty_cache() # onload models for name, model in self.named_children(): if name in model_names: if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: for module in model.modules(): if hasattr(module, "onload"): module.onload() else: model.to(self.device) def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): # Initialize Gaussian noise generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) return noise def enable_cpu_offload(self): warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") self.vram_management_enabled = True def get_vram(self): return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3) def freeze_except(self, model_names): for name, model in self.named_children(): if name in model_names: model.train() model.requires_grad_(True) else: model.eval() model.requires_grad_(False) def blend_with_mask(self, base, addition, mask): return base * (1 - mask) + addition * mask def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs): timestep = scheduler.timesteps[progress_id] if inpaint_mask is not None: noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents) noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask) latents_next = scheduler.step(noise_pred, timestep, latents) return latents_next @dataclass class ModelConfig: path: Union[str, list[str]] = None model_id: str = None origin_file_pattern: Union[str, list[str]] = None download_resource: str = "ModelScope" offload_device: Optional[Union[str, torch.device]] = None offload_dtype: Optional[torch.dtype] = None local_model_path: str = None skip_download: bool = False def download_if_necessary(self, use_usp=False): if self.path is None: # Check model_id and origin_file_pattern if self.model_id is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") # Skip if not in rank 0 if use_usp: import torch.distributed as dist skip_download = self.skip_download or dist.get_rank() != 0 else: skip_download = self.skip_download # Check whether the origin path is a folder if self.origin_file_pattern is None or self.origin_file_pattern == "": self.origin_file_pattern = "" allow_file_pattern = None is_folder = True elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"): allow_file_pattern = self.origin_file_pattern + "*" is_folder = True else: allow_file_pattern = self.origin_file_pattern is_folder = False # Download if self.local_model_path is None: self.local_model_path = "./models" if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) if self.download_resource.lower() == "modelscope": snapshot_download( self.model_id, local_dir=os.path.join(self.local_model_path, self.model_id), allow_file_pattern=allow_file_pattern, ignore_file_pattern=downloaded_files, local_files_only=False ) elif self.download_resource.lower() == "huggingface": hf_snapshot_download( self.model_id, local_dir=os.path.join(self.local_model_path, self.model_id), allow_patterns=allow_file_pattern, ignore_patterns=downloaded_files, local_files_only=False ) else: raise ValueError("`download_resource` should be `modelscope` or `huggingface`.") # Let rank 1, 2, ... wait for rank 0 if use_usp: import torch.distributed as dist dist.barrier(device_ids=[dist.get_rank()]) # Return downloaded files if is_folder: self.path = os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern) else: self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: self.path = self.path[0] class PipelineUnit: def __init__( self, seperate_cfg: bool = False, take_over: bool = False, input_params: tuple[str] = None, input_params_posi: dict[str, str] = None, input_params_nega: dict[str, str] = None, onload_model_names: tuple[str] = None ): self.seperate_cfg = seperate_cfg self.take_over = take_over self.input_params = input_params self.input_params_posi = input_params_posi self.input_params_nega = input_params_nega self.onload_model_names = onload_model_names def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict: raise NotImplementedError("`process` is not implemented.") class PipelineUnitRunner: def __init__(self): pass def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: if unit.take_over: # Let the pipeline unit take over this function. inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) elif unit.seperate_cfg: # Positive side processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} if unit.input_params is not None: for name in unit.input_params: processor_inputs[name] = inputs_shared.get(name) processor_outputs = unit.process(pipe, **processor_inputs) inputs_posi.update(processor_outputs) # Negative side if inputs_shared["cfg_scale"] != 1: processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} if unit.input_params is not None: for name in unit.input_params: processor_inputs[name] = inputs_shared.get(name) processor_outputs = unit.process(pipe, **processor_inputs) inputs_nega.update(processor_outputs) else: inputs_nega.update(processor_outputs) else: processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} processor_outputs = unit.process(pipe, **processor_inputs) inputs_shared.update(processor_outputs) return inputs_shared, inputs_posi, inputs_nega