File size: 74,167 Bytes
49cbc74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1e781e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd /home/ubuntu/Qwen-Image-Edit-Angles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6192ee5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4941"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import glob\n",
    "from pathlib import Path\n",
    "\n",
    "base_data = Path(\"/data/regression_output\")\n",
    "\n",
    "all_reg = list(base_data.glob(\"*.pt\"))\n",
    "max_ind = max([int(reg_pth.stem) for reg_pth in all_reg])\n",
    "\n",
    "max_ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b5124900",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompt_embeds\n",
      "prompt_embeds_mask\n",
      "noise\n",
      "image_latents\n",
      "vae_image_sizes\n",
      "img_shapes\n",
      "txt_seq_lens\n",
      "t_0\n",
      "latents_0_start\n",
      "noise_pred_0\n",
      "t_1\n",
      "latents_1_start\n",
      "noise_pred_1\n",
      "t_2\n",
      "latents_2_start\n",
      "noise_pred_2\n",
      "t_3\n",
      "latents_3_start\n",
      "noise_pred_3\n",
      "t_4\n",
      "latents_4_start\n",
      "noise_pred_4\n",
      "t_5\n",
      "latents_5_start\n",
      "noise_pred_5\n",
      "t_6\n",
      "latents_6_start\n",
      "noise_pred_6\n",
      "t_7\n",
      "latents_7_start\n",
      "noise_pred_7\n",
      "t_8\n",
      "latents_8_start\n",
      "noise_pred_8\n",
      "t_9\n",
      "latents_9_start\n",
      "noise_pred_9\n",
      "t_10\n",
      "latents_10_start\n",
      "noise_pred_10\n",
      "t_11\n",
      "latents_11_start\n",
      "noise_pred_11\n",
      "t_12\n",
      "latents_12_start\n",
      "noise_pred_12\n",
      "t_13\n",
      "latents_13_start\n",
      "noise_pred_13\n",
      "t_14\n",
      "latents_14_start\n",
      "noise_pred_14\n",
      "t_15\n",
      "latents_15_start\n",
      "noise_pred_15\n",
      "t_16\n",
      "latents_16_start\n",
      "noise_pred_16\n",
      "t_17\n",
      "latents_17_start\n",
      "noise_pred_17\n",
      "t_18\n",
      "latents_18_start\n",
      "noise_pred_18\n",
      "t_19\n",
      "latents_19_start\n",
      "noise_pred_19\n",
      "t_20\n",
      "latents_20_start\n",
      "noise_pred_20\n",
      "t_21\n",
      "latents_21_start\n",
      "noise_pred_21\n",
      "t_22\n",
      "latents_22_start\n",
      "noise_pred_22\n",
      "t_23\n",
      "latents_23_start\n",
      "noise_pred_23\n",
      "t_24\n",
      "latents_24_start\n",
      "noise_pred_24\n",
      "t_25\n",
      "latents_25_start\n",
      "noise_pred_25\n",
      "t_26\n",
      "latents_26_start\n",
      "noise_pred_26\n",
      "t_27\n",
      "latents_27_start\n",
      "noise_pred_27\n",
      "t_28\n",
      "latents_28_start\n",
      "noise_pred_28\n",
      "t_29\n",
      "latents_29_start\n",
      "noise_pred_29\n",
      "t_30\n",
      "latents_30_start\n",
      "noise_pred_30\n",
      "t_31\n",
      "latents_31_start\n",
      "noise_pred_31\n",
      "t_32\n",
      "latents_32_start\n",
      "noise_pred_32\n",
      "t_33\n",
      "latents_33_start\n",
      "noise_pred_33\n",
      "t_34\n",
      "latents_34_start\n",
      "noise_pred_34\n",
      "t_35\n",
      "latents_35_start\n",
      "noise_pred_35\n",
      "t_36\n",
      "latents_36_start\n",
      "noise_pred_36\n",
      "t_37\n",
      "latents_37_start\n",
      "noise_pred_37\n",
      "t_38\n",
      "latents_38_start\n",
      "noise_pred_38\n",
      "t_39\n",
      "latents_39_start\n",
      "noise_pred_39\n",
      "t_40\n",
      "latents_40_start\n",
      "noise_pred_40\n",
      "t_41\n",
      "latents_41_start\n",
      "noise_pred_41\n",
      "t_42\n",
      "latents_42_start\n",
      "noise_pred_42\n",
      "t_43\n",
      "latents_43_start\n",
      "noise_pred_43\n",
      "t_44\n",
      "latents_44_start\n",
      "noise_pred_44\n",
      "t_45\n",
      "latents_45_start\n",
      "noise_pred_45\n",
      "t_46\n",
      "latents_46_start\n",
      "noise_pred_46\n",
      "t_47\n",
      "latents_47_start\n",
      "noise_pred_47\n",
      "t_48\n",
      "latents_48_start\n",
      "noise_pred_48\n",
      "t_49\n",
      "latents_49_start\n",
      "noise_pred_49\n",
      "output\n",
      "height\n",
      "width\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "out = all_reg[0]\n",
    "out_dict = torch.load(out)\n",
    "for k in out_dict.keys():\n",
    "    print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f693db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'003329'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da107d9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "69G\t/data/regression_output\n"
     ]
    }
   ],
   "source": [
    "!du -h {base_data}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269c0bfb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5964bf2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RegressionSource:\n",
    "    # WIP\n",
    "\n",
    "    def __init__(self, data_dir, gen_steps=50):\n",
    "        if not isinstance(data_dir, Path):\n",
    "            data_dir = Path(data_dir)\n",
    "        self.data_paths = list(data_dir.glob(\"*.pt\"))\n",
    "        self.gen_steps = gen_steps\n",
    "        self._len = gen_steps * len(self.data_paths)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self._len\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        data_idx = idx // self.gen_steps\n",
    "        step_idx = idx % self.gen_steps\n",
    "        out_dict = torch.load(self.data_paths[data_idx])\n",
    "        t = out_dict.pop(f\"t_{step_idx}\")\n",
    "        latents_start = out_dict.pop(f\"latents_{step_idx}_start\")\n",
    "        noise_pred = out_dict.pop(f\"noise_pred_{step_idx}\")\n",
    "        out_dict[\"t\"] = t\n",
    "        out_dict[\"latents_start\"] = latents_start\n",
    "        out_dict[\"noise_pred\"] = noise_pred\n",
    "        return out_dict\n",
    "\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b62e7bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "src = RegressionSource(\"/data/regression_output\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ee68ab3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9738e1d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'prompt_embeds': tensor([[[ 3.2188,  3.4375,  3.1719,  ...,  0.3535,  1.7812,  2.0312],\n",
       "          [ 3.0938,  1.9297,  0.7031,  ...,  2.0625, -0.2314,  1.2266],\n",
       "          [ 2.6250,  1.7031,  3.5625,  ...,  0.8828,  2.1719,  1.4766],\n",
       "          ...,\n",
       "          [ 4.7812,  0.1689,  4.4688,  ...,  5.0000, -1.8359, -0.7500],\n",
       "          [-0.0654,  2.1406, -1.4922,  ...,  0.7930,  3.9844,  1.6406],\n",
       "          [-2.7031,  1.5547,  2.6094,  ..., -0.0481,  0.1582,  0.7383]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'prompt_embeds_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
       " 'noise': tensor([[[ 1.9766, -0.8047,  0.6367,  ..., -1.7422,  1.0469,  0.3809],\n",
       "          [ 1.6562,  0.1147, -0.1562,  ...,  0.7539, -0.1768, -1.6953],\n",
       "          [ 0.3984,  0.3926,  0.1914,  ..., -0.9258, -1.3281, -2.3281],\n",
       "          ...,\n",
       "          [-1.4766,  0.2539,  1.3359,  ...,  0.1797, -0.6250,  0.7617],\n",
       "          [ 1.0391,  1.3672, -0.1572,  ...,  0.1152,  1.4688, -0.2852],\n",
       "          [ 0.4941, -1.1094,  2.3438,  ...,  0.8281, -0.8320,  0.4258]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'image_latents': tensor([[[ 0.1719,  0.0194,  0.0084,  ..., -0.1494,  0.0552,  0.2295],\n",
       "          [ 0.1777,  0.1406,  0.1592,  ...,  0.1260, -0.2412, -0.0041],\n",
       "          [ 0.1187,  0.2324,  0.1104,  ...,  0.0801,  0.3516,  0.4414],\n",
       "          ...,\n",
       "          [-0.0972, -0.3242, -0.3027,  ...,  0.3672,  0.1699,  0.4004],\n",
       "          [-0.1221, -0.0125, -0.3867,  ...,  0.7031,  0.8477,  0.8320],\n",
       "          [-0.1416, -0.1914, -0.3359,  ...,  0.9883,  1.3359,  0.7422]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'vae_image_sizes': [(448, 576)],\n",
       " 'img_shapes': [[(1, 36, 28), (1, 36, 28)]],\n",
       " 'txt_seq_lens': [228],\n",
       " 't_1': tensor([0.9883], dtype=torch.bfloat16),\n",
       " 'latents_1_start': tensor([[[ 1.9531, -0.7930,  0.6289,  ..., -1.7188,  1.0312,  0.3770],\n",
       "          [ 1.6406,  0.1143, -0.1533,  ...,  0.7461, -0.1748, -1.6719],\n",
       "          [ 0.3945,  0.3887,  0.1895,  ..., -0.9141, -1.3125, -2.2969],\n",
       "          ...,\n",
       "          [-1.4609,  0.2471,  1.3203,  ...,  0.1826, -0.6133,  0.7578],\n",
       "          [ 1.0234,  1.3516, -0.1582,  ...,  0.1226,  1.4609, -0.2715],\n",
       "          [ 0.4863, -1.1016,  2.3125,  ...,  0.8281, -0.8086,  0.4297]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_1': tensor([[[ 1.9062, -0.9102,  0.5742,  ..., -1.7422,  1.0625,  0.3359],\n",
       "          [ 1.5859,  0.0306, -0.2637,  ...,  0.7539, -0.1768, -1.7969],\n",
       "          [ 0.3184,  0.3066,  0.1592,  ..., -1.0391, -1.5391, -2.5625],\n",
       "          ...,\n",
       "          [-1.2734,  0.4941,  1.5781,  ..., -0.2344, -1.0156,  0.3477],\n",
       "          [ 1.2422,  1.5234,  0.0510,  ..., -0.5820,  0.9219, -1.0859],\n",
       "          [ 0.6172, -0.9336,  2.5781,  ..., -0.0801, -1.7734, -0.3730]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_2': tensor([0.9766], dtype=torch.bfloat16),\n",
       " 'latents_2_start': tensor([[[ 1.9297, -0.7812,  0.6211,  ..., -1.6953,  1.0156,  0.3730],\n",
       "          [ 1.6250,  0.1138, -0.1504,  ...,  0.7383, -0.1729, -1.6484],\n",
       "          [ 0.3906,  0.3848,  0.1875,  ..., -0.9023, -1.2969, -2.2656],\n",
       "          ...,\n",
       "          [-1.4453,  0.2412,  1.3047,  ...,  0.1855, -0.6016,  0.7539],\n",
       "          [ 1.0078,  1.3359, -0.1592,  ...,  0.1299,  1.4531, -0.2578],\n",
       "          [ 0.4785, -1.0938,  2.2812,  ...,  0.8281, -0.7891,  0.4336]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_2': tensor([[[ 1.8984, -0.9219,  0.5664,  ..., -1.7188,  1.0703,  0.3633],\n",
       "          [ 1.5859,  0.0256, -0.2539,  ...,  0.7578, -0.1719, -1.7656],\n",
       "          [ 0.3105,  0.3027,  0.1611,  ..., -1.0000, -1.4688, -2.4688],\n",
       "          ...,\n",
       "          [-1.2969,  0.4453,  1.5625,  ..., -0.1934, -0.9883,  0.4082],\n",
       "          [ 1.2188,  1.4844, -0.0028,  ..., -0.4492,  1.0312, -0.9180],\n",
       "          [ 0.5820, -1.0156,  2.5156,  ...,  0.1885, -1.5391, -0.1602]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_3': tensor([0.9648], dtype=torch.bfloat16),\n",
       " 'latents_3_start': tensor([[[ 1.9062, -0.7695,  0.6133,  ..., -1.6719,  1.0000,  0.3691],\n",
       "          [ 1.6094,  0.1133, -0.1475,  ...,  0.7305, -0.1709, -1.6250],\n",
       "          [ 0.3867,  0.3809,  0.1855,  ..., -0.8906, -1.2812, -2.2344],\n",
       "          ...,\n",
       "          [-1.4297,  0.2354,  1.2891,  ...,  0.1875, -0.5898,  0.7500],\n",
       "          [ 0.9922,  1.3203, -0.1592,  ...,  0.1357,  1.4375, -0.2461],\n",
       "          [ 0.4707, -1.0781,  2.2500,  ...,  0.8242, -0.7695,  0.4355]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_3': tensor([[[ 1.8984, -0.9180,  0.5430,  ..., -1.7031,  1.0938,  0.3691],\n",
       "          [ 1.5703,  0.0308, -0.2676,  ...,  0.7812, -0.1602, -1.7422],\n",
       "          [ 0.3164,  0.2949,  0.1514,  ..., -0.9922, -1.4609, -2.4531],\n",
       "          ...,\n",
       "          [-1.3203,  0.4277,  1.5234,  ..., -0.1611, -0.9688,  0.4434],\n",
       "          [ 1.1875,  1.4609, -0.0179,  ..., -0.4355,  1.0312, -0.8867],\n",
       "          [ 0.5547, -1.0234,  2.4844,  ...,  0.2344, -1.4844, -0.1025]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_4': tensor([0.9531], dtype=torch.bfloat16),\n",
       " 'latents_4_start': tensor([[[ 1.8828, -0.7578,  0.6055,  ..., -1.6484,  0.9844,  0.3652],\n",
       "          [ 1.5859,  0.1128, -0.1445,  ...,  0.7188, -0.1689, -1.6016],\n",
       "          [ 0.3828,  0.3770,  0.1836,  ..., -0.8789, -1.2656, -2.2031],\n",
       "          ...,\n",
       "          [-1.4141,  0.2305,  1.2734,  ...,  0.1895, -0.5781,  0.7461],\n",
       "          [ 0.9766,  1.3047, -0.1592,  ...,  0.1416,  1.4219, -0.2354],\n",
       "          [ 0.4629, -1.0625,  2.2188,  ...,  0.8203, -0.7500,  0.4375]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_4': tensor([[[ 1.8984, -0.9141,  0.5508,  ..., -1.7109,  1.0859,  0.3672],\n",
       "          [ 1.5625,  0.0238, -0.2754,  ...,  0.7656, -0.1768, -1.7578],\n",
       "          [ 0.3105,  0.2988,  0.1602,  ..., -1.0156, -1.4766, -2.4375],\n",
       "          ...,\n",
       "          [-1.3125,  0.4316,  1.5469,  ..., -0.1621, -0.9805,  0.4141],\n",
       "          [ 1.1953,  1.4844, -0.0118,  ..., -0.4590,  1.0078, -0.9492],\n",
       "          [ 0.5703, -1.0156,  2.5156,  ...,  0.1777, -1.5469, -0.1475]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_5': tensor([0.9414], dtype=torch.bfloat16),\n",
       " 'latents_5_start': tensor([[[ 1.8594, -0.7461,  0.5977,  ..., -1.6250,  0.9688,  0.3613],\n",
       "          [ 1.5625,  0.1123, -0.1406,  ...,  0.7109, -0.1670, -1.5781],\n",
       "          [ 0.3789,  0.3730,  0.1816,  ..., -0.8672, -1.2500, -2.1719],\n",
       "          ...,\n",
       "          [-1.3984,  0.2246,  1.2500,  ...,  0.1914, -0.5664,  0.7422],\n",
       "          [ 0.9609,  1.2891, -0.1592,  ...,  0.1475,  1.4062, -0.2236],\n",
       "          [ 0.4551, -1.0469,  2.1875,  ...,  0.8164, -0.7305,  0.4395]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_5': tensor([[[ 1.8906, -0.8984,  0.5586,  ..., -1.7031,  1.0391,  0.3516],\n",
       "          [ 1.5625,  0.0388, -0.2871,  ...,  0.8008, -0.1504, -1.7734],\n",
       "          [ 0.3008,  0.2949,  0.1777,  ..., -1.0312, -1.5781, -2.5469],\n",
       "          ...,\n",
       "          [-1.3047,  0.4688,  1.5938,  ..., -0.2188, -1.0781,  0.3945],\n",
       "          [ 1.2031,  1.4844,  0.0082,  ..., -0.5469,  0.9414, -1.1875],\n",
       "          [ 0.5781, -0.9336,  2.5625,  ..., -0.0903, -1.8047, -0.3828]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_6': tensor([0.9258], dtype=torch.bfloat16),\n",
       " 'latents_6_start': tensor([[[ 1.8359, -0.7344,  0.5898,  ..., -1.6016,  0.9570,  0.3574],\n",
       "          [ 1.5391,  0.1118, -0.1367,  ...,  0.6992, -0.1650, -1.5547],\n",
       "          [ 0.3750,  0.3691,  0.1797,  ..., -0.8555, -1.2266, -2.1406],\n",
       "          ...,\n",
       "          [-1.3828,  0.2188,  1.2266,  ...,  0.1943, -0.5508,  0.7383],\n",
       "          [ 0.9453,  1.2734, -0.1592,  ...,  0.1543,  1.3906, -0.2080],\n",
       "          [ 0.4473, -1.0312,  2.1562,  ...,  0.8164, -0.7070,  0.4453]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_6': tensor([[[ 1.9219, -0.8828,  0.5820,  ..., -1.7109,  1.0781,  0.3613],\n",
       "          [ 1.5703,  0.0359, -0.2812,  ...,  0.7773, -0.1865, -1.8203],\n",
       "          [ 0.3301,  0.2949,  0.1924,  ..., -1.0781, -1.6016, -2.5312],\n",
       "          ...,\n",
       "          [-1.2734,  0.4941,  1.6094,  ..., -0.2236, -1.0391,  0.3633],\n",
       "          [ 1.2656,  1.5469,  0.0796,  ..., -0.6797,  0.8672, -1.2656],\n",
       "          [ 0.6484, -0.9102,  2.5938,  ..., -0.1904, -1.8516, -0.4590]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_7': tensor([0.9102], dtype=torch.bfloat16),\n",
       " 'latents_7_start': tensor([[[ 1.8125, -0.7227,  0.5820,  ..., -1.5781,  0.9414,  0.3535],\n",
       "          [ 1.5156,  0.1113, -0.1328,  ...,  0.6875, -0.1621, -1.5312],\n",
       "          [ 0.3711,  0.3652,  0.1768,  ..., -0.8398, -1.2031, -2.1094],\n",
       "          ...,\n",
       "          [-1.3672,  0.2119,  1.2031,  ...,  0.1973, -0.5352,  0.7344],\n",
       "          [ 0.9297,  1.2500, -0.1602,  ...,  0.1631,  1.3828, -0.1914],\n",
       "          [ 0.4395, -1.0156,  2.1250,  ...,  0.8203, -0.6836,  0.4512]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_7': tensor([[[ 1.9531, -0.8906,  0.5938,  ..., -1.7266,  1.0938,  0.4180],\n",
       "          [ 1.5781,  0.0309, -0.3008,  ...,  0.7969, -0.1699, -1.8281],\n",
       "          [ 0.3262,  0.3008,  0.2314,  ..., -1.0781, -1.6797, -2.6250],\n",
       "          ...,\n",
       "          [-1.2812,  0.5039,  1.5938,  ..., -0.2314, -1.0547,  0.3828],\n",
       "          [ 1.2734,  1.5781,  0.0859,  ..., -0.7930,  0.8555, -1.3516],\n",
       "          [ 0.6914, -0.9062,  2.6250,  ..., -0.2598, -1.8516, -0.4902]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_8': tensor([0.8984], dtype=torch.bfloat16),\n",
       " 'latents_8_start': tensor([[[ 1.7891, -0.7109,  0.5742,  ..., -1.5547,  0.9258,  0.3477],\n",
       "          [ 1.4922,  0.1108, -0.1289,  ...,  0.6758, -0.1602, -1.5078],\n",
       "          [ 0.3672,  0.3613,  0.1738,  ..., -0.8242, -1.1797, -2.0781],\n",
       "          ...,\n",
       "          [-1.3516,  0.2051,  1.1797,  ...,  0.2002, -0.5195,  0.7305],\n",
       "          [ 0.9141,  1.2266, -0.1611,  ...,  0.1738,  1.3750, -0.1729],\n",
       "          [ 0.4297, -1.0000,  2.0938,  ...,  0.8242, -0.6602,  0.4570]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_8': tensor([[[ 1.9453, -0.8789,  0.6094,  ..., -1.7266,  1.0781,  0.4082],\n",
       "          [ 1.5703,  0.0396, -0.3047,  ...,  0.7891, -0.1826, -1.8516],\n",
       "          [ 0.3164,  0.2949,  0.2500,  ..., -1.0859, -1.7031, -2.6250],\n",
       "          ...,\n",
       "          [-1.2578,  0.5234,  1.5938,  ..., -0.2246, -1.0547,  0.3770],\n",
       "          [ 1.2734,  1.6016,  0.0884,  ..., -0.8828,  0.8086, -1.3672],\n",
       "          [ 0.7070, -0.8828,  2.6094,  ..., -0.2832, -1.8750, -0.5117]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_9': tensor([0.8828], dtype=torch.bfloat16),\n",
       " 'latents_9_start': tensor([[[ 1.7656, -0.6992,  0.5664,  ..., -1.5312,  0.9102,  0.3418],\n",
       "          [ 1.4688,  0.1104, -0.1245,  ...,  0.6641, -0.1572, -1.4844],\n",
       "          [ 0.3633,  0.3574,  0.1699,  ..., -0.8086, -1.1562, -2.0469],\n",
       "          ...,\n",
       "          [-1.3359,  0.1982,  1.1562,  ...,  0.2031, -0.5039,  0.7266],\n",
       "          [ 0.8984,  1.2031, -0.1621,  ...,  0.1855,  1.3672, -0.1543],\n",
       "          [ 0.4199, -0.9883,  2.0625,  ...,  0.8281, -0.6328,  0.4648]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_9': tensor([[[ 1.9531, -0.8828,  0.6172,  ..., -1.7188,  1.0391,  0.4141],\n",
       "          [ 1.5703,  0.0583, -0.3125,  ...,  0.7930, -0.1582, -1.8594],\n",
       "          [ 0.3203,  0.2910,  0.2598,  ..., -1.1016, -1.7500, -2.6562],\n",
       "          ...,\n",
       "          [-1.2422,  0.5273,  1.6094,  ..., -0.2168, -1.0391,  0.4121],\n",
       "          [ 1.2656,  1.6172,  0.1001,  ..., -0.8984,  0.8008, -1.4062],\n",
       "          [ 0.7383, -0.8750,  2.6250,  ..., -0.2891, -1.8672, -0.5312]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_10': tensor([0.8711], dtype=torch.bfloat16),\n",
       " 'latents_10_start': tensor([[[ 1.7344, -0.6875,  0.5586,  ..., -1.5078,  0.8945,  0.3359],\n",
       "          [ 1.4453,  0.1094, -0.1201,  ...,  0.6523, -0.1553, -1.4609],\n",
       "          [ 0.3594,  0.3535,  0.1660,  ..., -0.7930, -1.1328, -2.0156],\n",
       "          ...,\n",
       "          [-1.3203,  0.1904,  1.1328,  ...,  0.2061, -0.4902,  0.7227],\n",
       "          [ 0.8789,  1.1797, -0.1631,  ...,  0.1982,  1.3594, -0.1348],\n",
       "          [ 0.4102, -0.9766,  2.0312,  ...,  0.8320, -0.6055,  0.4727]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_10': tensor([[[ 1.9609, -0.8672,  0.6445,  ..., -1.7188,  1.0156,  0.4180],\n",
       "          [ 1.5781,  0.0728, -0.3125,  ...,  0.7812, -0.1602, -1.8828],\n",
       "          [ 0.3125,  0.2832,  0.2832,  ..., -1.1172, -1.7734, -2.6875],\n",
       "          ...,\n",
       "          [-1.2500,  0.5273,  1.6016,  ..., -0.2227, -1.0625,  0.4062],\n",
       "          [ 1.2500,  1.6094,  0.1016,  ..., -0.9297,  0.8086, -1.4219],\n",
       "          [ 0.7617, -0.8555,  2.6406,  ..., -0.3105, -1.8594, -0.5352]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_11': tensor([0.8555], dtype=torch.bfloat16),\n",
       " 'latents_11_start': tensor([[[ 1.7031, -0.6758,  0.5508,  ..., -1.4844,  0.8789,  0.3301],\n",
       "          [ 1.4219,  0.1084, -0.1157,  ...,  0.6406, -0.1533, -1.4375],\n",
       "          [ 0.3555,  0.3496,  0.1621,  ..., -0.7773, -1.1094, -1.9766],\n",
       "          ...,\n",
       "          [-1.3047,  0.1826,  1.1094,  ...,  0.2090, -0.4746,  0.7188],\n",
       "          [ 0.8594,  1.1562, -0.1641,  ...,  0.2119,  1.3516, -0.1143],\n",
       "          [ 0.3984, -0.9648,  1.9922,  ...,  0.8359, -0.5781,  0.4805]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_11': tensor([[[ 1.9688, -0.8516,  0.6484,  ..., -1.7266,  1.0000,  0.4082],\n",
       "          [ 1.5938,  0.0679, -0.3105,  ...,  0.8086, -0.1455, -1.8984],\n",
       "          [ 0.3203,  0.2812,  0.2949,  ..., -1.1094, -1.7812, -2.6719],\n",
       "          ...,\n",
       "          [-1.2500,  0.5273,  1.5938,  ..., -0.2119, -1.0625,  0.4102],\n",
       "          [ 1.2656,  1.6016,  0.1011,  ..., -0.9180,  0.8281, -1.4531],\n",
       "          [ 0.7695, -0.8320,  2.6562,  ..., -0.2891, -1.8516, -0.5234]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_12': tensor([0.8438], dtype=torch.bfloat16),\n",
       " 'latents_12_start': tensor([[[ 1.6719, -0.6641,  0.5430,  ..., -1.4609,  0.8633,  0.3242],\n",
       "          [ 1.3984,  0.1074, -0.1113,  ...,  0.6289, -0.1514, -1.4062],\n",
       "          [ 0.3516,  0.3457,  0.1582,  ..., -0.7617, -1.0859, -1.9375],\n",
       "          ...,\n",
       "          [-1.2891,  0.1748,  1.0859,  ...,  0.2119, -0.4590,  0.7109],\n",
       "          [ 0.8398,  1.1328, -0.1660,  ...,  0.2256,  1.3359, -0.0933],\n",
       "          [ 0.3867, -0.9531,  1.9531,  ...,  0.8398, -0.5508,  0.4883]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_12': tensor([[[ 1.9688, -0.8477,  0.6602,  ..., -1.7422,  0.9805,  0.3965],\n",
       "          [ 1.5938,  0.0845, -0.3066,  ...,  0.7891, -0.1816, -1.9062],\n",
       "          [ 0.3105,  0.2754,  0.3242,  ..., -1.1328, -1.8047, -2.6875],\n",
       "          ...,\n",
       "          [-1.2422,  0.5195,  1.5938,  ..., -0.2227, -1.0625,  0.4180],\n",
       "          [ 1.2500,  1.6172,  0.1138,  ..., -0.9492,  0.8281, -1.4609],\n",
       "          [ 0.7852, -0.8555,  2.6562,  ..., -0.3047, -1.8438, -0.5430]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_13': tensor([0.8281], dtype=torch.bfloat16),\n",
       " 'latents_13_start': tensor([[[ 1.6406, -0.6523,  0.5312,  ..., -1.4375,  0.8477,  0.3184],\n",
       "          [ 1.3750,  0.1060, -0.1069,  ...,  0.6172, -0.1484, -1.3750],\n",
       "          [ 0.3477,  0.3418,  0.1533,  ..., -0.7461, -1.0625, -1.8984],\n",
       "          ...,\n",
       "          [-1.2734,  0.1670,  1.0625,  ...,  0.2148, -0.4434,  0.7031],\n",
       "          [ 0.8203,  1.1094, -0.1680,  ...,  0.2393,  1.3203, -0.0718],\n",
       "          [ 0.3750, -0.9414,  1.9141,  ...,  0.8438, -0.5234,  0.4961]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_13': tensor([[[ 1.9688, -0.8438,  0.6562,  ..., -1.7500,  0.9805,  0.3906],\n",
       "          [ 1.5938,  0.0791, -0.3066,  ...,  0.7734, -0.1934, -1.9062],\n",
       "          [ 0.3145,  0.2734,  0.3203,  ..., -1.1250, -1.8359, -2.7188],\n",
       "          ...,\n",
       "          [-1.2422,  0.5156,  1.6094,  ..., -0.2021, -1.0312,  0.4609],\n",
       "          [ 1.2656,  1.6250,  0.1108,  ..., -0.9453,  0.8789, -1.4688],\n",
       "          [ 0.7930, -0.8594,  2.6719,  ..., -0.3105, -1.8359, -0.5469]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_14': tensor([0.8125], dtype=torch.bfloat16),\n",
       " 'latents_14_start': tensor([[[ 1.6094, -0.6406,  0.5195,  ..., -1.4141,  0.8320,  0.3125],\n",
       "          [ 1.3516,  0.1050, -0.1025,  ...,  0.6055, -0.1455, -1.3438],\n",
       "          [ 0.3438,  0.3379,  0.1484,  ..., -0.7305, -1.0312, -1.8594],\n",
       "          ...,\n",
       "          [-1.2578,  0.1592,  1.0391,  ...,  0.2178, -0.4277,  0.6953],\n",
       "          [ 0.8008,  1.0859, -0.1699,  ...,  0.2539,  1.3047, -0.0498],\n",
       "          [ 0.3633, -0.9297,  1.8750,  ...,  0.8477, -0.4961,  0.5039]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_14': tensor([[[ 1.9609, -0.8438,  0.6562,  ..., -1.7500,  0.9727,  0.3828],\n",
       "          [ 1.5938,  0.0840, -0.3203,  ...,  0.7695, -0.2061, -1.8906],\n",
       "          [ 0.3125,  0.2754,  0.3262,  ..., -1.1328, -1.8359, -2.7031],\n",
       "          ...,\n",
       "          [-1.2422,  0.5117,  1.6094,  ..., -0.2002, -1.0156,  0.4746],\n",
       "          [ 1.2656,  1.6172,  0.1108,  ..., -0.9219,  0.8945, -1.4609],\n",
       "          [ 0.7969, -0.8672,  2.6406,  ..., -0.3047, -1.7969, -0.5430]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_15': tensor([0.7969], dtype=torch.bfloat16),\n",
       " 'latents_15_start': tensor([[[ 1.5781, -0.6289,  0.5078,  ..., -1.3906,  0.8164,  0.3066],\n",
       "          [ 1.3281,  0.1035, -0.0977,  ...,  0.5938, -0.1426, -1.3125],\n",
       "          [ 0.3398,  0.3340,  0.1436,  ..., -0.7148, -1.0000, -1.8203],\n",
       "          ...,\n",
       "          [-1.2422,  0.1514,  1.0156,  ...,  0.2207, -0.4121,  0.6875],\n",
       "          [ 0.7812,  1.0625, -0.1719,  ...,  0.2676,  1.2891, -0.0275],\n",
       "          [ 0.3516, -0.9180,  1.8359,  ...,  0.8516, -0.4688,  0.5117]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_15': tensor([[[ 1.9531, -0.8320,  0.6641,  ..., -1.7578,  0.9570,  0.3789],\n",
       "          [ 1.5938,  0.0806, -0.3184,  ...,  0.7500, -0.2031, -1.8750],\n",
       "          [ 0.3242,  0.2656,  0.3301,  ..., -1.1406, -1.8359, -2.7031],\n",
       "          ...,\n",
       "          [-1.2578,  0.5195,  1.6328,  ..., -0.1875, -0.9883,  0.5117],\n",
       "          [ 1.2734,  1.6094,  0.1230,  ..., -0.9297,  0.9102, -1.4531],\n",
       "          [ 0.7930, -0.8633,  2.6406,  ..., -0.3027, -1.8203, -0.5312]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_16': tensor([0.7812], dtype=torch.bfloat16),\n",
       " 'latents_16_start': tensor([[[ 1.5469, -0.6172,  0.4980,  ..., -1.3594,  0.8008,  0.3008],\n",
       "          [ 1.3047,  0.1021, -0.0928,  ...,  0.5820, -0.1396, -1.2812],\n",
       "          [ 0.3340,  0.3301,  0.1387,  ..., -0.6953, -0.9727, -1.7812],\n",
       "          ...,\n",
       "          [-1.2188,  0.1436,  0.9883,  ...,  0.2236, -0.3965,  0.6797],\n",
       "          [ 0.7617,  1.0391, -0.1738,  ...,  0.2812,  1.2734, -0.0048],\n",
       "          [ 0.3398, -0.9062,  1.7969,  ...,  0.8555, -0.4395,  0.5195]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_16': tensor([[[ 1.9766, -0.8320,  0.6758,  ..., -1.7734,  0.9844,  0.3867],\n",
       "          [ 1.6094,  0.0923, -0.3164,  ...,  0.7617, -0.2148, -1.8984],\n",
       "          [ 0.3281,  0.2695,  0.3398,  ..., -1.1484, -1.8516, -2.7500],\n",
       "          ...,\n",
       "          [-1.2500,  0.5156,  1.6406,  ..., -0.1953, -0.9648,  0.5273],\n",
       "          [ 1.3125,  1.6250,  0.1113,  ..., -0.9102,  0.9414, -1.4609],\n",
       "          [ 0.7969, -0.8750,  2.6719,  ..., -0.2988, -1.7891, -0.5469]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_17': tensor([0.7656], dtype=torch.bfloat16),\n",
       " 'latents_17_start': tensor([[[ 1.5156, -0.6055,  0.4883,  ..., -1.3281,  0.7852,  0.2949],\n",
       "          [ 1.2812,  0.1006, -0.0879,  ...,  0.5703, -0.1367, -1.2500],\n",
       "          [ 0.3281,  0.3262,  0.1328,  ..., -0.6758, -0.9414, -1.7344],\n",
       "          ...,\n",
       "          [-1.1953,  0.1357,  0.9609,  ...,  0.2266, -0.3809,  0.6719],\n",
       "          [ 0.7422,  1.0156, -0.1758,  ...,  0.2949,  1.2578,  0.0184],\n",
       "          [ 0.3281, -0.8906,  1.7578,  ...,  0.8594, -0.4102,  0.5273]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_17': tensor([[[ 1.9688, -0.8242,  0.6719,  ..., -1.7578,  0.9688,  0.3691],\n",
       "          [ 1.6094,  0.0869, -0.3145,  ...,  0.7500, -0.2217, -1.8828],\n",
       "          [ 0.3203,  0.2754,  0.3457,  ..., -1.1406, -1.8516, -2.7500],\n",
       "          ...,\n",
       "          [-1.2266,  0.5156,  1.6250,  ..., -0.1904, -0.9492,  0.5273],\n",
       "          [ 1.3047,  1.6172,  0.1040,  ..., -0.9141,  0.9570, -1.4531],\n",
       "          [ 0.7852, -0.8633,  2.6562,  ..., -0.2949, -1.7969, -0.5430]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_18': tensor([0.7461], dtype=torch.bfloat16),\n",
       " 'latents_18_start': tensor([[[ 1.4844, -0.5938,  0.4766,  ..., -1.2969,  0.7695,  0.2891],\n",
       "          [ 1.2578,  0.0991, -0.0830,  ...,  0.5586, -0.1328, -1.2188],\n",
       "          [ 0.3223,  0.3223,  0.1270,  ..., -0.6562, -0.9102, -1.6875],\n",
       "          ...,\n",
       "          [-1.1719,  0.1270,  0.9336,  ...,  0.2295, -0.3652,  0.6641],\n",
       "          [ 0.7227,  0.9883, -0.1777,  ...,  0.3105,  1.2422,  0.0420],\n",
       "          [ 0.3145, -0.8750,  1.7109,  ...,  0.8633, -0.3809,  0.5352]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_18': tensor([[[ 1.9844, -0.8398,  0.6680,  ..., -1.7578,  0.9727,  0.3730],\n",
       "          [ 1.6172,  0.0752, -0.3184,  ...,  0.7578, -0.2148, -1.8828],\n",
       "          [ 0.3066,  0.2715,  0.3398,  ..., -1.1328, -1.8516, -2.7500],\n",
       "          ...,\n",
       "          [-1.2422,  0.5156,  1.6484,  ..., -0.1777, -0.9336,  0.5625],\n",
       "          [ 1.3047,  1.6328,  0.1147,  ..., -0.8906,  0.9883, -1.4688],\n",
       "          [ 0.7734, -0.8672,  2.6406,  ..., -0.2910, -1.7891, -0.5312]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_19': tensor([0.7305], dtype=torch.bfloat16),\n",
       " 'latents_19_start': tensor([[[ 1.4531, -0.5781,  0.4648,  ..., -1.2656,  0.7539,  0.2832],\n",
       "          [ 1.2344,  0.0977, -0.0776,  ...,  0.5469, -0.1289, -1.1875],\n",
       "          [ 0.3164,  0.3184,  0.1211,  ..., -0.6367, -0.8789, -1.6406],\n",
       "          ...,\n",
       "          [-1.1484,  0.1182,  0.9062,  ...,  0.2324, -0.3496,  0.6562],\n",
       "          [ 0.6992,  0.9609, -0.1797,  ...,  0.3262,  1.2266,  0.0664],\n",
       "          [ 0.3008, -0.8594,  1.6641,  ...,  0.8672, -0.3516,  0.5430]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_19': tensor([[[ 1.9844, -0.8516,  0.6641,  ..., -1.7656,  0.9688,  0.3770],\n",
       "          [ 1.6094,  0.0684, -0.3281,  ...,  0.7734, -0.2119, -1.8672],\n",
       "          [ 0.3086,  0.2695,  0.3301,  ..., -1.1484, -1.8438, -2.7188],\n",
       "          ...,\n",
       "          [-1.2422,  0.5117,  1.6484,  ..., -0.1738, -0.8984,  0.5781],\n",
       "          [ 1.3203,  1.6328,  0.1157,  ..., -0.8828,  1.0078, -1.4844],\n",
       "          [ 0.7617, -0.8672,  2.6719,  ..., -0.2480, -1.8125, -0.5273]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_20': tensor([0.7148], dtype=torch.bfloat16),\n",
       " 'latents_20_start': tensor([[[ 1.4219, -0.5625,  0.4531,  ..., -1.2344,  0.7383,  0.2773],\n",
       "          [ 1.2109,  0.0967, -0.0723,  ...,  0.5352, -0.1250, -1.1562],\n",
       "          [ 0.3105,  0.3145,  0.1157,  ..., -0.6172, -0.8477, -1.5938],\n",
       "          ...,\n",
       "          [-1.1250,  0.1094,  0.8789,  ...,  0.2354, -0.3340,  0.6484],\n",
       "          [ 0.6758,  0.9336, -0.1816,  ...,  0.3418,  1.2109,  0.0913],\n",
       "          [ 0.2871, -0.8438,  1.6172,  ...,  0.8711, -0.3203,  0.5508]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_20': tensor([[[ 1.9766, -0.8438,  0.6562,  ..., -1.7656,  0.9570,  0.3809],\n",
       "          [ 1.6094,  0.0713, -0.3340,  ...,  0.7734, -0.2246, -1.8594],\n",
       "          [ 0.2910,  0.2598,  0.3281,  ..., -1.1250, -1.8359, -2.7500],\n",
       "          ...,\n",
       "          [-1.2422,  0.5039,  1.6406,  ..., -0.1738, -0.8867,  0.6016],\n",
       "          [ 1.3125,  1.6172,  0.1187,  ..., -0.8672,  1.0156, -1.4922],\n",
       "          [ 0.7578, -0.8711,  2.6719,  ..., -0.2559, -1.7891, -0.5547]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_21': tensor([0.6992], dtype=torch.bfloat16),\n",
       " 'latents_21_start': tensor([[[ 1.3906, -0.5469,  0.4414,  ..., -1.2031,  0.7227,  0.2715],\n",
       "          [ 1.1797,  0.0952, -0.0664,  ...,  0.5234, -0.1211, -1.1250],\n",
       "          [ 0.3047,  0.3105,  0.1099,  ..., -0.5977, -0.8164, -1.5469],\n",
       "          ...,\n",
       "          [-1.1016,  0.1006,  0.8516,  ...,  0.2383, -0.3184,  0.6367],\n",
       "          [ 0.6523,  0.9062, -0.1836,  ...,  0.3574,  1.1953,  0.1172],\n",
       "          [ 0.2734, -0.8281,  1.5703,  ...,  0.8750, -0.2891,  0.5586]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_21': tensor([[[ 1.9844, -0.8281,  0.6680,  ..., -1.7578,  0.9375,  0.3457],\n",
       "          [ 1.5938,  0.0811, -0.3203,  ...,  0.7656, -0.2207, -1.8594],\n",
       "          [ 0.2773,  0.2559,  0.3340,  ..., -1.1328, -1.8438, -2.7344],\n",
       "          ...,\n",
       "          [-1.2266,  0.4961,  1.6406,  ..., -0.1953, -0.8750,  0.5859],\n",
       "          [ 1.2969,  1.6250,  0.1147,  ..., -0.8594,  1.0156, -1.4922],\n",
       "          [ 0.7578, -0.8711,  2.6719,  ..., -0.2617, -1.7891, -0.5508]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_22': tensor([0.6797], dtype=torch.bfloat16),\n",
       " 'latents_22_start': tensor([[[ 1.3594, -0.5312,  0.4297,  ..., -1.1719,  0.7070,  0.2656],\n",
       "          [ 1.1484,  0.0938, -0.0608,  ...,  0.5117, -0.1172, -1.0938],\n",
       "          [ 0.3008,  0.3066,  0.1040,  ..., -0.5781, -0.7852, -1.5000],\n",
       "          ...,\n",
       "          [-1.0781,  0.0918,  0.8242,  ...,  0.2422, -0.3027,  0.6250],\n",
       "          [ 0.6289,  0.8789, -0.1855,  ...,  0.3730,  1.1797,  0.1436],\n",
       "          [ 0.2598, -0.8125,  1.5234,  ...,  0.8789, -0.2578,  0.5664]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_22': tensor([[[ 1.9922, -0.8242,  0.6523,  ..., -1.7500,  0.9375,  0.3477],\n",
       "          [ 1.5859,  0.0757, -0.3379,  ...,  0.7578, -0.2178, -1.8438],\n",
       "          [ 0.2930,  0.2520,  0.3320,  ..., -1.1250, -1.8516, -2.7500],\n",
       "          ...,\n",
       "          [-1.2031,  0.5000,  1.6406,  ..., -0.2012, -0.8750,  0.5820],\n",
       "          [ 1.3047,  1.6094,  0.1309,  ..., -0.8555,  1.0234, -1.5078],\n",
       "          [ 0.7617, -0.8711,  2.6562,  ..., -0.2793, -1.7969, -0.5742]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_23': tensor([0.6641], dtype=torch.bfloat16),\n",
       " 'latents_23_start': tensor([[[ 1.3203, -0.5156,  0.4180,  ..., -1.1406,  0.6914,  0.2598],\n",
       "          [ 1.1172,  0.0923, -0.0547,  ...,  0.4980, -0.1133, -1.0625],\n",
       "          [ 0.2949,  0.3027,  0.0981,  ..., -0.5586, -0.7500, -1.4531],\n",
       "          ...,\n",
       "          [-1.0547,  0.0830,  0.7930,  ...,  0.2461, -0.2871,  0.6133],\n",
       "          [ 0.6055,  0.8516, -0.1875,  ...,  0.3887,  1.1641,  0.1709],\n",
       "          [ 0.2461, -0.7969,  1.4766,  ...,  0.8828, -0.2256,  0.5781]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_23': tensor([[[ 1.9688, -0.8203,  0.6562,  ..., -1.7422,  0.9336,  0.3359],\n",
       "          [ 1.5781,  0.0923, -0.3164,  ...,  0.7617, -0.2188, -1.8438],\n",
       "          [ 0.2969,  0.2617,  0.3203,  ..., -1.1328, -1.8594, -2.7500],\n",
       "          ...,\n",
       "          [-1.2109,  0.5117,  1.6406,  ..., -0.1914, -0.8711,  0.5938],\n",
       "          [ 1.2891,  1.6094,  0.1182,  ..., -0.8477,  1.0469, -1.5000],\n",
       "          [ 0.7461, -0.8945,  2.6562,  ..., -0.2852, -1.8047, -0.5586]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_24': tensor([0.6445], dtype=torch.bfloat16),\n",
       " 'latents_24_start': tensor([[[ 1.2812, -0.5000,  0.4062,  ..., -1.1094,  0.6758,  0.2539],\n",
       "          [ 1.0859,  0.0908, -0.0488,  ...,  0.4844, -0.1094, -1.0312],\n",
       "          [ 0.2891,  0.2988,  0.0923,  ..., -0.5391, -0.7148, -1.4062],\n",
       "          ...,\n",
       "          [-1.0312,  0.0737,  0.7617,  ...,  0.2500, -0.2715,  0.6016],\n",
       "          [ 0.5820,  0.8203, -0.1895,  ...,  0.4043,  1.1484,  0.1982],\n",
       "          [ 0.2324, -0.7812,  1.4297,  ...,  0.8867, -0.1924,  0.5898]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_24': tensor([[[ 1.9688, -0.8164,  0.6484,  ..., -1.7422,  0.9492,  0.3574],\n",
       "          [ 1.5703,  0.0918, -0.3262,  ...,  0.7734, -0.2207, -1.8438],\n",
       "          [ 0.2871,  0.2637,  0.3340,  ..., -1.1250, -1.8516, -2.7656],\n",
       "          ...,\n",
       "          [-1.2031,  0.4961,  1.6328,  ..., -0.1768, -0.8555,  0.6055],\n",
       "          [ 1.2891,  1.6172,  0.1211,  ..., -0.8516,  1.0625, -1.5000],\n",
       "          [ 0.7461, -0.8867,  2.6562,  ..., -0.2891, -1.8047, -0.5391]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_25': tensor([0.6289], dtype=torch.bfloat16),\n",
       " 'latents_25_start': tensor([[[ 1.2422, -0.4844,  0.3945,  ..., -1.0781,  0.6562,  0.2471],\n",
       "          [ 1.0547,  0.0889, -0.0427,  ...,  0.4707, -0.1055, -0.9961],\n",
       "          [ 0.2832,  0.2930,  0.0859,  ..., -0.5195, -0.6797, -1.3516],\n",
       "          ...,\n",
       "          [-1.0078,  0.0645,  0.7305,  ...,  0.2539, -0.2559,  0.5898],\n",
       "          [ 0.5586,  0.7891, -0.1914,  ...,  0.4199,  1.1250,  0.2266],\n",
       "          [ 0.2188, -0.7656,  1.3828,  ...,  0.8906, -0.1582,  0.6016]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_25': tensor([[[ 1.9609, -0.8242,  0.6523,  ..., -1.7422,  0.9258,  0.3418],\n",
       "          [ 1.5625,  0.0850, -0.3359,  ...,  0.7812, -0.2295, -1.8516],\n",
       "          [ 0.2871,  0.2520,  0.3184,  ..., -1.1250, -1.8359, -2.7500],\n",
       "          ...,\n",
       "          [-1.1875,  0.4863,  1.6250,  ..., -0.1924, -0.8633,  0.6055],\n",
       "          [ 1.2969,  1.6172,  0.1240,  ..., -0.8359,  1.0625, -1.5078],\n",
       "          [ 0.7305, -0.8789,  2.6562,  ..., -0.2969, -1.8203, -0.5469]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_26': tensor([0.6094], dtype=torch.bfloat16),\n",
       " 'latents_26_start': tensor([[[ 1.2031, -0.4688,  0.3828,  ..., -1.0469,  0.6406,  0.2402],\n",
       "          [ 1.0234,  0.0874, -0.0364,  ...,  0.4551, -0.1011, -0.9609],\n",
       "          [ 0.2773,  0.2891,  0.0801,  ..., -0.4980, -0.6445, -1.2969],\n",
       "          ...,\n",
       "          [-0.9844,  0.0552,  0.6992,  ...,  0.2578, -0.2393,  0.5781],\n",
       "          [ 0.5352,  0.7578, -0.1934,  ...,  0.4355,  1.1016,  0.2559],\n",
       "          [ 0.2051, -0.7500,  1.3359,  ...,  0.8945, -0.1235,  0.6133]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_26': tensor([[[ 1.9609, -0.8320,  0.6602,  ..., -1.7578,  0.9219,  0.3496],\n",
       "          [ 1.5703,  0.0801, -0.3359,  ...,  0.7812, -0.2266, -1.8281],\n",
       "          [ 0.2793,  0.2471,  0.3223,  ..., -1.1172, -1.8281, -2.7188],\n",
       "          ...,\n",
       "          [-1.1797,  0.4863,  1.6172,  ..., -0.2021, -0.8516,  0.6133],\n",
       "          [ 1.2969,  1.6016,  0.1226,  ..., -0.8359,  1.0625, -1.5000],\n",
       "          [ 0.7305, -0.8906,  2.6562,  ..., -0.2988, -1.8047, -0.5469]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_27': tensor([0.5898], dtype=torch.bfloat16),\n",
       " 'latents_27_start': tensor([[[ 1.1641, -0.4531,  0.3691,  ..., -1.0156,  0.6211,  0.2334],\n",
       "          [ 0.9922,  0.0859, -0.0298,  ...,  0.4395, -0.0967, -0.9258],\n",
       "          [ 0.2715,  0.2852,  0.0737,  ..., -0.4766, -0.6094, -1.2422],\n",
       "          ...,\n",
       "          [-0.9609,  0.0457,  0.6680,  ...,  0.2617, -0.2227,  0.5664],\n",
       "          [ 0.5078,  0.7266, -0.1953,  ...,  0.4512,  1.0781,  0.2852],\n",
       "          [ 0.1904, -0.7344,  1.2812,  ...,  0.8984, -0.0884,  0.6250]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_27': tensor([[[ 1.9453, -0.8398,  0.6562,  ..., -1.7578,  0.9141,  0.3398],\n",
       "          [ 1.5469,  0.0654, -0.3379,  ...,  0.7734, -0.2236, -1.8359],\n",
       "          [ 0.2773,  0.2354,  0.3223,  ..., -1.1094, -1.8359, -2.7188],\n",
       "          ...,\n",
       "          [-1.1797,  0.4844,  1.6328,  ..., -0.1992, -0.8359,  0.6172],\n",
       "          [ 1.2891,  1.5859,  0.1133,  ..., -0.8242,  1.0469, -1.4922],\n",
       "          [ 0.7148, -0.9180,  2.6406,  ..., -0.3047, -1.8203, -0.5508]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_28': tensor([0.5664], dtype=torch.bfloat16),\n",
       " 'latents_28_start': tensor([[[ 1.1250, -0.4355,  0.3555,  ..., -0.9805,  0.6016,  0.2266],\n",
       "          [ 0.9609,  0.0845, -0.0231,  ...,  0.4238, -0.0923, -0.8906],\n",
       "          [ 0.2656,  0.2812,  0.0674,  ..., -0.4551, -0.5742, -1.1875],\n",
       "          ...,\n",
       "          [-0.9375,  0.0361,  0.6367,  ...,  0.2656, -0.2061,  0.5547],\n",
       "          [ 0.4824,  0.6953, -0.1973,  ...,  0.4668,  1.0547,  0.3145],\n",
       "          [ 0.1758, -0.7148,  1.2266,  ...,  0.9062, -0.0522,  0.6367]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_28': tensor([[[ 1.9453, -0.8281,  0.6406,  ..., -1.7344,  0.9219,  0.3672],\n",
       "          [ 1.5547,  0.0684, -0.3496,  ...,  0.8047, -0.1953, -1.8281],\n",
       "          [ 0.2812,  0.2207,  0.3281,  ..., -1.1016, -1.8359, -2.7188],\n",
       "          ...,\n",
       "          [-1.1953,  0.4668,  1.6328,  ..., -0.1758, -0.8203,  0.6445],\n",
       "          [ 1.2734,  1.5781,  0.1045,  ..., -0.7969,  1.0547, -1.4922],\n",
       "          [ 0.6953, -0.9258,  2.6562,  ..., -0.3086, -1.7969, -0.5508]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_29': tensor([0.5469], dtype=torch.bfloat16),\n",
       " 'latents_29_start': tensor([[[ 1.0859, -0.4180,  0.3418,  ..., -0.9453,  0.5820,  0.2188],\n",
       "          [ 0.9297,  0.0830, -0.0159,  ...,  0.4082, -0.0884, -0.8516],\n",
       "          [ 0.2598,  0.2773,  0.0608,  ..., -0.4336, -0.5352, -1.1328],\n",
       "          ...,\n",
       "          [-0.9141,  0.0266,  0.6016,  ...,  0.2695, -0.1895,  0.5430],\n",
       "          [ 0.4570,  0.6641, -0.1992,  ...,  0.4824,  1.0312,  0.3457],\n",
       "          [ 0.1621, -0.6953,  1.1719,  ...,  0.9141, -0.0156,  0.6484]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_29': tensor([[[ 1.9688, -0.8320,  0.6445,  ..., -1.7734,  0.9219,  0.3672],\n",
       "          [ 1.5469,  0.0732, -0.3477,  ...,  0.7930, -0.2100, -1.8281],\n",
       "          [ 0.2793,  0.2354,  0.3262,  ..., -1.1250, -1.8359, -2.7188],\n",
       "          ...,\n",
       "          [-1.1953,  0.4746,  1.6172,  ..., -0.1738, -0.8086,  0.6484],\n",
       "          [ 1.2969,  1.5781,  0.0952,  ..., -0.8164,  1.0859, -1.4766],\n",
       "          [ 0.6797, -0.9180,  2.6562,  ..., -0.3145, -1.7812, -0.5508]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_30': tensor([0.5273], dtype=torch.bfloat16),\n",
       " 'latents_30_start': tensor([[[ 1.0469, -0.4004,  0.3281,  ..., -0.9102,  0.5625,  0.2109],\n",
       "          [ 0.8984,  0.0815, -0.0087,  ...,  0.3926, -0.0840, -0.8125],\n",
       "          [ 0.2539,  0.2734,  0.0540,  ..., -0.4102, -0.4961, -1.0781],\n",
       "          ...,\n",
       "          [-0.8906,  0.0168,  0.5664,  ...,  0.2734, -0.1729,  0.5312],\n",
       "          [ 0.4297,  0.6328, -0.2012,  ...,  0.5000,  1.0078,  0.3770],\n",
       "          [ 0.1484, -0.6758,  1.1172,  ...,  0.9219,  0.0212,  0.6602]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_30': tensor([[[ 1.9766, -0.8359,  0.6484,  ..., -1.7812,  0.9219,  0.3691],\n",
       "          [ 1.5469,  0.0693, -0.3359,  ...,  0.7969, -0.2090, -1.8203],\n",
       "          [ 0.2852,  0.2363,  0.3320,  ..., -1.1172, -1.8359, -2.7344],\n",
       "          ...,\n",
       "          [-1.1953,  0.4766,  1.6406,  ..., -0.1855, -0.8203,  0.6484],\n",
       "          [ 1.2891,  1.5781,  0.1064,  ..., -0.8203,  1.0859, -1.4922],\n",
       "          [ 0.6953, -0.9219,  2.6719,  ..., -0.3145, -1.7656, -0.5312]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_31': tensor([0.5078], dtype=torch.bfloat16),\n",
       " 'latents_31_start': tensor([[[ 1.0078, -0.3828,  0.3145,  ..., -0.8711,  0.5430,  0.2031],\n",
       "          [ 0.8672,  0.0801, -0.0015,  ...,  0.3750, -0.0796, -0.7734],\n",
       "          [ 0.2480,  0.2676,  0.0469,  ..., -0.3867, -0.4570, -1.0234],\n",
       "          ...,\n",
       "          [-0.8672,  0.0067,  0.5312,  ...,  0.2773, -0.1553,  0.5156],\n",
       "          [ 0.4023,  0.5977, -0.2031,  ...,  0.5156,  0.9844,  0.4082],\n",
       "          [ 0.1338, -0.6562,  1.0625,  ...,  0.9297,  0.0588,  0.6719]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_31': tensor([[[ 1.9688, -0.8242,  0.6523,  ..., -1.7656,  0.9297,  0.3555],\n",
       "          [ 1.5469,  0.0796, -0.3516,  ...,  0.7969, -0.2051, -1.8281],\n",
       "          [ 0.2754,  0.2344,  0.3262,  ..., -1.1172, -1.8359, -2.7344],\n",
       "          ...,\n",
       "          [-1.1953,  0.4766,  1.6328,  ..., -0.1895, -0.8125,  0.6289],\n",
       "          [ 1.2969,  1.5859,  0.0938,  ..., -0.8125,  1.0781, -1.4766],\n",
       "          [ 0.6836, -0.9375,  2.6562,  ..., -0.2949, -1.7734, -0.5234]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_32': tensor([0.4844], dtype=torch.bfloat16),\n",
       " 'latents_32_start': tensor([[[ 0.9648, -0.3652,  0.3008,  ..., -0.8320,  0.5234,  0.1953],\n",
       "          [ 0.8320,  0.0781,  0.0061,  ...,  0.3574, -0.0752, -0.7344],\n",
       "          [ 0.2422,  0.2617,  0.0398,  ..., -0.3633, -0.4180, -0.9648],\n",
       "          ...,\n",
       "          [-0.8398, -0.0037,  0.4961,  ...,  0.2812, -0.1377,  0.5000],\n",
       "          [ 0.3750,  0.5625, -0.2051,  ...,  0.5352,  0.9609,  0.4395],\n",
       "          [ 0.1191, -0.6367,  1.0078,  ...,  0.9375,  0.0977,  0.6836]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_32': tensor([[[ 1.9688, -0.8242,  0.6523,  ..., -1.7656,  0.9258,  0.3535],\n",
       "          [ 1.5391,  0.0728, -0.3457,  ...,  0.7891, -0.2021, -1.8203],\n",
       "          [ 0.2754,  0.2344,  0.3223,  ..., -1.1172, -1.8281, -2.7344],\n",
       "          ...,\n",
       "          [-1.1875,  0.4688,  1.6250,  ..., -0.1768, -0.8086,  0.6133],\n",
       "          [ 1.2812,  1.5703,  0.1079,  ..., -0.8320,  1.0703, -1.4844],\n",
       "          [ 0.6914, -0.9297,  2.6562,  ..., -0.3027, -1.7656, -0.5156]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_33': tensor([0.4629], dtype=torch.bfloat16),\n",
       " 'latents_33_start': tensor([[[ 0.9219, -0.3477,  0.2871,  ..., -0.7930,  0.5039,  0.1875],\n",
       "          [ 0.7969,  0.0767,  0.0138,  ...,  0.3398, -0.0708, -0.6953],\n",
       "          [ 0.2363,  0.2559,  0.0327,  ..., -0.3379, -0.3770, -0.9023],\n",
       "          ...,\n",
       "          [-0.8125, -0.0141,  0.4609,  ...,  0.2852, -0.1196,  0.4863],\n",
       "          [ 0.3457,  0.5273, -0.2070,  ...,  0.5547,  0.9375,  0.4727],\n",
       "          [ 0.1035, -0.6172,  0.9492,  ...,  0.9453,  0.1367,  0.6953]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_33': tensor([[[ 1.9609, -0.8242,  0.6484,  ..., -1.7578,  0.9258,  0.3477],\n",
       "          [ 1.5312,  0.0684, -0.3379,  ...,  0.7891, -0.2012, -1.8125],\n",
       "          [ 0.2812,  0.2158,  0.3164,  ..., -1.1172, -1.8125, -2.7188],\n",
       "          ...,\n",
       "          [-1.1797,  0.4570,  1.6250,  ..., -0.1826, -0.8086,  0.6055],\n",
       "          [ 1.2812,  1.5625,  0.1025,  ..., -0.8125,  1.0625, -1.4922],\n",
       "          [ 0.6797, -0.9297,  2.6250,  ..., -0.2949, -1.7656, -0.5117]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_34': tensor([0.4375], dtype=torch.bfloat16),\n",
       " 'latents_34_start': tensor([[[ 0.8789, -0.3281,  0.2715,  ..., -0.7539,  0.4824,  0.1797],\n",
       "          [ 0.7617,  0.0752,  0.0215,  ...,  0.3223, -0.0664, -0.6523],\n",
       "          [ 0.2295,  0.2500,  0.0255,  ..., -0.3125, -0.3359, -0.8398],\n",
       "          ...,\n",
       "          [-0.7852, -0.0245,  0.4238,  ...,  0.2891, -0.1011,  0.4727],\n",
       "          [ 0.3164,  0.4922, -0.2090,  ...,  0.5742,  0.9141,  0.5078],\n",
       "          [ 0.0879, -0.5977,  0.8906,  ...,  0.9531,  0.1768,  0.7070]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_34': tensor([[[ 1.9766, -0.8242,  0.6523,  ..., -1.7812,  0.9258,  0.3535],\n",
       "          [ 1.5312,  0.0640, -0.3457,  ...,  0.8086, -0.1934, -1.8047],\n",
       "          [ 0.2715,  0.1992,  0.3105,  ..., -1.0938, -1.8047, -2.7344],\n",
       "          ...,\n",
       "          [-1.1875,  0.4609,  1.6172,  ..., -0.1953, -0.8164,  0.6172],\n",
       "          [ 1.2812,  1.5625,  0.0942,  ..., -0.8242,  1.0781, -1.4922],\n",
       "          [ 0.6562, -0.9531,  2.6406,  ..., -0.3008, -1.7734, -0.4980]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_35': tensor([0.4160], dtype=torch.bfloat16),\n",
       " 'latents_35_start': tensor([[[ 0.8320, -0.3086,  0.2559,  ..., -0.7109,  0.4609,  0.1719],\n",
       "          [ 0.7266,  0.0737,  0.0295,  ...,  0.3027, -0.0620, -0.6094],\n",
       "          [ 0.2236,  0.2451,  0.0183,  ..., -0.2871, -0.2930, -0.7773],\n",
       "          ...,\n",
       "          [-0.7578, -0.0352,  0.3867,  ...,  0.2930, -0.0820,  0.4590],\n",
       "          [ 0.2871,  0.4551, -0.2109,  ...,  0.5938,  0.8906,  0.5430],\n",
       "          [ 0.0728, -0.5742,  0.8281,  ...,  0.9609,  0.2178,  0.7188]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_35': tensor([[[ 1.9766, -0.8242,  0.6484,  ..., -1.7734,  0.9258,  0.3496],\n",
       "          [ 1.5312,  0.0708, -0.3418,  ...,  0.8047, -0.1953, -1.7969],\n",
       "          [ 0.2832,  0.1875,  0.3086,  ..., -1.0938, -1.7891, -2.7344],\n",
       "          ...,\n",
       "          [-1.1719,  0.4551,  1.6172,  ..., -0.1953, -0.8047,  0.6055],\n",
       "          [ 1.2578,  1.5625,  0.0898,  ..., -0.8242,  1.0859, -1.4766],\n",
       "          [ 0.6602, -0.9414,  2.6406,  ..., -0.3164, -1.7656, -0.5078]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_36': tensor([0.3926], dtype=torch.bfloat16),\n",
       " 'latents_36_start': tensor([[[ 0.7852, -0.2891,  0.2402,  ..., -0.6680,  0.4395,  0.1641],\n",
       "          [ 0.6914,  0.0723,  0.0376,  ...,  0.2832, -0.0574, -0.5664],\n",
       "          [ 0.2168,  0.2402,  0.0110,  ..., -0.2617, -0.2500, -0.7109],\n",
       "          ...,\n",
       "          [-0.7305, -0.0459,  0.3477,  ...,  0.2969, -0.0630,  0.4453],\n",
       "          [ 0.2578,  0.4180, -0.2129,  ...,  0.6133,  0.8633,  0.5781],\n",
       "          [ 0.0571, -0.5508,  0.7656,  ...,  0.9688,  0.2598,  0.7305]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_36': tensor([[[ 1.9609, -0.8164,  0.6445,  ..., -1.7656,  0.9102,  0.3477],\n",
       "          [ 1.5234,  0.0654, -0.3320,  ...,  0.8164, -0.2041, -1.7812],\n",
       "          [ 0.2734,  0.1836,  0.3066,  ..., -1.0938, -1.7891, -2.7031],\n",
       "          ...,\n",
       "          [-1.1875,  0.4688,  1.6016,  ..., -0.2051, -0.8008,  0.5977],\n",
       "          [ 1.2500,  1.5234,  0.0747,  ..., -0.8438,  1.0781, -1.4922],\n",
       "          [ 0.6484, -0.9219,  2.6250,  ..., -0.3027, -1.7812, -0.5078]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_37': tensor([0.3652], dtype=torch.bfloat16),\n",
       " 'latents_37_start': tensor([[[ 0.7383, -0.2695,  0.2246,  ..., -0.6250,  0.4180,  0.1553],\n",
       "          [ 0.6562,  0.0708,  0.0457,  ...,  0.2637, -0.0525, -0.5234],\n",
       "          [ 0.2100,  0.2354,  0.0035,  ..., -0.2354, -0.2061, -0.6445],\n",
       "          ...,\n",
       "          [-0.7031, -0.0574,  0.3086,  ...,  0.3027, -0.0435,  0.4316],\n",
       "          [ 0.2275,  0.3809, -0.2148,  ...,  0.6328,  0.8359,  0.6133],\n",
       "          [ 0.0413, -0.5273,  0.7031,  ...,  0.9766,  0.3027,  0.7422]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_37': tensor([[[ 1.9922, -0.8320,  0.6523,  ..., -1.7812,  0.9336,  0.3613],\n",
       "          [ 1.5312,  0.0713, -0.3477,  ...,  0.8281, -0.1963, -1.7891],\n",
       "          [ 0.2754,  0.1777,  0.2930,  ..., -1.0859, -1.7734, -2.7031],\n",
       "          ...,\n",
       "          [-1.1875,  0.4590,  1.6172,  ..., -0.1855, -0.8086,  0.5820],\n",
       "          [ 1.2422,  1.5391,  0.0552,  ..., -0.8633,  1.0781, -1.5156],\n",
       "          [ 0.6602, -0.9219,  2.6250,  ..., -0.3047, -1.7734, -0.5000]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_38': tensor([0.3418], dtype=torch.bfloat16),\n",
       " 'latents_38_start': tensor([[[ 0.6875, -0.2490,  0.2080,  ..., -0.5820,  0.3945,  0.1465],\n",
       "          [ 0.6172,  0.0688,  0.0544,  ...,  0.2432, -0.0476, -0.4785],\n",
       "          [ 0.2031,  0.2305, -0.0038,  ..., -0.2080, -0.1621, -0.5781],\n",
       "          ...,\n",
       "          [-0.6719, -0.0688,  0.2676,  ...,  0.3066, -0.0232,  0.4180],\n",
       "          [ 0.1963,  0.3418, -0.2158,  ...,  0.6562,  0.8086,  0.6523],\n",
       "          [ 0.0248, -0.5039,  0.6367,  ...,  0.9844,  0.3477,  0.7539]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_38': tensor([[[ 1.9766, -0.8477,  0.6484,  ..., -1.7812,  0.9531,  0.3789],\n",
       "          [ 1.5234,  0.0781, -0.3496,  ...,  0.8281, -0.1973, -1.7734],\n",
       "          [ 0.2617,  0.1660,  0.2949,  ..., -1.0547, -1.7422, -2.6875],\n",
       "          ...,\n",
       "          [-1.1641,  0.4668,  1.6328,  ..., -0.1836, -0.8125,  0.5547],\n",
       "          [ 1.2344,  1.5312,  0.0752,  ..., -0.8867,  1.0234, -1.5156],\n",
       "          [ 0.6406, -0.9219,  2.6250,  ..., -0.2988, -1.7812, -0.5117]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_39': tensor([0.3164], dtype=torch.bfloat16),\n",
       " 'latents_39_start': tensor([[[ 0.6367, -0.2275,  0.1914,  ..., -0.5352,  0.3711,  0.1367],\n",
       "          [ 0.5781,  0.0669,  0.0635,  ...,  0.2217, -0.0425, -0.4336],\n",
       "          [ 0.1963,  0.2266, -0.0114,  ..., -0.1807, -0.1172, -0.5078],\n",
       "          ...,\n",
       "          [-0.6406, -0.0811,  0.2256,  ...,  0.3105, -0.0023,  0.4043],\n",
       "          [ 0.1641,  0.3027, -0.2178,  ...,  0.6797,  0.7812,  0.6914],\n",
       "          [ 0.0083, -0.4805,  0.5703,  ...,  0.9922,  0.3926,  0.7656]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_39': tensor([[[ 1.9531, -0.8320,  0.6523,  ..., -1.7500,  0.9336,  0.3965],\n",
       "          [ 1.5156,  0.0767, -0.3691,  ...,  0.8281, -0.1777, -1.7500],\n",
       "          [ 0.2637,  0.1729,  0.2734,  ..., -1.0234, -1.6797, -2.6250],\n",
       "          ...,\n",
       "          [-1.1406,  0.4531,  1.6016,  ..., -0.1973, -0.8164,  0.5234],\n",
       "          [ 1.2344,  1.5078,  0.0593,  ..., -0.8906,  1.0234, -1.4844],\n",
       "          [ 0.6367, -0.9023,  2.5938,  ..., -0.2910, -1.7500, -0.5078]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_40': tensor([0.2891], dtype=torch.bfloat16),\n",
       " 'latents_40_start': tensor([[[ 0.5859, -0.2061,  0.1738,  ..., -0.4883,  0.3457,  0.1260],\n",
       "          [ 0.5391,  0.0649,  0.0732,  ...,  0.2002, -0.0378, -0.3867],\n",
       "          [ 0.1895,  0.2217, -0.0186,  ..., -0.1543, -0.0732, -0.4395],\n",
       "          ...,\n",
       "          [-0.6094, -0.0928,  0.1836,  ...,  0.3164,  0.0192,  0.3906],\n",
       "          [ 0.1318,  0.2637, -0.2197,  ...,  0.7031,  0.7539,  0.7305],\n",
       "          [-0.0084, -0.4570,  0.5039,  ...,  1.0000,  0.4375,  0.7773]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_40': tensor([[[ 1.9766, -0.8555,  0.6445,  ..., -1.7500,  0.9414,  0.4004],\n",
       "          [ 1.5234,  0.0693, -0.3613,  ...,  0.8672, -0.1650, -1.7266],\n",
       "          [ 0.2598,  0.1660,  0.2734,  ..., -1.0156, -1.7031, -2.6250],\n",
       "          ...,\n",
       "          [-1.1641,  0.4531,  1.6016,  ..., -0.1611, -0.8125,  0.5039],\n",
       "          [ 1.2109,  1.5156,  0.0391,  ..., -0.8906,  0.9883, -1.4688],\n",
       "          [ 0.6250, -0.9102,  2.6094,  ..., -0.2930, -1.7578, -0.4922]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_41': tensor([0.2617], dtype=torch.bfloat16),\n",
       " 'latents_41_start': tensor([[[ 0.5312, -0.1826,  0.1562,  ..., -0.4414,  0.3203,  0.1152],\n",
       "          [ 0.4980,  0.0630,  0.0830,  ...,  0.1768, -0.0334, -0.3398],\n",
       "          [ 0.1826,  0.2168, -0.0259,  ..., -0.1270, -0.0273, -0.3691],\n",
       "          ...,\n",
       "          [-0.5781, -0.1050,  0.1406,  ...,  0.3203,  0.0410,  0.3770],\n",
       "          [ 0.0991,  0.2227, -0.2207,  ...,  0.7266,  0.7266,  0.7695],\n",
       "          [-0.0253, -0.4316,  0.4336,  ...,  1.0078,  0.4844,  0.7891]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_41': tensor([[[ 1.9375, -0.8555,  0.6367,  ..., -1.7266,  0.9531,  0.4297],\n",
       "          [ 1.5312,  0.0659, -0.3633,  ...,  0.8711, -0.1660, -1.7266],\n",
       "          [ 0.2520,  0.1826,  0.2676,  ..., -0.9844, -1.6328, -2.5781],\n",
       "          ...,\n",
       "          [-1.1484,  0.4453,  1.5859,  ..., -0.1562, -0.8281,  0.4922],\n",
       "          [ 1.1797,  1.5078,  0.0229,  ..., -0.9102,  0.9531, -1.4609],\n",
       "          [ 0.6211, -0.9102,  2.5938,  ..., -0.2832, -1.7266, -0.4727]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_42': tensor([0.2354], dtype=torch.bfloat16),\n",
       " 'latents_42_start': tensor([[[ 0.4785, -0.1592,  0.1387,  ..., -0.3945,  0.2949,  0.1035],\n",
       "          [ 0.4551,  0.0613,  0.0928,  ...,  0.1523, -0.0288, -0.2930],\n",
       "          [ 0.1758,  0.2119, -0.0332,  ..., -0.0996,  0.0178, -0.2969],\n",
       "          ...,\n",
       "          [-0.5469, -0.1172,  0.0967,  ...,  0.3242,  0.0640,  0.3633],\n",
       "          [ 0.0664,  0.1816, -0.2217,  ...,  0.7500,  0.6992,  0.8086],\n",
       "          [-0.0425, -0.4062,  0.3613,  ...,  1.0156,  0.5312,  0.8008]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_42': tensor([[[ 1.9297, -0.8594,  0.6289,  ..., -1.7031,  0.9609,  0.4297],\n",
       "          [ 1.5000,  0.0811, -0.3652,  ...,  0.8711, -0.1494, -1.7031],\n",
       "          [ 0.2246,  0.1543,  0.2695,  ..., -0.9492, -1.6328, -2.5312],\n",
       "          ...,\n",
       "          [-1.1328,  0.4395,  1.5781,  ..., -0.1680, -0.8398,  0.4453],\n",
       "          [ 1.1484,  1.4766,  0.0073,  ..., -0.9648,  0.8984, -1.4219],\n",
       "          [ 0.5938, -0.8672,  2.5312,  ..., -0.2949, -1.7031, -0.4766]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_43': tensor([0.2070], dtype=torch.bfloat16),\n",
       " 'latents_43_start': tensor([[[ 0.4238, -0.1348,  0.1211,  ..., -0.3457,  0.2676,  0.0913],\n",
       "          [ 0.4121,  0.0591,  0.1030,  ...,  0.1279, -0.0245, -0.2441],\n",
       "          [ 0.1699,  0.2080, -0.0408,  ..., -0.0728,  0.0640, -0.2246],\n",
       "          ...,\n",
       "          [-0.5156, -0.1299,  0.0520,  ...,  0.3281,  0.0879,  0.3516],\n",
       "          [ 0.0339,  0.1396, -0.2217,  ...,  0.7773,  0.6719,  0.8477],\n",
       "          [-0.0593, -0.3809,  0.2891,  ...,  1.0234,  0.5781,  0.8125]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_43': tensor([[[ 1.9141, -0.8711,  0.6328,  ..., -1.6484,  0.9883,  0.4453],\n",
       "          [ 1.4844,  0.0693, -0.4004,  ...,  0.8867, -0.1152, -1.6797],\n",
       "          [ 0.2432,  0.1484,  0.2090,  ..., -0.8867, -1.5938, -2.4531],\n",
       "          ...,\n",
       "          [-1.1094,  0.4453,  1.5703,  ..., -0.1865, -0.8594,  0.3906],\n",
       "          [ 1.0938,  1.4375, -0.0374,  ..., -0.9648,  0.8359, -1.4531],\n",
       "          [ 0.5898, -0.8516,  2.4688,  ..., -0.2773, -1.6484, -0.4746]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_44': tensor([0.1777], dtype=torch.bfloat16),\n",
       " 'latents_44_start': tensor([[[ 0.3672, -0.1094,  0.1025,  ..., -0.2969,  0.2393,  0.0781],\n",
       "          [ 0.3691,  0.0571,  0.1147,  ...,  0.1021, -0.0212, -0.1953],\n",
       "          [ 0.1631,  0.2041, -0.0469,  ..., -0.0469,  0.1104, -0.1533],\n",
       "          ...,\n",
       "          [-0.4844, -0.1426,  0.0063,  ...,  0.3340,  0.1128,  0.3398],\n",
       "          [ 0.0022,  0.0977, -0.2207,  ...,  0.8047,  0.6484,  0.8906],\n",
       "          [-0.0762, -0.3555,  0.2168,  ...,  1.0312,  0.6250,  0.8281]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_44': tensor([[[ 1.8984, -0.8984,  0.5977,  ..., -1.6094,  0.9805,  0.4551],\n",
       "          [ 1.4609,  0.0474, -0.4160,  ...,  0.9102, -0.1060, -1.6484],\n",
       "          [ 0.2119,  0.0791,  0.2021,  ..., -0.8359, -1.5391, -2.4219],\n",
       "          ...,\n",
       "          [-1.0859,  0.4395,  1.5391,  ..., -0.2002, -0.8516,  0.3359],\n",
       "          [ 1.0625,  1.4219, -0.0486,  ..., -0.9609,  0.7930, -1.4453],\n",
       "          [ 0.5898, -0.8008,  2.4219,  ..., -0.3223, -1.5703, -0.4609]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_45': tensor([0.1484], dtype=torch.bfloat16),\n",
       " 'latents_45_start': tensor([[[ 0.3105, -0.0825,  0.0850,  ..., -0.2490,  0.2100,  0.0645],\n",
       "          [ 0.3262,  0.0557,  0.1270,  ...,  0.0747, -0.0181, -0.1465],\n",
       "          [ 0.1562,  0.2021, -0.0530,  ..., -0.0219,  0.1562, -0.0811],\n",
       "          ...,\n",
       "          [-0.4512, -0.1553, -0.0398,  ...,  0.3398,  0.1387,  0.3301],\n",
       "          [-0.0295,  0.0552, -0.2197,  ...,  0.8320,  0.6250,  0.9336],\n",
       "          [-0.0938, -0.3320,  0.1445,  ...,  1.0391,  0.6719,  0.8438]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_45': tensor([[[ 1.8516, -0.8984,  0.5547,  ..., -1.5547,  0.9688,  0.4727],\n",
       "          [ 1.4219,  0.0500, -0.4258,  ...,  0.9141, -0.0588, -1.5703],\n",
       "          [ 0.1592,  0.0850,  0.1924,  ..., -0.7500, -1.4766, -2.3125],\n",
       "          ...,\n",
       "          [-1.0469,  0.4414,  1.4766,  ..., -0.1494, -0.8438,  0.3047],\n",
       "          [ 0.9883,  1.4062, -0.0874,  ..., -0.9844,  0.7422, -1.4297],\n",
       "          [ 0.5742, -0.7578,  2.3125,  ..., -0.3301, -1.3672, -0.4238]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_46': tensor([0.1172], dtype=torch.bfloat16),\n",
       " 'latents_46_start': tensor([[[ 0.2539, -0.0549,  0.0679,  ..., -0.2012,  0.1807,  0.0500],\n",
       "          [ 0.2832,  0.0542,  0.1396,  ...,  0.0469, -0.0162, -0.0986],\n",
       "          [ 0.1514,  0.1992, -0.0588,  ...,  0.0011,  0.2012, -0.0103],\n",
       "          ...,\n",
       "          [-0.4199, -0.1689, -0.0850,  ...,  0.3438,  0.1641,  0.3203],\n",
       "          [-0.0598,  0.0122, -0.2168,  ...,  0.8633,  0.6016,  0.9766],\n",
       "          [-0.1113, -0.3086,  0.0737,  ...,  1.0469,  0.7148,  0.8555]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_46': tensor([[[ 1.7734, -0.9492,  0.5430,  ..., -1.5312,  0.9727,  0.5273],\n",
       "          [ 1.4219,  0.0054, -0.4844,  ...,  0.9297,  0.0198, -1.4531],\n",
       "          [ 0.1377, -0.0330,  0.1299,  ..., -0.6914, -1.3906, -2.2188],\n",
       "          ...,\n",
       "          [-0.9727,  0.4297,  1.3906,  ..., -0.1172, -0.8164,  0.2295],\n",
       "          [ 0.8945,  1.3516, -0.1758,  ..., -0.9961,  0.6445, -1.5078],\n",
       "          [ 0.5234, -0.7422,  2.2031,  ..., -0.3848, -1.1641, -0.4883]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_47': tensor([0.0854], dtype=torch.bfloat16),\n",
       " 'latents_47_start': tensor([[[ 0.1982, -0.0250,  0.0508,  ..., -0.1523,  0.1504,  0.0334],\n",
       "          [ 0.2383,  0.0540,  0.1553,  ...,  0.0176, -0.0168, -0.0530],\n",
       "          [ 0.1475,  0.2002, -0.0630,  ...,  0.0228,  0.2451,  0.0596],\n",
       "          ...,\n",
       "          [-0.3887, -0.1826, -0.1289,  ...,  0.3477,  0.1895,  0.3125],\n",
       "          [-0.0879, -0.0303, -0.2109,  ...,  0.8945,  0.5820,  1.0234],\n",
       "          [-0.1279, -0.2852,  0.0044,  ...,  1.0625,  0.7500,  0.8711]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_47': tensor([[[ 1.6250, -0.9688,  0.4160,  ..., -1.4609,  0.9961,  0.5742],\n",
       "          [ 1.3594,  0.0728, -0.5430,  ...,  0.9062,  0.0530, -1.3438],\n",
       "          [ 0.1553, -0.1787,  0.0908,  ..., -0.5820, -1.1875, -1.9688],\n",
       "          ...,\n",
       "          [-0.8281,  0.4160,  1.2422,  ..., -0.0122, -0.7500,  0.1396],\n",
       "          [ 0.7734,  1.2812, -0.2295,  ..., -0.9883,  0.5039, -1.4844],\n",
       "          [ 0.4453, -0.6719,  1.9688,  ..., -0.4180, -0.9141, -0.6211]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_48': tensor([0.0532], dtype=torch.bfloat16),\n",
       " 'latents_48_start': tensor([[[ 0.1455,  0.0065,  0.0374,  ..., -0.1050,  0.1182,  0.0148],\n",
       "          [ 0.1943,  0.0515,  0.1729,  ..., -0.0118, -0.0186, -0.0093],\n",
       "          [ 0.1426,  0.2061, -0.0659,  ...,  0.0417,  0.2832,  0.1235],\n",
       "          ...,\n",
       "          [-0.3613, -0.1963, -0.1689,  ...,  0.3477,  0.2139,  0.3086],\n",
       "          [-0.1133, -0.0718, -0.2031,  ...,  0.9258,  0.5664,  1.0703],\n",
       "          [-0.1426, -0.2637, -0.0596,  ...,  1.0781,  0.7812,  0.8906]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred_48': tensor([[[ 1.2031, -0.9297,  0.2559,  ..., -1.4141,  0.8906,  0.4258],\n",
       "          [ 1.1016,  0.0625, -0.3242,  ...,  0.8047,  0.1318, -1.1094],\n",
       "          [ 0.1094, -0.2695,  0.2334,  ..., -0.5820, -1.1016, -1.5234],\n",
       "          ...,\n",
       "          [-0.5938,  0.4551,  1.0938,  ...,  0.0281, -0.6289,  0.1357],\n",
       "          [ 0.6523,  1.0625, -0.2275,  ..., -1.0938,  0.4297, -1.3984],\n",
       "          [ 0.3906, -0.4180,  1.5391,  ..., -0.5742, -0.6250, -0.7852]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 't_49': tensor([0.0200], dtype=torch.bfloat16),\n",
       " 'latents_49_start': tensor([[[ 1.0547e-01,  3.7354e-02,  2.8809e-02,  ..., -5.8105e-02,\n",
       "            8.8867e-02,  6.1035e-04],\n",
       "          [ 1.5820e-01,  4.9316e-02,  1.8359e-01,  ..., -3.8574e-02,\n",
       "           -2.2949e-02,  2.7588e-02],\n",
       "          [ 1.3867e-01,  2.1484e-01, -7.3730e-02,  ...,  6.1035e-02,\n",
       "            3.2031e-01,  1.7383e-01],\n",
       "          ...,\n",
       "          [-3.4180e-01, -2.1094e-01, -2.0508e-01,  ...,  3.4766e-01,\n",
       "            2.3438e-01,  3.0469e-01],\n",
       "          [-1.3477e-01, -1.0693e-01, -1.9531e-01,  ...,  9.6094e-01,\n",
       "            5.5078e-01,  1.1172e+00],\n",
       "          [-1.5527e-01, -2.5000e-01, -1.1035e-01,  ...,  1.0938e+00,\n",
       "            8.0078e-01,  9.1797e-01]]], dtype=torch.bfloat16),\n",
       " 'noise_pred_49': tensor([[[ 0.7461, -0.5586,  0.2197,  ..., -1.0469,  0.7109,  0.4902],\n",
       "          [ 0.6094,  0.0464, -0.1650,  ...,  0.4980,  0.2314, -0.9414],\n",
       "          [ 0.1064, -0.2109,  0.1846,  ..., -0.3633, -0.8086, -1.0234],\n",
       "          ...,\n",
       "          [-0.2559,  0.3711,  0.7461,  ..., -0.2217, -0.2988,  0.0339],\n",
       "          [ 0.4980,  0.5156, -0.0260,  ..., -1.1250,  0.1064, -1.1250],\n",
       "          [ 0.2471,  0.0179,  0.6875,  ..., -0.7188, -0.5898, -0.8672]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'output': tensor([[[ 0.0903,  0.0486,  0.0244,  ..., -0.0371,  0.0747, -0.0092],\n",
       "          [ 0.1465,  0.0483,  0.1865,  ..., -0.0486, -0.0276,  0.0464],\n",
       "          [ 0.1367,  0.2188, -0.0776,  ...,  0.0684,  0.3359,  0.1943],\n",
       "          ...,\n",
       "          [-0.3359, -0.2188, -0.2197,  ...,  0.3516,  0.2402,  0.3047],\n",
       "          [-0.1445, -0.1172, -0.1943,  ...,  0.9844,  0.5469,  1.1406],\n",
       "          [-0.1602, -0.2500, -0.1240,  ...,  1.1094,  0.8125,  0.9336]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'height': 576,\n",
       " 'width': 448,\n",
       " 't': tensor([1.], dtype=torch.bfloat16),\n",
       " 'latents_start': tensor([[[ 1.9766, -0.8047,  0.6367,  ..., -1.7422,  1.0469,  0.3809],\n",
       "          [ 1.6562,  0.1147, -0.1562,  ...,  0.7539, -0.1768, -1.6953],\n",
       "          [ 0.3984,  0.3926,  0.1914,  ..., -0.9258, -1.3281, -2.3281],\n",
       "          ...,\n",
       "          [-1.4766,  0.2539,  1.3359,  ...,  0.1797, -0.6250,  0.7617],\n",
       "          [ 1.0391,  1.3672, -0.1572,  ...,  0.1152,  1.4688, -0.2852],\n",
       "          [ 0.4941, -1.1094,  2.3438,  ...,  0.8281, -0.8320,  0.4258]]],\n",
       "        dtype=torch.bfloat16),\n",
       " 'noise_pred': tensor([[[ 1.8906, -0.8945,  0.5938,  ..., -1.7578,  1.0078,  0.2539],\n",
       "          [ 1.5781,  0.0278, -0.2793,  ...,  0.7305, -0.1553, -1.7969],\n",
       "          [ 0.3027,  0.2949,  0.1621,  ..., -1.0625, -1.5938, -2.6406],\n",
       "          ...,\n",
       "          [-1.2578,  0.5352,  1.5859,  ..., -0.2773, -1.0312,  0.3203],\n",
       "          [ 1.2734,  1.5312,  0.0728,  ..., -0.6211,  0.8984, -1.1562],\n",
       "          [ 0.6172, -0.9336,  2.6719,  ..., -0.1050, -1.8672, -0.3691]]],\n",
       "        dtype=torch.bfloat16)}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f19ae9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}