Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
2d60859
1
Parent(s):
f0fd514
main.py
CHANGED
|
@@ -48,7 +48,8 @@ MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd
|
|
| 48 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
|
| 49 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
|
| 50 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
|
| 51 |
-
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0"
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
|
|
@@ -207,14 +208,19 @@ def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
|
|
| 207 |
if num_sampling_steps >= 1000:
|
| 208 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
| 209 |
else:
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
timing['unet'] = time.perf_counter() - start
|
| 219 |
|
| 220 |
# Decoding
|
|
|
|
| 48 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
|
| 49 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
|
| 50 |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
|
| 51 |
+
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-22k"
|
| 52 |
+
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
|
| 53 |
|
| 54 |
|
| 55 |
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
|
|
|
|
| 208 |
if num_sampling_steps >= 1000:
|
| 209 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
| 210 |
else:
|
| 211 |
+
if num_sampling_steps == 1:
|
| 212 |
+
x = torch.randn([1, *LATENT_DIMS], device=device)
|
| 213 |
+
t = torch.full((1,), 999, device=device, dtype=torch.long)
|
| 214 |
+
sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
|
| 215 |
+
else:
|
| 216 |
+
sampler = DDIMSampler(model)
|
| 217 |
+
sample_latent, _ = sampler.sample(
|
| 218 |
+
S=num_sampling_steps,
|
| 219 |
+
conditioning={'c_concat': output_from_rnn},
|
| 220 |
+
batch_size=1,
|
| 221 |
+
shape=LATENT_DIMS,
|
| 222 |
+
verbose=False
|
| 223 |
+
)
|
| 224 |
timing['unet'] = time.perf_counter() - start
|
| 225 |
|
| 226 |
# Decoding
|