Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
ec28976
1
Parent(s):
833555d
update regression
Browse files
qwenimage/datamodels.py
CHANGED
|
@@ -76,7 +76,7 @@ class QwenConfig(ExperimentTrainerParameters):
|
|
| 76 |
train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 77 |
validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 78 |
|
| 79 |
-
training_type: TrainingType
|
| 80 |
train_range: DataRange|None=None
|
| 81 |
val_range: DataRange|None=None
|
| 82 |
test_range: DataRange|None=None
|
|
|
|
| 76 |
train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 77 |
validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 78 |
|
| 79 |
+
training_type: TrainingType|None=None
|
| 80 |
train_range: DataRange|None=None
|
| 81 |
val_range: DataRange|None=None
|
| 82 |
test_range: DataRange|None=None
|
scripts/save_regression_outputs.py
CHANGED
|
@@ -12,6 +12,9 @@ from qwenimage.foundation import QwenImageFoundationSaveInterm
|
|
| 12 |
def main():
|
| 13 |
parser = argparse.ArgumentParser()
|
| 14 |
parser.add_argument("--start-index", type=int, default=0)
|
|
|
|
|
|
|
|
|
|
| 15 |
args = parser.parse_args()
|
| 16 |
|
| 17 |
total_per = 10
|
|
@@ -30,7 +33,7 @@ def main():
|
|
| 30 |
for edit_type in EDIT_TYPES:
|
| 31 |
to_concat = []
|
| 32 |
for ds_n in range(total_per):
|
| 33 |
-
ds = load_dataset("parquet", data_files=f"/
|
| 34 |
to_concat.append(ds)
|
| 35 |
edit_type_concat = concatenate_datasets(to_concat)
|
| 36 |
all_edit_datasets.append(edit_type_concat)
|
|
@@ -38,10 +41,10 @@ def main():
|
|
| 38 |
# consistent ordering for indexing, also allow extension
|
| 39 |
join_ds = interleave_datasets(all_edit_datasets)
|
| 40 |
|
| 41 |
-
save_base_dir = Path(
|
| 42 |
save_base_dir.mkdir(exist_ok=True, parents=True)
|
| 43 |
|
| 44 |
-
foundation = QwenImageFoundationSaveInterm(QwenConfig())
|
| 45 |
|
| 46 |
dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
|
| 47 |
|
|
@@ -50,6 +53,7 @@ def main():
|
|
| 50 |
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
| 51 |
image=[input_data["input_img"]],
|
| 52 |
prompt=input_data["instruction"],
|
|
|
|
| 53 |
))
|
| 54 |
|
| 55 |
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|
|
|
|
| 12 |
def main():
|
| 13 |
parser = argparse.ArgumentParser()
|
| 14 |
parser.add_argument("--start-index", type=int, default=0)
|
| 15 |
+
parser.add_argument("--imsize", type=int, default=512)
|
| 16 |
+
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
|
| 17 |
+
parser.add_argument("--outdir", type=str, default="/data/regression_output")
|
| 18 |
args = parser.parse_args()
|
| 19 |
|
| 20 |
total_per = 10
|
|
|
|
| 33 |
for edit_type in EDIT_TYPES:
|
| 34 |
to_concat = []
|
| 35 |
for ds_n in range(total_per):
|
| 36 |
+
ds = load_dataset("parquet", data_files=f"{args.indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
|
| 37 |
to_concat.append(ds)
|
| 38 |
edit_type_concat = concatenate_datasets(to_concat)
|
| 39 |
all_edit_datasets.append(edit_type_concat)
|
|
|
|
| 41 |
# consistent ordering for indexing, also allow extension
|
| 42 |
join_ds = interleave_datasets(all_edit_datasets)
|
| 43 |
|
| 44 |
+
save_base_dir = Path(args.outdir)
|
| 45 |
save_base_dir.mkdir(exist_ok=True, parents=True)
|
| 46 |
|
| 47 |
+
foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=args.imsize * args.imsize))
|
| 48 |
|
| 49 |
dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
|
| 50 |
|
|
|
|
| 53 |
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
| 54 |
image=[input_data["input_img"]],
|
| 55 |
prompt=input_data["instruction"],
|
| 56 |
+
vae_image_override=args.imsize * args.imsize,
|
| 57 |
))
|
| 58 |
|
| 59 |
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|