Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
9ac4ead
1
Parent(s):
c4a6ce2
run train experiments
Browse files- configs/compare/5k_steps.yaml +3 -0
- configs/optim/accum-4.yaml +5 -0
- configs/optim/cosine.yaml +3 -1
- configs/regression/base.yaml +2 -0
- configs/regression/lo_mse.yaml +3 -0
- configs/regression/triplet/mse-triplet-f.yaml +9 -0
- configs/regression/triplet/mse-triplet-g.yaml +9 -0
- configs/regression/triplet/mse-triplet-h.yaml +9 -0
- qwenimage/datamodels.py +16 -1
- qwenimage/foundation.py +19 -10
- qwenimage/training.py +1 -0
- scripts/train_multi.sh +8 -25
configs/compare/5k_steps.yaml
CHANGED
|
@@ -1,2 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
num_train_epochs: 1
|
| 2 |
max_train_steps: 5000
|
|
|
|
| 1 |
+
name_suffix:
|
| 2 |
+
max_steps: 5000
|
| 3 |
+
|
| 4 |
num_train_epochs: 1
|
| 5 |
max_train_steps: 5000
|
configs/optim/accum-4.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
name_suffix:
|
| 3 |
+
accum: 4
|
| 4 |
+
|
| 5 |
+
gradient_accumulation_steps: 4
|
configs/optim/cosine.yaml
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
|
|
|
|
|
|
|
| 2 |
|
| 3 |
lr_scheduler: cosine
|
| 4 |
-
lr_warmup_steps:
|
|
|
|
| 1 |
|
| 2 |
+
name_suffix:
|
| 3 |
+
lr: cosine
|
| 4 |
|
| 5 |
lr_scheduler: cosine
|
| 6 |
+
lr_warmup_steps: 50
|
configs/regression/base.yaml
CHANGED
|
@@ -20,6 +20,8 @@ regression_gen_steps: 50
|
|
| 20 |
editing_data_dir: "/data/CrispEdit"
|
| 21 |
editing_total_per: 1
|
| 22 |
|
|
|
|
|
|
|
| 23 |
|
| 24 |
validation_loss_terms:
|
| 25 |
mse: 1.0
|
|
|
|
| 20 |
editing_data_dir: "/data/CrispEdit"
|
| 21 |
editing_total_per: 1
|
| 22 |
|
| 23 |
+
gradient_checkpointing: true
|
| 24 |
+
vae_tiling: false
|
| 25 |
|
| 26 |
validation_loss_terms:
|
| 27 |
mse: 1.0
|
configs/regression/lo_mse.yaml
CHANGED
|
@@ -1,2 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
train_loss_terms:
|
| 2 |
mse: 0.1
|
|
|
|
| 1 |
+
name_suffix:
|
| 2 |
+
mse: 0.1
|
| 3 |
+
|
| 4 |
train_loss_terms:
|
| 5 |
mse: 0.1
|
configs/regression/triplet/mse-triplet-f.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-f"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-f"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.0
|
| 9 |
+
triplet_min_abs_diff: 0.25
|
configs/regression/triplet/mse-triplet-g.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-g"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-g"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: -0.1
|
| 9 |
+
triplet_min_abs_diff: 0.25
|
configs/regression/triplet/mse-triplet-h.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-h"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-h"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: -0.1
|
| 9 |
+
triplet_min_abs_diff: 0.3
|
qwenimage/datamodels.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import enum
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import Literal
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from diffusers.image_processor import PipelineImageInput
|
|
@@ -79,6 +79,7 @@ class QwenConfig(ExperimentTrainerParameters):
|
|
| 79 |
offload_text_encoder: bool = True
|
| 80 |
quantize_text_encoder: bool = False
|
| 81 |
quantize_transformer: bool = False
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
|
@@ -103,4 +104,18 @@ class QwenConfig(ExperimentTrainerParameters):
|
|
| 103 |
editing_total_per: int = 1
|
| 104 |
regression_base_pipe_steps: int = 8
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
|
|
|
| 1 |
import enum
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Any, Literal
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from diffusers.image_processor import PipelineImageInput
|
|
|
|
| 79 |
offload_text_encoder: bool = True
|
| 80 |
quantize_text_encoder: bool = False
|
| 81 |
quantize_transformer: bool = False
|
| 82 |
+
vae_tiling: bool = False
|
| 83 |
|
| 84 |
|
| 85 |
train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
|
|
|
| 104 |
editing_total_per: int = 1
|
| 105 |
regression_base_pipe_steps: int = 8
|
| 106 |
|
| 107 |
+
name_suffix: dict[str,Any]|None = None
|
| 108 |
+
|
| 109 |
+
def add_suffix_to_names(self):
|
| 110 |
+
if self.name_suffix is None:
|
| 111 |
+
return
|
| 112 |
+
suffix_sum = ""
|
| 113 |
+
for suf_name,suf_val in self.name_suffix.items():
|
| 114 |
+
suffix_sum += "_" + suf_name
|
| 115 |
+
suf_val = str(suf_val)
|
| 116 |
+
suffix_sum += "_" + suf_val
|
| 117 |
+
self.run_name += suffix_sum
|
| 118 |
+
self.output_dir = self.output_dir.removesuffix("/") # in case
|
| 119 |
+
self.output_dir += suffix_sum
|
| 120 |
+
|
| 121 |
|
qwenimage/foundation.py
CHANGED
|
@@ -81,7 +81,17 @@ class QwenImageFoundation(WandModel):
|
|
| 81 |
self.text_encoder.requires_grad_(False)
|
| 82 |
self.text_encoder_device = None
|
| 83 |
self.transformer.eval()
|
|
|
|
| 84 |
self.transformer.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
self.timestep_dist_utils = TimestepDistUtils(
|
| 87 |
min_seq_len=self.scheduler.config.base_image_seq_len,
|
|
@@ -419,7 +429,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 419 |
margin = loss_terms.triplet_margin
|
| 420 |
triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
|
| 421 |
print(f"{triplet_min_abs_diff=}")
|
| 422 |
-
v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2
|
| 423 |
zero_weight = torch.zeros_like(v_gt_neg_diff)
|
| 424 |
v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
|
| 425 |
ones = torch.ones_like(v_gt_neg_diff)
|
|
@@ -431,12 +441,11 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 431 |
|
| 432 |
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
|
| 433 |
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
loss_nonzero_nums = torch.sum(torch.where((loss_weighted + margin)>0, ones, zeros))
|
| 440 |
wand_logger.log({
|
| 441 |
"loss_nonzero_nums": loss_nonzero_nums,
|
| 442 |
}, commit=False)
|
|
@@ -447,8 +456,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 447 |
texam(v_weight, "v_weight")
|
| 448 |
texam(diffv_gt_pred, "diffv_gt_pred")
|
| 449 |
texam(diffv_neg_pred, "diffv_neg_pred")
|
| 450 |
-
texam(
|
| 451 |
-
texam(loss_weighted, "loss_weighted")
|
| 452 |
|
| 453 |
|
| 454 |
|
|
@@ -467,7 +475,8 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 467 |
|
| 468 |
if loss_accumulator.has_group("pixel"):
|
| 469 |
x_0_pred = x_t_1d - t * v_pred_1d
|
| 470 |
-
|
|
|
|
| 471 |
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
|
| 472 |
|
| 473 |
if loss_accumulator.has("pixel_lpips"):
|
|
|
|
| 81 |
self.text_encoder.requires_grad_(False)
|
| 82 |
self.text_encoder_device = None
|
| 83 |
self.transformer.eval()
|
| 84 |
+
|
| 85 |
self.transformer.requires_grad_(False)
|
| 86 |
+
if self.config.gradient_checkpointing:
|
| 87 |
+
self.transformer.enable_gradient_checkpointing()
|
| 88 |
+
if self.config.vae_tiling:
|
| 89 |
+
self.vae.enable_tiling(
|
| 90 |
+
576,
|
| 91 |
+
576,
|
| 92 |
+
512,
|
| 93 |
+
512
|
| 94 |
+
)
|
| 95 |
|
| 96 |
self.timestep_dist_utils = TimestepDistUtils(
|
| 97 |
min_seq_len=self.scheduler.config.base_image_seq_len,
|
|
|
|
| 429 |
margin = loss_terms.triplet_margin
|
| 430 |
triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
|
| 431 |
print(f"{triplet_min_abs_diff=}")
|
| 432 |
+
v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2)
|
| 433 |
zero_weight = torch.zeros_like(v_gt_neg_diff)
|
| 434 |
v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
|
| 435 |
ones = torch.ones_like(v_gt_neg_diff)
|
|
|
|
| 441 |
|
| 442 |
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
|
| 443 |
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
|
| 444 |
+
per_tok_diff = (diffv_gt_pred - diffv_neg_pred).sum(dim=2)
|
| 445 |
+
triplet_loss = torch.mean(F.relu((per_tok_diff + margin) * v_weight))
|
| 446 |
+
ones = torch.ones_like(per_tok_diff)
|
| 447 |
+
zeros = torch.zeros_like(per_tok_diff)
|
| 448 |
+
loss_nonzero_nums = torch.sum(torch.where(((per_tok_diff + margin) * v_weight)>0, ones, zeros))
|
|
|
|
| 449 |
wand_logger.log({
|
| 450 |
"loss_nonzero_nums": loss_nonzero_nums,
|
| 451 |
}, commit=False)
|
|
|
|
| 456 |
texam(v_weight, "v_weight")
|
| 457 |
texam(diffv_gt_pred, "diffv_gt_pred")
|
| 458 |
texam(diffv_neg_pred, "diffv_neg_pred")
|
| 459 |
+
texam(per_tok_diff, "per_tok_diff")
|
|
|
|
| 460 |
|
| 461 |
|
| 462 |
|
|
|
|
| 475 |
|
| 476 |
if loss_accumulator.has_group("pixel"):
|
| 477 |
x_0_pred = x_t_1d - t * v_pred_1d
|
| 478 |
+
with torch.no_grad():
|
| 479 |
+
pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
|
| 480 |
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
|
| 481 |
|
| 482 |
if loss_accumulator.has("pixel_lpips"):
|
qwenimage/training.py
CHANGED
|
@@ -118,6 +118,7 @@ def run_training(config_path: Path | str, update_config_paths: list[Path] | None
|
|
| 118 |
config = QwenConfig(
|
| 119 |
**config,
|
| 120 |
)
|
|
|
|
| 121 |
|
| 122 |
# Data
|
| 123 |
if config.training_type.is_style:
|
|
|
|
| 118 |
config = QwenConfig(
|
| 119 |
**config,
|
| 120 |
)
|
| 121 |
+
config.add_suffix_to_names()
|
| 122 |
|
| 123 |
# Data
|
| 124 |
if config.training_type.is_style:
|
scripts/train_multi.sh
CHANGED
|
@@ -1,38 +1,21 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
-
|
| 4 |
-
# nohup python scripts/train.py configs/base.yaml --where modal \
|
| 5 |
-
# --update configs/regression/base.yaml \
|
| 6 |
-
# --update configs/regression/modal.yaml \
|
| 7 |
-
# --update configs/regression/mse.yaml \
|
| 8 |
-
# --update configs/compare/5k_steps.yaml \
|
| 9 |
-
# > logs/mse.log 2>&1 &
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 13 |
-
--update configs/regression/base.yaml \
|
| 14 |
-
--update configs/regression/modal.yaml \
|
| 15 |
-
--update configs/regression/triplet/mse-triplet-b.yaml \
|
| 16 |
-
--update configs/compare/5k_steps.yaml \
|
| 17 |
-
> logs/mse-triplet-b.log 2>&1 &
|
| 18 |
-
|
| 19 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 20 |
--update configs/regression/base.yaml \
|
| 21 |
--update configs/regression/modal.yaml \
|
| 22 |
-
--update configs/regression/
|
| 23 |
-
|
| 24 |
-
> logs/mse-triplet-c.log 2>&1 &
|
| 25 |
|
| 26 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 27 |
--update configs/regression/base.yaml \
|
| 28 |
--update configs/regression/modal.yaml \
|
| 29 |
-
--update configs/regression/
|
| 30 |
-
--update configs/
|
| 31 |
-
> logs/mse-
|
| 32 |
|
| 33 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 34 |
--update configs/regression/base.yaml \
|
| 35 |
--update configs/regression/modal.yaml \
|
| 36 |
-
--update configs/regression/
|
| 37 |
-
--update configs/
|
| 38 |
-
> logs/mse-
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 4 |
--update configs/regression/base.yaml \
|
| 5 |
--update configs/regression/modal.yaml \
|
| 6 |
+
--update configs/regression/mse-pixel-lpips.yaml \
|
| 7 |
+
> logs/mse-pixel-lpips.log 2>&1 &
|
|
|
|
| 8 |
|
| 9 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 10 |
--update configs/regression/base.yaml \
|
| 11 |
--update configs/regression/modal.yaml \
|
| 12 |
+
--update configs/regression/mse-pixel-lpips.yaml \
|
| 13 |
+
--update configs/optim/accum-4.yaml \
|
| 14 |
+
> logs/mse-pixel-lpips-accum4.log 2>&1 &
|
| 15 |
|
| 16 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 17 |
--update configs/regression/base.yaml \
|
| 18 |
--update configs/regression/modal.yaml \
|
| 19 |
+
--update configs/regression/mse-pixel-lpips.yaml \
|
| 20 |
+
--update configs/optim/cosine.yaml \
|
| 21 |
+
> logs/mse-pixel-lpips-cosine.log 2>&1 &
|