Elea Zhong commited on
Commit
ec28976
·
1 Parent(s): 833555d

update regression

Browse files
qwenimage/datamodels.py CHANGED
@@ -76,7 +76,7 @@ class QwenConfig(ExperimentTrainerParameters):
76
  train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
77
  validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
78
 
79
- training_type: TrainingType
80
  train_range: DataRange|None=None
81
  val_range: DataRange|None=None
82
  test_range: DataRange|None=None
 
76
  train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
77
  validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
78
 
79
+ training_type: TrainingType|None=None
80
  train_range: DataRange|None=None
81
  val_range: DataRange|None=None
82
  test_range: DataRange|None=None
scripts/save_regression_outputs.py CHANGED
@@ -12,6 +12,9 @@ from qwenimage.foundation import QwenImageFoundationSaveInterm
12
  def main():
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument("--start-index", type=int, default=0)
 
 
 
15
  args = parser.parse_args()
16
 
17
  total_per = 10
@@ -30,7 +33,7 @@ def main():
30
  for edit_type in EDIT_TYPES:
31
  to_concat = []
32
  for ds_n in range(total_per):
33
- ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
34
  to_concat.append(ds)
35
  edit_type_concat = concatenate_datasets(to_concat)
36
  all_edit_datasets.append(edit_type_concat)
@@ -38,10 +41,10 @@ def main():
38
  # consistent ordering for indexing, also allow extension
39
  join_ds = interleave_datasets(all_edit_datasets)
40
 
41
- save_base_dir = Path("/data/regression_output")
42
  save_base_dir.mkdir(exist_ok=True, parents=True)
43
 
44
- foundation = QwenImageFoundationSaveInterm(QwenConfig())
45
 
46
  dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
47
 
@@ -50,6 +53,7 @@ def main():
50
  output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
51
  image=[input_data["input_img"]],
52
  prompt=input_data["instruction"],
 
53
  ))
54
 
55
  torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
 
12
  def main():
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument("--start-index", type=int, default=0)
15
+ parser.add_argument("--imsize", type=int, default=512)
16
+ parser.add_argument("--indir", type=str, default="/data/CrispEdit")
17
+ parser.add_argument("--outdir", type=str, default="/data/regression_output")
18
  args = parser.parse_args()
19
 
20
  total_per = 10
 
33
  for edit_type in EDIT_TYPES:
34
  to_concat = []
35
  for ds_n in range(total_per):
36
+ ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
37
  to_concat.append(ds)
38
  edit_type_concat = concatenate_datasets(to_concat)
39
  all_edit_datasets.append(edit_type_concat)
 
41
  # consistent ordering for indexing, also allow extension
42
  join_ds = interleave_datasets(all_edit_datasets)
43
 
44
+ save_base_dir = Path(args.outdir)
45
  save_base_dir.mkdir(exist_ok=True, parents=True)
46
 
47
+ foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize))
48
 
49
  dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
50
 
 
53
  output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
54
  image=[input_data["input_img"]],
55
  prompt=input_data["instruction"],
56
+ vae_image_override=args.imsize * args.imsize,
57
  ))
58
 
59
  torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")