Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
04846cf
1
Parent(s):
2d60859
main.py
CHANGED
|
@@ -25,7 +25,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 25 |
DEBUG_MODE = False
|
| 26 |
DEBUG_MODE_2 = False
|
| 27 |
NUM_MAX_FRAMES = 1
|
| 28 |
-
|
| 29 |
SCREEN_WIDTH = 512
|
| 30 |
SCREEN_HEIGHT = 384
|
| 31 |
NUM_SAMPLING_STEPS = 32
|
|
@@ -48,8 +48,15 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd
|
|
| 48 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
|
| 49 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
|
| 50 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
|
| 51 |
-
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-22k"
|
| 52 |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
|
|
@@ -67,9 +74,17 @@ LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
|
|
| 67 |
|
| 68 |
if 'origunet' in MODEL_NAME:
|
| 69 |
if 'x0' in MODEL_NAME:
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
else:
|
| 74 |
model = initialize_model("config_final_model.yaml", MODEL_NAME)
|
| 75 |
|
|
@@ -205,12 +220,12 @@ def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
|
|
| 205 |
sample_latent = output_from_rnn[:, :16]
|
| 206 |
else:
|
| 207 |
#NUM_SAMPLING_STEPS = 8
|
| 208 |
-
if num_sampling_steps >=
|
| 209 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
| 210 |
else:
|
| 211 |
if num_sampling_steps == 1:
|
| 212 |
x = torch.randn([1, *LATENT_DIMS], device=device)
|
| 213 |
-
t = torch.full((1,),
|
| 214 |
sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
|
| 215 |
else:
|
| 216 |
sampler = DDIMSampler(model)
|
|
|
|
| 25 |
DEBUG_MODE = False
|
| 26 |
DEBUG_MODE_2 = False
|
| 27 |
NUM_MAX_FRAMES = 1
|
| 28 |
+
TIMESTEPS = 1000
|
| 29 |
SCREEN_WIDTH = 512
|
| 30 |
SCREEN_HEIGHT = 384
|
| 31 |
NUM_SAMPLING_STEPS = 32
|
|
|
|
| 48 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
|
| 49 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
|
| 50 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
|
|
|
|
| 51 |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
|
| 52 |
+
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-46k"
|
| 53 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-142k"
|
| 54 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-338k"
|
| 55 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-x0-140k"
|
| 56 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-eps-144k"
|
| 57 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-70k"
|
| 58 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-joint-onlineonly-eps22-40k"
|
| 59 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-22-38k"
|
| 60 |
|
| 61 |
|
| 62 |
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
|
|
|
|
| 74 |
|
| 75 |
if 'origunet' in MODEL_NAME:
|
| 76 |
if 'x0' in MODEL_NAME:
|
| 77 |
+
if 'ddpm32' in MODEL_NAME:
|
| 78 |
+
TIMESTEPS = 32
|
| 79 |
+
model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", MODEL_NAME)
|
| 80 |
+
else:
|
| 81 |
+
model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
|
| 82 |
else:
|
| 83 |
+
if 'ddpm32' in MODEL_NAME:
|
| 84 |
+
TIMESTEPS = 32
|
| 85 |
+
model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", MODEL_NAME)
|
| 86 |
+
else:
|
| 87 |
+
model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
|
| 88 |
else:
|
| 89 |
model = initialize_model("config_final_model.yaml", MODEL_NAME)
|
| 90 |
|
|
|
|
| 220 |
sample_latent = output_from_rnn[:, :16]
|
| 221 |
else:
|
| 222 |
#NUM_SAMPLING_STEPS = 8
|
| 223 |
+
if num_sampling_steps >= TIMESTEPS:
|
| 224 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
| 225 |
else:
|
| 226 |
if num_sampling_steps == 1:
|
| 227 |
x = torch.randn([1, *LATENT_DIMS], device=device)
|
| 228 |
+
t = torch.full((1,), TIMESTEPS-1, device=device, dtype=torch.long)
|
| 229 |
sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
|
| 230 |
else:
|
| 231 |
sampler = DDIMSampler(model)
|