Elea Zhong commited on
Commit
9ac4ead
·
1 Parent(s): c4a6ce2

run train experiments

Browse files
configs/compare/5k_steps.yaml CHANGED
@@ -1,2 +1,5 @@
 
 
 
1
  num_train_epochs: 1
2
  max_train_steps: 5000
 
1
+ name_suffix:
2
+ max_steps: 5000
3
+
4
  num_train_epochs: 1
5
  max_train_steps: 5000
configs/optim/accum-4.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ name_suffix:
3
+ accum: 4
4
+
5
+ gradient_accumulation_steps: 4
configs/optim/cosine.yaml CHANGED
@@ -1,4 +1,6 @@
1
 
 
 
2
 
3
  lr_scheduler: cosine
4
- lr_warmup_steps: 250
 
1
 
2
+ name_suffix:
3
+ lr: cosine
4
 
5
  lr_scheduler: cosine
6
+ lr_warmup_steps: 50
configs/regression/base.yaml CHANGED
@@ -20,6 +20,8 @@ regression_gen_steps: 50
20
  editing_data_dir: "/data/CrispEdit"
21
  editing_total_per: 1
22
 
 
 
23
 
24
  validation_loss_terms:
25
  mse: 1.0
 
20
  editing_data_dir: "/data/CrispEdit"
21
  editing_total_per: 1
22
 
23
+ gradient_checkpointing: true
24
+ vae_tiling: false
25
 
26
  validation_loss_terms:
27
  mse: 1.0
configs/regression/lo_mse.yaml CHANGED
@@ -1,2 +1,5 @@
 
 
 
1
  train_loss_terms:
2
  mse: 0.1
 
1
+ name_suffix:
2
+ mse: 0.1
3
+
4
  train_loss_terms:
5
  mse: 0.1
configs/regression/triplet/mse-triplet-f.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-f"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-f"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.0
9
+ triplet_min_abs_diff: 0.25
configs/regression/triplet/mse-triplet-g.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-g"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-g"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: -0.1
9
+ triplet_min_abs_diff: 0.25
configs/regression/triplet/mse-triplet-h.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-h"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-h"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: -0.1
9
+ triplet_min_abs_diff: 0.3
qwenimage/datamodels.py CHANGED
@@ -1,6 +1,6 @@
1
  import enum
2
  from pathlib import Path
3
- from typing import Literal
4
 
5
  import torch
6
  from diffusers.image_processor import PipelineImageInput
@@ -79,6 +79,7 @@ class QwenConfig(ExperimentTrainerParameters):
79
  offload_text_encoder: bool = True
80
  quantize_text_encoder: bool = False
81
  quantize_transformer: bool = False
 
82
 
83
 
84
  train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
@@ -103,4 +104,18 @@ class QwenConfig(ExperimentTrainerParameters):
103
  editing_total_per: int = 1
104
  regression_base_pipe_steps: int = 8
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
1
  import enum
2
  from pathlib import Path
3
+ from typing import Any, Literal
4
 
5
  import torch
6
  from diffusers.image_processor import PipelineImageInput
 
79
  offload_text_encoder: bool = True
80
  quantize_text_encoder: bool = False
81
  quantize_transformer: bool = False
82
+ vae_tiling: bool = False
83
 
84
 
85
  train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
 
104
  editing_total_per: int = 1
105
  regression_base_pipe_steps: int = 8
106
 
107
+ name_suffix: dict[str,Any]|None = None
108
+
109
+ def add_suffix_to_names(self):
110
+ if self.name_suffix is None:
111
+ return
112
+ suffix_sum = ""
113
+ for suf_name,suf_val in self.name_suffix.items():
114
+ suffix_sum += "_" + suf_name
115
+ suf_val = str(suf_val)
116
+ suffix_sum += "_" + suf_val
117
+ self.run_name += suffix_sum
118
+ self.output_dir = self.output_dir.removesuffix("/") # in case
119
+ self.output_dir += suffix_sum
120
+
121
 
qwenimage/foundation.py CHANGED
@@ -81,7 +81,17 @@ class QwenImageFoundation(WandModel):
81
  self.text_encoder.requires_grad_(False)
82
  self.text_encoder_device = None
83
  self.transformer.eval()
 
84
  self.transformer.requires_grad_(False)
 
 
 
 
 
 
 
 
 
85
 
86
  self.timestep_dist_utils = TimestepDistUtils(
87
  min_seq_len=self.scheduler.config.base_image_seq_len,
@@ -419,7 +429,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
419
  margin = loss_terms.triplet_margin
420
  triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
421
  print(f"{triplet_min_abs_diff=}")
422
- v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2, keepdim=True)
423
  zero_weight = torch.zeros_like(v_gt_neg_diff)
424
  v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
425
  ones = torch.ones_like(v_gt_neg_diff)
@@ -431,12 +441,11 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
431
 
432
  diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
433
  diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
434
- loss_unreduced = diffv_gt_pred - diffv_neg_pred
435
- loss_weighted = (loss_unreduced * v_weight).sum(dim=2)
436
- triplet_loss = F.relu(loss_weighted + margin).mean()
437
- ones = torch.ones_like(loss_weighted)
438
- zeros = torch.zeros_like(loss_weighted)
439
- loss_nonzero_nums = torch.sum(torch.where((loss_weighted + margin)>0, ones, zeros))
440
  wand_logger.log({
441
  "loss_nonzero_nums": loss_nonzero_nums,
442
  }, commit=False)
@@ -447,8 +456,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
447
  texam(v_weight, "v_weight")
448
  texam(diffv_gt_pred, "diffv_gt_pred")
449
  texam(diffv_neg_pred, "diffv_neg_pred")
450
- texam(loss_unreduced, "loss_unreduced")
451
- texam(loss_weighted, "loss_weighted")
452
 
453
 
454
 
@@ -467,7 +475,8 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
467
 
468
  if loss_accumulator.has_group("pixel"):
469
  x_0_pred = x_t_1d - t * v_pred_1d
470
- pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
 
471
  pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
472
 
473
  if loss_accumulator.has("pixel_lpips"):
 
81
  self.text_encoder.requires_grad_(False)
82
  self.text_encoder_device = None
83
  self.transformer.eval()
84
+
85
  self.transformer.requires_grad_(False)
86
+ if self.config.gradient_checkpointing:
87
+ self.transformer.enable_gradient_checkpointing()
88
+ if self.config.vae_tiling:
89
+ self.vae.enable_tiling(
90
+ 576,
91
+ 576,
92
+ 512,
93
+ 512
94
+ )
95
 
96
  self.timestep_dist_utils = TimestepDistUtils(
97
  min_seq_len=self.scheduler.config.base_image_seq_len,
 
429
  margin = loss_terms.triplet_margin
430
  triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
431
  print(f"{triplet_min_abs_diff=}")
432
+ v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2)
433
  zero_weight = torch.zeros_like(v_gt_neg_diff)
434
  v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
435
  ones = torch.ones_like(v_gt_neg_diff)
 
441
 
442
  diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
443
  diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
444
+ per_tok_diff = (diffv_gt_pred - diffv_neg_pred).sum(dim=2)
445
+ triplet_loss = torch.mean(F.relu((per_tok_diff + margin) * v_weight))
446
+ ones = torch.ones_like(per_tok_diff)
447
+ zeros = torch.zeros_like(per_tok_diff)
448
+ loss_nonzero_nums = torch.sum(torch.where(((per_tok_diff + margin) * v_weight)>0, ones, zeros))
 
449
  wand_logger.log({
450
  "loss_nonzero_nums": loss_nonzero_nums,
451
  }, commit=False)
 
456
  texam(v_weight, "v_weight")
457
  texam(diffv_gt_pred, "diffv_gt_pred")
458
  texam(diffv_neg_pred, "diffv_neg_pred")
459
+ texam(per_tok_diff, "per_tok_diff")
 
460
 
461
 
462
 
 
475
 
476
  if loss_accumulator.has_group("pixel"):
477
  x_0_pred = x_t_1d - t * v_pred_1d
478
+ with torch.no_grad():
479
+ pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
480
  pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
481
 
482
  if loss_accumulator.has("pixel_lpips"):
qwenimage/training.py CHANGED
@@ -118,6 +118,7 @@ def run_training(config_path: Path | str, update_config_paths: list[Path] | None
118
  config = QwenConfig(
119
  **config,
120
  )
 
121
 
122
  # Data
123
  if config.training_type.is_style:
 
118
  config = QwenConfig(
119
  **config,
120
  )
121
+ config.add_suffix_to_names()
122
 
123
  # Data
124
  if config.training_type.is_style:
scripts/train_multi.sh CHANGED
@@ -1,38 +1,21 @@
1
  #!/bin/bash
2
 
3
-
4
- # nohup python scripts/train.py configs/base.yaml --where modal \
5
- # --update configs/regression/base.yaml \
6
- # --update configs/regression/modal.yaml \
7
- # --update configs/regression/mse.yaml \
8
- # --update configs/compare/5k_steps.yaml \
9
- # > logs/mse.log 2>&1 &
10
-
11
-
12
- nohup python scripts/train.py configs/base.yaml --where modal \
13
- --update configs/regression/base.yaml \
14
- --update configs/regression/modal.yaml \
15
- --update configs/regression/triplet/mse-triplet-b.yaml \
16
- --update configs/compare/5k_steps.yaml \
17
- > logs/mse-triplet-b.log 2>&1 &
18
-
19
  nohup python scripts/train.py configs/base.yaml --where modal \
20
  --update configs/regression/base.yaml \
21
  --update configs/regression/modal.yaml \
22
- --update configs/regression/triplet/mse-triplet-c.yaml \
23
- --update configs/compare/5k_steps.yaml \
24
- > logs/mse-triplet-c.log 2>&1 &
25
 
26
  nohup python scripts/train.py configs/base.yaml --where modal \
27
  --update configs/regression/base.yaml \
28
  --update configs/regression/modal.yaml \
29
- --update configs/regression/triplet/mse-triplet-d.yaml \
30
- --update configs/compare/5k_steps.yaml \
31
- > logs/mse-triplet-d.log 2>&1 &
32
 
33
  nohup python scripts/train.py configs/base.yaml --where modal \
34
  --update configs/regression/base.yaml \
35
  --update configs/regression/modal.yaml \
36
- --update configs/regression/triplet/mse-triplet-e.yaml \
37
- --update configs/compare/5k_steps.yaml \
38
- > logs/mse-triplet-e.log 2>&1 &
 
1
  #!/bin/bash
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  nohup python scripts/train.py configs/base.yaml --where modal \
4
  --update configs/regression/base.yaml \
5
  --update configs/regression/modal.yaml \
6
+ --update configs/regression/mse-pixel-lpips.yaml \
7
+ > logs/mse-pixel-lpips.log 2>&1 &
 
8
 
9
  nohup python scripts/train.py configs/base.yaml --where modal \
10
  --update configs/regression/base.yaml \
11
  --update configs/regression/modal.yaml \
12
+ --update configs/regression/mse-pixel-lpips.yaml \
13
+ --update configs/optim/accum-4.yaml \
14
+ > logs/mse-pixel-lpips-accum4.log 2>&1 &
15
 
16
  nohup python scripts/train.py configs/base.yaml --where modal \
17
  --update configs/regression/base.yaml \
18
  --update configs/regression/modal.yaml \
19
+ --update configs/regression/mse-pixel-lpips.yaml \
20
+ --update configs/optim/cosine.yaml \
21
+ > logs/mse-pixel-lpips-cosine.log 2>&1 &