Spaces:
Running
on
Zero
Running
on
Zero
File size: 74,167 Bytes
49cbc74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "e1e781e9",
"metadata": {},
"outputs": [],
"source": [
"%cd /home/ubuntu/Qwen-Image-Edit-Angles"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6192ee5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4941"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import glob\n",
"from pathlib import Path\n",
"\n",
"base_data = Path(\"/data/regression_output\")\n",
"\n",
"all_reg = list(base_data.glob(\"*.pt\"))\n",
"max_ind = max([int(reg_pth.stem) for reg_pth in all_reg])\n",
"\n",
"max_ind"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b5124900",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"prompt_embeds\n",
"prompt_embeds_mask\n",
"noise\n",
"image_latents\n",
"vae_image_sizes\n",
"img_shapes\n",
"txt_seq_lens\n",
"t_0\n",
"latents_0_start\n",
"noise_pred_0\n",
"t_1\n",
"latents_1_start\n",
"noise_pred_1\n",
"t_2\n",
"latents_2_start\n",
"noise_pred_2\n",
"t_3\n",
"latents_3_start\n",
"noise_pred_3\n",
"t_4\n",
"latents_4_start\n",
"noise_pred_4\n",
"t_5\n",
"latents_5_start\n",
"noise_pred_5\n",
"t_6\n",
"latents_6_start\n",
"noise_pred_6\n",
"t_7\n",
"latents_7_start\n",
"noise_pred_7\n",
"t_8\n",
"latents_8_start\n",
"noise_pred_8\n",
"t_9\n",
"latents_9_start\n",
"noise_pred_9\n",
"t_10\n",
"latents_10_start\n",
"noise_pred_10\n",
"t_11\n",
"latents_11_start\n",
"noise_pred_11\n",
"t_12\n",
"latents_12_start\n",
"noise_pred_12\n",
"t_13\n",
"latents_13_start\n",
"noise_pred_13\n",
"t_14\n",
"latents_14_start\n",
"noise_pred_14\n",
"t_15\n",
"latents_15_start\n",
"noise_pred_15\n",
"t_16\n",
"latents_16_start\n",
"noise_pred_16\n",
"t_17\n",
"latents_17_start\n",
"noise_pred_17\n",
"t_18\n",
"latents_18_start\n",
"noise_pred_18\n",
"t_19\n",
"latents_19_start\n",
"noise_pred_19\n",
"t_20\n",
"latents_20_start\n",
"noise_pred_20\n",
"t_21\n",
"latents_21_start\n",
"noise_pred_21\n",
"t_22\n",
"latents_22_start\n",
"noise_pred_22\n",
"t_23\n",
"latents_23_start\n",
"noise_pred_23\n",
"t_24\n",
"latents_24_start\n",
"noise_pred_24\n",
"t_25\n",
"latents_25_start\n",
"noise_pred_25\n",
"t_26\n",
"latents_26_start\n",
"noise_pred_26\n",
"t_27\n",
"latents_27_start\n",
"noise_pred_27\n",
"t_28\n",
"latents_28_start\n",
"noise_pred_28\n",
"t_29\n",
"latents_29_start\n",
"noise_pred_29\n",
"t_30\n",
"latents_30_start\n",
"noise_pred_30\n",
"t_31\n",
"latents_31_start\n",
"noise_pred_31\n",
"t_32\n",
"latents_32_start\n",
"noise_pred_32\n",
"t_33\n",
"latents_33_start\n",
"noise_pred_33\n",
"t_34\n",
"latents_34_start\n",
"noise_pred_34\n",
"t_35\n",
"latents_35_start\n",
"noise_pred_35\n",
"t_36\n",
"latents_36_start\n",
"noise_pred_36\n",
"t_37\n",
"latents_37_start\n",
"noise_pred_37\n",
"t_38\n",
"latents_38_start\n",
"noise_pred_38\n",
"t_39\n",
"latents_39_start\n",
"noise_pred_39\n",
"t_40\n",
"latents_40_start\n",
"noise_pred_40\n",
"t_41\n",
"latents_41_start\n",
"noise_pred_41\n",
"t_42\n",
"latents_42_start\n",
"noise_pred_42\n",
"t_43\n",
"latents_43_start\n",
"noise_pred_43\n",
"t_44\n",
"latents_44_start\n",
"noise_pred_44\n",
"t_45\n",
"latents_45_start\n",
"noise_pred_45\n",
"t_46\n",
"latents_46_start\n",
"noise_pred_46\n",
"t_47\n",
"latents_47_start\n",
"noise_pred_47\n",
"t_48\n",
"latents_48_start\n",
"noise_pred_48\n",
"t_49\n",
"latents_49_start\n",
"noise_pred_49\n",
"output\n",
"height\n",
"width\n"
]
}
],
"source": [
"import torch\n",
"\n",
"out = all_reg[0]\n",
"out_dict = torch.load(out)\n",
"for k in out_dict.keys():\n",
" print(k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74f693db",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'003329'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "da107d9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"69G\t/data/regression_output\n"
]
}
],
"source": [
"!du -h {base_data}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "269c0bfb",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 16,
"id": "5964bf2b",
"metadata": {},
"outputs": [],
"source": [
"class RegressionSource:\n",
" # WIP\n",
"\n",
" def __init__(self, data_dir, gen_steps=50):\n",
" if not isinstance(data_dir, Path):\n",
" data_dir = Path(data_dir)\n",
" self.data_paths = list(data_dir.glob(\"*.pt\"))\n",
" self.gen_steps = gen_steps\n",
" self._len = gen_steps * len(self.data_paths)\n",
" \n",
" def __len__(self):\n",
" return self._len\n",
" \n",
" def __getitem__(self, idx):\n",
" data_idx = idx // self.gen_steps\n",
" step_idx = idx % self.gen_steps\n",
" out_dict = torch.load(self.data_paths[data_idx])\n",
" t = out_dict.pop(f\"t_{step_idx}\")\n",
" latents_start = out_dict.pop(f\"latents_{step_idx}_start\")\n",
" noise_pred = out_dict.pop(f\"noise_pred_{step_idx}\")\n",
" out_dict[\"t\"] = t\n",
" out_dict[\"latents_start\"] = latents_start\n",
" out_dict[\"noise_pred\"] = noise_pred\n",
" return out_dict\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b62e7bec",
"metadata": {},
"outputs": [],
"source": [
"src = RegressionSource(\"/data/regression_output\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ee68ab3",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "9738e1d4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'prompt_embeds': tensor([[[ 3.2188, 3.4375, 3.1719, ..., 0.3535, 1.7812, 2.0312],\n",
" [ 3.0938, 1.9297, 0.7031, ..., 2.0625, -0.2314, 1.2266],\n",
" [ 2.6250, 1.7031, 3.5625, ..., 0.8828, 2.1719, 1.4766],\n",
" ...,\n",
" [ 4.7812, 0.1689, 4.4688, ..., 5.0000, -1.8359, -0.7500],\n",
" [-0.0654, 2.1406, -1.4922, ..., 0.7930, 3.9844, 1.6406],\n",
" [-2.7031, 1.5547, 2.6094, ..., -0.0481, 0.1582, 0.7383]]],\n",
" dtype=torch.bfloat16),\n",
" 'prompt_embeds_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
" 'noise': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
" [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
" [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
" ...,\n",
" [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
" [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
" [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
" dtype=torch.bfloat16),\n",
" 'image_latents': tensor([[[ 0.1719, 0.0194, 0.0084, ..., -0.1494, 0.0552, 0.2295],\n",
" [ 0.1777, 0.1406, 0.1592, ..., 0.1260, -0.2412, -0.0041],\n",
" [ 0.1187, 0.2324, 0.1104, ..., 0.0801, 0.3516, 0.4414],\n",
" ...,\n",
" [-0.0972, -0.3242, -0.3027, ..., 0.3672, 0.1699, 0.4004],\n",
" [-0.1221, -0.0125, -0.3867, ..., 0.7031, 0.8477, 0.8320],\n",
" [-0.1416, -0.1914, -0.3359, ..., 0.9883, 1.3359, 0.7422]]],\n",
" dtype=torch.bfloat16),\n",
" 'vae_image_sizes': [(448, 576)],\n",
" 'img_shapes': [[(1, 36, 28), (1, 36, 28)]],\n",
" 'txt_seq_lens': [228],\n",
" 't_1': tensor([0.9883], dtype=torch.bfloat16),\n",
" 'latents_1_start': tensor([[[ 1.9531, -0.7930, 0.6289, ..., -1.7188, 1.0312, 0.3770],\n",
" [ 1.6406, 0.1143, -0.1533, ..., 0.7461, -0.1748, -1.6719],\n",
" [ 0.3945, 0.3887, 0.1895, ..., -0.9141, -1.3125, -2.2969],\n",
" ...,\n",
" [-1.4609, 0.2471, 1.3203, ..., 0.1826, -0.6133, 0.7578],\n",
" [ 1.0234, 1.3516, -0.1582, ..., 0.1226, 1.4609, -0.2715],\n",
" [ 0.4863, -1.1016, 2.3125, ..., 0.8281, -0.8086, 0.4297]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_1': tensor([[[ 1.9062, -0.9102, 0.5742, ..., -1.7422, 1.0625, 0.3359],\n",
" [ 1.5859, 0.0306, -0.2637, ..., 0.7539, -0.1768, -1.7969],\n",
" [ 0.3184, 0.3066, 0.1592, ..., -1.0391, -1.5391, -2.5625],\n",
" ...,\n",
" [-1.2734, 0.4941, 1.5781, ..., -0.2344, -1.0156, 0.3477],\n",
" [ 1.2422, 1.5234, 0.0510, ..., -0.5820, 0.9219, -1.0859],\n",
" [ 0.6172, -0.9336, 2.5781, ..., -0.0801, -1.7734, -0.3730]]],\n",
" dtype=torch.bfloat16),\n",
" 't_2': tensor([0.9766], dtype=torch.bfloat16),\n",
" 'latents_2_start': tensor([[[ 1.9297, -0.7812, 0.6211, ..., -1.6953, 1.0156, 0.3730],\n",
" [ 1.6250, 0.1138, -0.1504, ..., 0.7383, -0.1729, -1.6484],\n",
" [ 0.3906, 0.3848, 0.1875, ..., -0.9023, -1.2969, -2.2656],\n",
" ...,\n",
" [-1.4453, 0.2412, 1.3047, ..., 0.1855, -0.6016, 0.7539],\n",
" [ 1.0078, 1.3359, -0.1592, ..., 0.1299, 1.4531, -0.2578],\n",
" [ 0.4785, -1.0938, 2.2812, ..., 0.8281, -0.7891, 0.4336]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_2': tensor([[[ 1.8984, -0.9219, 0.5664, ..., -1.7188, 1.0703, 0.3633],\n",
" [ 1.5859, 0.0256, -0.2539, ..., 0.7578, -0.1719, -1.7656],\n",
" [ 0.3105, 0.3027, 0.1611, ..., -1.0000, -1.4688, -2.4688],\n",
" ...,\n",
" [-1.2969, 0.4453, 1.5625, ..., -0.1934, -0.9883, 0.4082],\n",
" [ 1.2188, 1.4844, -0.0028, ..., -0.4492, 1.0312, -0.9180],\n",
" [ 0.5820, -1.0156, 2.5156, ..., 0.1885, -1.5391, -0.1602]]],\n",
" dtype=torch.bfloat16),\n",
" 't_3': tensor([0.9648], dtype=torch.bfloat16),\n",
" 'latents_3_start': tensor([[[ 1.9062, -0.7695, 0.6133, ..., -1.6719, 1.0000, 0.3691],\n",
" [ 1.6094, 0.1133, -0.1475, ..., 0.7305, -0.1709, -1.6250],\n",
" [ 0.3867, 0.3809, 0.1855, ..., -0.8906, -1.2812, -2.2344],\n",
" ...,\n",
" [-1.4297, 0.2354, 1.2891, ..., 0.1875, -0.5898, 0.7500],\n",
" [ 0.9922, 1.3203, -0.1592, ..., 0.1357, 1.4375, -0.2461],\n",
" [ 0.4707, -1.0781, 2.2500, ..., 0.8242, -0.7695, 0.4355]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_3': tensor([[[ 1.8984, -0.9180, 0.5430, ..., -1.7031, 1.0938, 0.3691],\n",
" [ 1.5703, 0.0308, -0.2676, ..., 0.7812, -0.1602, -1.7422],\n",
" [ 0.3164, 0.2949, 0.1514, ..., -0.9922, -1.4609, -2.4531],\n",
" ...,\n",
" [-1.3203, 0.4277, 1.5234, ..., -0.1611, -0.9688, 0.4434],\n",
" [ 1.1875, 1.4609, -0.0179, ..., -0.4355, 1.0312, -0.8867],\n",
" [ 0.5547, -1.0234, 2.4844, ..., 0.2344, -1.4844, -0.1025]]],\n",
" dtype=torch.bfloat16),\n",
" 't_4': tensor([0.9531], dtype=torch.bfloat16),\n",
" 'latents_4_start': tensor([[[ 1.8828, -0.7578, 0.6055, ..., -1.6484, 0.9844, 0.3652],\n",
" [ 1.5859, 0.1128, -0.1445, ..., 0.7188, -0.1689, -1.6016],\n",
" [ 0.3828, 0.3770, 0.1836, ..., -0.8789, -1.2656, -2.2031],\n",
" ...,\n",
" [-1.4141, 0.2305, 1.2734, ..., 0.1895, -0.5781, 0.7461],\n",
" [ 0.9766, 1.3047, -0.1592, ..., 0.1416, 1.4219, -0.2354],\n",
" [ 0.4629, -1.0625, 2.2188, ..., 0.8203, -0.7500, 0.4375]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_4': tensor([[[ 1.8984, -0.9141, 0.5508, ..., -1.7109, 1.0859, 0.3672],\n",
" [ 1.5625, 0.0238, -0.2754, ..., 0.7656, -0.1768, -1.7578],\n",
" [ 0.3105, 0.2988, 0.1602, ..., -1.0156, -1.4766, -2.4375],\n",
" ...,\n",
" [-1.3125, 0.4316, 1.5469, ..., -0.1621, -0.9805, 0.4141],\n",
" [ 1.1953, 1.4844, -0.0118, ..., -0.4590, 1.0078, -0.9492],\n",
" [ 0.5703, -1.0156, 2.5156, ..., 0.1777, -1.5469, -0.1475]]],\n",
" dtype=torch.bfloat16),\n",
" 't_5': tensor([0.9414], dtype=torch.bfloat16),\n",
" 'latents_5_start': tensor([[[ 1.8594, -0.7461, 0.5977, ..., -1.6250, 0.9688, 0.3613],\n",
" [ 1.5625, 0.1123, -0.1406, ..., 0.7109, -0.1670, -1.5781],\n",
" [ 0.3789, 0.3730, 0.1816, ..., -0.8672, -1.2500, -2.1719],\n",
" ...,\n",
" [-1.3984, 0.2246, 1.2500, ..., 0.1914, -0.5664, 0.7422],\n",
" [ 0.9609, 1.2891, -0.1592, ..., 0.1475, 1.4062, -0.2236],\n",
" [ 0.4551, -1.0469, 2.1875, ..., 0.8164, -0.7305, 0.4395]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_5': tensor([[[ 1.8906, -0.8984, 0.5586, ..., -1.7031, 1.0391, 0.3516],\n",
" [ 1.5625, 0.0388, -0.2871, ..., 0.8008, -0.1504, -1.7734],\n",
" [ 0.3008, 0.2949, 0.1777, ..., -1.0312, -1.5781, -2.5469],\n",
" ...,\n",
" [-1.3047, 0.4688, 1.5938, ..., -0.2188, -1.0781, 0.3945],\n",
" [ 1.2031, 1.4844, 0.0082, ..., -0.5469, 0.9414, -1.1875],\n",
" [ 0.5781, -0.9336, 2.5625, ..., -0.0903, -1.8047, -0.3828]]],\n",
" dtype=torch.bfloat16),\n",
" 't_6': tensor([0.9258], dtype=torch.bfloat16),\n",
" 'latents_6_start': tensor([[[ 1.8359, -0.7344, 0.5898, ..., -1.6016, 0.9570, 0.3574],\n",
" [ 1.5391, 0.1118, -0.1367, ..., 0.6992, -0.1650, -1.5547],\n",
" [ 0.3750, 0.3691, 0.1797, ..., -0.8555, -1.2266, -2.1406],\n",
" ...,\n",
" [-1.3828, 0.2188, 1.2266, ..., 0.1943, -0.5508, 0.7383],\n",
" [ 0.9453, 1.2734, -0.1592, ..., 0.1543, 1.3906, -0.2080],\n",
" [ 0.4473, -1.0312, 2.1562, ..., 0.8164, -0.7070, 0.4453]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_6': tensor([[[ 1.9219, -0.8828, 0.5820, ..., -1.7109, 1.0781, 0.3613],\n",
" [ 1.5703, 0.0359, -0.2812, ..., 0.7773, -0.1865, -1.8203],\n",
" [ 0.3301, 0.2949, 0.1924, ..., -1.0781, -1.6016, -2.5312],\n",
" ...,\n",
" [-1.2734, 0.4941, 1.6094, ..., -0.2236, -1.0391, 0.3633],\n",
" [ 1.2656, 1.5469, 0.0796, ..., -0.6797, 0.8672, -1.2656],\n",
" [ 0.6484, -0.9102, 2.5938, ..., -0.1904, -1.8516, -0.4590]]],\n",
" dtype=torch.bfloat16),\n",
" 't_7': tensor([0.9102], dtype=torch.bfloat16),\n",
" 'latents_7_start': tensor([[[ 1.8125, -0.7227, 0.5820, ..., -1.5781, 0.9414, 0.3535],\n",
" [ 1.5156, 0.1113, -0.1328, ..., 0.6875, -0.1621, -1.5312],\n",
" [ 0.3711, 0.3652, 0.1768, ..., -0.8398, -1.2031, -2.1094],\n",
" ...,\n",
" [-1.3672, 0.2119, 1.2031, ..., 0.1973, -0.5352, 0.7344],\n",
" [ 0.9297, 1.2500, -0.1602, ..., 0.1631, 1.3828, -0.1914],\n",
" [ 0.4395, -1.0156, 2.1250, ..., 0.8203, -0.6836, 0.4512]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_7': tensor([[[ 1.9531, -0.8906, 0.5938, ..., -1.7266, 1.0938, 0.4180],\n",
" [ 1.5781, 0.0309, -0.3008, ..., 0.7969, -0.1699, -1.8281],\n",
" [ 0.3262, 0.3008, 0.2314, ..., -1.0781, -1.6797, -2.6250],\n",
" ...,\n",
" [-1.2812, 0.5039, 1.5938, ..., -0.2314, -1.0547, 0.3828],\n",
" [ 1.2734, 1.5781, 0.0859, ..., -0.7930, 0.8555, -1.3516],\n",
" [ 0.6914, -0.9062, 2.6250, ..., -0.2598, -1.8516, -0.4902]]],\n",
" dtype=torch.bfloat16),\n",
" 't_8': tensor([0.8984], dtype=torch.bfloat16),\n",
" 'latents_8_start': tensor([[[ 1.7891, -0.7109, 0.5742, ..., -1.5547, 0.9258, 0.3477],\n",
" [ 1.4922, 0.1108, -0.1289, ..., 0.6758, -0.1602, -1.5078],\n",
" [ 0.3672, 0.3613, 0.1738, ..., -0.8242, -1.1797, -2.0781],\n",
" ...,\n",
" [-1.3516, 0.2051, 1.1797, ..., 0.2002, -0.5195, 0.7305],\n",
" [ 0.9141, 1.2266, -0.1611, ..., 0.1738, 1.3750, -0.1729],\n",
" [ 0.4297, -1.0000, 2.0938, ..., 0.8242, -0.6602, 0.4570]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_8': tensor([[[ 1.9453, -0.8789, 0.6094, ..., -1.7266, 1.0781, 0.4082],\n",
" [ 1.5703, 0.0396, -0.3047, ..., 0.7891, -0.1826, -1.8516],\n",
" [ 0.3164, 0.2949, 0.2500, ..., -1.0859, -1.7031, -2.6250],\n",
" ...,\n",
" [-1.2578, 0.5234, 1.5938, ..., -0.2246, -1.0547, 0.3770],\n",
" [ 1.2734, 1.6016, 0.0884, ..., -0.8828, 0.8086, -1.3672],\n",
" [ 0.7070, -0.8828, 2.6094, ..., -0.2832, -1.8750, -0.5117]]],\n",
" dtype=torch.bfloat16),\n",
" 't_9': tensor([0.8828], dtype=torch.bfloat16),\n",
" 'latents_9_start': tensor([[[ 1.7656, -0.6992, 0.5664, ..., -1.5312, 0.9102, 0.3418],\n",
" [ 1.4688, 0.1104, -0.1245, ..., 0.6641, -0.1572, -1.4844],\n",
" [ 0.3633, 0.3574, 0.1699, ..., -0.8086, -1.1562, -2.0469],\n",
" ...,\n",
" [-1.3359, 0.1982, 1.1562, ..., 0.2031, -0.5039, 0.7266],\n",
" [ 0.8984, 1.2031, -0.1621, ..., 0.1855, 1.3672, -0.1543],\n",
" [ 0.4199, -0.9883, 2.0625, ..., 0.8281, -0.6328, 0.4648]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_9': tensor([[[ 1.9531, -0.8828, 0.6172, ..., -1.7188, 1.0391, 0.4141],\n",
" [ 1.5703, 0.0583, -0.3125, ..., 0.7930, -0.1582, -1.8594],\n",
" [ 0.3203, 0.2910, 0.2598, ..., -1.1016, -1.7500, -2.6562],\n",
" ...,\n",
" [-1.2422, 0.5273, 1.6094, ..., -0.2168, -1.0391, 0.4121],\n",
" [ 1.2656, 1.6172, 0.1001, ..., -0.8984, 0.8008, -1.4062],\n",
" [ 0.7383, -0.8750, 2.6250, ..., -0.2891, -1.8672, -0.5312]]],\n",
" dtype=torch.bfloat16),\n",
" 't_10': tensor([0.8711], dtype=torch.bfloat16),\n",
" 'latents_10_start': tensor([[[ 1.7344, -0.6875, 0.5586, ..., -1.5078, 0.8945, 0.3359],\n",
" [ 1.4453, 0.1094, -0.1201, ..., 0.6523, -0.1553, -1.4609],\n",
" [ 0.3594, 0.3535, 0.1660, ..., -0.7930, -1.1328, -2.0156],\n",
" ...,\n",
" [-1.3203, 0.1904, 1.1328, ..., 0.2061, -0.4902, 0.7227],\n",
" [ 0.8789, 1.1797, -0.1631, ..., 0.1982, 1.3594, -0.1348],\n",
" [ 0.4102, -0.9766, 2.0312, ..., 0.8320, -0.6055, 0.4727]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_10': tensor([[[ 1.9609, -0.8672, 0.6445, ..., -1.7188, 1.0156, 0.4180],\n",
" [ 1.5781, 0.0728, -0.3125, ..., 0.7812, -0.1602, -1.8828],\n",
" [ 0.3125, 0.2832, 0.2832, ..., -1.1172, -1.7734, -2.6875],\n",
" ...,\n",
" [-1.2500, 0.5273, 1.6016, ..., -0.2227, -1.0625, 0.4062],\n",
" [ 1.2500, 1.6094, 0.1016, ..., -0.9297, 0.8086, -1.4219],\n",
" [ 0.7617, -0.8555, 2.6406, ..., -0.3105, -1.8594, -0.5352]]],\n",
" dtype=torch.bfloat16),\n",
" 't_11': tensor([0.8555], dtype=torch.bfloat16),\n",
" 'latents_11_start': tensor([[[ 1.7031, -0.6758, 0.5508, ..., -1.4844, 0.8789, 0.3301],\n",
" [ 1.4219, 0.1084, -0.1157, ..., 0.6406, -0.1533, -1.4375],\n",
" [ 0.3555, 0.3496, 0.1621, ..., -0.7773, -1.1094, -1.9766],\n",
" ...,\n",
" [-1.3047, 0.1826, 1.1094, ..., 0.2090, -0.4746, 0.7188],\n",
" [ 0.8594, 1.1562, -0.1641, ..., 0.2119, 1.3516, -0.1143],\n",
" [ 0.3984, -0.9648, 1.9922, ..., 0.8359, -0.5781, 0.4805]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_11': tensor([[[ 1.9688, -0.8516, 0.6484, ..., -1.7266, 1.0000, 0.4082],\n",
" [ 1.5938, 0.0679, -0.3105, ..., 0.8086, -0.1455, -1.8984],\n",
" [ 0.3203, 0.2812, 0.2949, ..., -1.1094, -1.7812, -2.6719],\n",
" ...,\n",
" [-1.2500, 0.5273, 1.5938, ..., -0.2119, -1.0625, 0.4102],\n",
" [ 1.2656, 1.6016, 0.1011, ..., -0.9180, 0.8281, -1.4531],\n",
" [ 0.7695, -0.8320, 2.6562, ..., -0.2891, -1.8516, -0.5234]]],\n",
" dtype=torch.bfloat16),\n",
" 't_12': tensor([0.8438], dtype=torch.bfloat16),\n",
" 'latents_12_start': tensor([[[ 1.6719, -0.6641, 0.5430, ..., -1.4609, 0.8633, 0.3242],\n",
" [ 1.3984, 0.1074, -0.1113, ..., 0.6289, -0.1514, -1.4062],\n",
" [ 0.3516, 0.3457, 0.1582, ..., -0.7617, -1.0859, -1.9375],\n",
" ...,\n",
" [-1.2891, 0.1748, 1.0859, ..., 0.2119, -0.4590, 0.7109],\n",
" [ 0.8398, 1.1328, -0.1660, ..., 0.2256, 1.3359, -0.0933],\n",
" [ 0.3867, -0.9531, 1.9531, ..., 0.8398, -0.5508, 0.4883]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_12': tensor([[[ 1.9688, -0.8477, 0.6602, ..., -1.7422, 0.9805, 0.3965],\n",
" [ 1.5938, 0.0845, -0.3066, ..., 0.7891, -0.1816, -1.9062],\n",
" [ 0.3105, 0.2754, 0.3242, ..., -1.1328, -1.8047, -2.6875],\n",
" ...,\n",
" [-1.2422, 0.5195, 1.5938, ..., -0.2227, -1.0625, 0.4180],\n",
" [ 1.2500, 1.6172, 0.1138, ..., -0.9492, 0.8281, -1.4609],\n",
" [ 0.7852, -0.8555, 2.6562, ..., -0.3047, -1.8438, -0.5430]]],\n",
" dtype=torch.bfloat16),\n",
" 't_13': tensor([0.8281], dtype=torch.bfloat16),\n",
" 'latents_13_start': tensor([[[ 1.6406, -0.6523, 0.5312, ..., -1.4375, 0.8477, 0.3184],\n",
" [ 1.3750, 0.1060, -0.1069, ..., 0.6172, -0.1484, -1.3750],\n",
" [ 0.3477, 0.3418, 0.1533, ..., -0.7461, -1.0625, -1.8984],\n",
" ...,\n",
" [-1.2734, 0.1670, 1.0625, ..., 0.2148, -0.4434, 0.7031],\n",
" [ 0.8203, 1.1094, -0.1680, ..., 0.2393, 1.3203, -0.0718],\n",
" [ 0.3750, -0.9414, 1.9141, ..., 0.8438, -0.5234, 0.4961]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_13': tensor([[[ 1.9688, -0.8438, 0.6562, ..., -1.7500, 0.9805, 0.3906],\n",
" [ 1.5938, 0.0791, -0.3066, ..., 0.7734, -0.1934, -1.9062],\n",
" [ 0.3145, 0.2734, 0.3203, ..., -1.1250, -1.8359, -2.7188],\n",
" ...,\n",
" [-1.2422, 0.5156, 1.6094, ..., -0.2021, -1.0312, 0.4609],\n",
" [ 1.2656, 1.6250, 0.1108, ..., -0.9453, 0.8789, -1.4688],\n",
" [ 0.7930, -0.8594, 2.6719, ..., -0.3105, -1.8359, -0.5469]]],\n",
" dtype=torch.bfloat16),\n",
" 't_14': tensor([0.8125], dtype=torch.bfloat16),\n",
" 'latents_14_start': tensor([[[ 1.6094, -0.6406, 0.5195, ..., -1.4141, 0.8320, 0.3125],\n",
" [ 1.3516, 0.1050, -0.1025, ..., 0.6055, -0.1455, -1.3438],\n",
" [ 0.3438, 0.3379, 0.1484, ..., -0.7305, -1.0312, -1.8594],\n",
" ...,\n",
" [-1.2578, 0.1592, 1.0391, ..., 0.2178, -0.4277, 0.6953],\n",
" [ 0.8008, 1.0859, -0.1699, ..., 0.2539, 1.3047, -0.0498],\n",
" [ 0.3633, -0.9297, 1.8750, ..., 0.8477, -0.4961, 0.5039]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_14': tensor([[[ 1.9609, -0.8438, 0.6562, ..., -1.7500, 0.9727, 0.3828],\n",
" [ 1.5938, 0.0840, -0.3203, ..., 0.7695, -0.2061, -1.8906],\n",
" [ 0.3125, 0.2754, 0.3262, ..., -1.1328, -1.8359, -2.7031],\n",
" ...,\n",
" [-1.2422, 0.5117, 1.6094, ..., -0.2002, -1.0156, 0.4746],\n",
" [ 1.2656, 1.6172, 0.1108, ..., -0.9219, 0.8945, -1.4609],\n",
" [ 0.7969, -0.8672, 2.6406, ..., -0.3047, -1.7969, -0.5430]]],\n",
" dtype=torch.bfloat16),\n",
" 't_15': tensor([0.7969], dtype=torch.bfloat16),\n",
" 'latents_15_start': tensor([[[ 1.5781, -0.6289, 0.5078, ..., -1.3906, 0.8164, 0.3066],\n",
" [ 1.3281, 0.1035, -0.0977, ..., 0.5938, -0.1426, -1.3125],\n",
" [ 0.3398, 0.3340, 0.1436, ..., -0.7148, -1.0000, -1.8203],\n",
" ...,\n",
" [-1.2422, 0.1514, 1.0156, ..., 0.2207, -0.4121, 0.6875],\n",
" [ 0.7812, 1.0625, -0.1719, ..., 0.2676, 1.2891, -0.0275],\n",
" [ 0.3516, -0.9180, 1.8359, ..., 0.8516, -0.4688, 0.5117]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_15': tensor([[[ 1.9531, -0.8320, 0.6641, ..., -1.7578, 0.9570, 0.3789],\n",
" [ 1.5938, 0.0806, -0.3184, ..., 0.7500, -0.2031, -1.8750],\n",
" [ 0.3242, 0.2656, 0.3301, ..., -1.1406, -1.8359, -2.7031],\n",
" ...,\n",
" [-1.2578, 0.5195, 1.6328, ..., -0.1875, -0.9883, 0.5117],\n",
" [ 1.2734, 1.6094, 0.1230, ..., -0.9297, 0.9102, -1.4531],\n",
" [ 0.7930, -0.8633, 2.6406, ..., -0.3027, -1.8203, -0.5312]]],\n",
" dtype=torch.bfloat16),\n",
" 't_16': tensor([0.7812], dtype=torch.bfloat16),\n",
" 'latents_16_start': tensor([[[ 1.5469, -0.6172, 0.4980, ..., -1.3594, 0.8008, 0.3008],\n",
" [ 1.3047, 0.1021, -0.0928, ..., 0.5820, -0.1396, -1.2812],\n",
" [ 0.3340, 0.3301, 0.1387, ..., -0.6953, -0.9727, -1.7812],\n",
" ...,\n",
" [-1.2188, 0.1436, 0.9883, ..., 0.2236, -0.3965, 0.6797],\n",
" [ 0.7617, 1.0391, -0.1738, ..., 0.2812, 1.2734, -0.0048],\n",
" [ 0.3398, -0.9062, 1.7969, ..., 0.8555, -0.4395, 0.5195]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_16': tensor([[[ 1.9766, -0.8320, 0.6758, ..., -1.7734, 0.9844, 0.3867],\n",
" [ 1.6094, 0.0923, -0.3164, ..., 0.7617, -0.2148, -1.8984],\n",
" [ 0.3281, 0.2695, 0.3398, ..., -1.1484, -1.8516, -2.7500],\n",
" ...,\n",
" [-1.2500, 0.5156, 1.6406, ..., -0.1953, -0.9648, 0.5273],\n",
" [ 1.3125, 1.6250, 0.1113, ..., -0.9102, 0.9414, -1.4609],\n",
" [ 0.7969, -0.8750, 2.6719, ..., -0.2988, -1.7891, -0.5469]]],\n",
" dtype=torch.bfloat16),\n",
" 't_17': tensor([0.7656], dtype=torch.bfloat16),\n",
" 'latents_17_start': tensor([[[ 1.5156, -0.6055, 0.4883, ..., -1.3281, 0.7852, 0.2949],\n",
" [ 1.2812, 0.1006, -0.0879, ..., 0.5703, -0.1367, -1.2500],\n",
" [ 0.3281, 0.3262, 0.1328, ..., -0.6758, -0.9414, -1.7344],\n",
" ...,\n",
" [-1.1953, 0.1357, 0.9609, ..., 0.2266, -0.3809, 0.6719],\n",
" [ 0.7422, 1.0156, -0.1758, ..., 0.2949, 1.2578, 0.0184],\n",
" [ 0.3281, -0.8906, 1.7578, ..., 0.8594, -0.4102, 0.5273]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_17': tensor([[[ 1.9688, -0.8242, 0.6719, ..., -1.7578, 0.9688, 0.3691],\n",
" [ 1.6094, 0.0869, -0.3145, ..., 0.7500, -0.2217, -1.8828],\n",
" [ 0.3203, 0.2754, 0.3457, ..., -1.1406, -1.8516, -2.7500],\n",
" ...,\n",
" [-1.2266, 0.5156, 1.6250, ..., -0.1904, -0.9492, 0.5273],\n",
" [ 1.3047, 1.6172, 0.1040, ..., -0.9141, 0.9570, -1.4531],\n",
" [ 0.7852, -0.8633, 2.6562, ..., -0.2949, -1.7969, -0.5430]]],\n",
" dtype=torch.bfloat16),\n",
" 't_18': tensor([0.7461], dtype=torch.bfloat16),\n",
" 'latents_18_start': tensor([[[ 1.4844, -0.5938, 0.4766, ..., -1.2969, 0.7695, 0.2891],\n",
" [ 1.2578, 0.0991, -0.0830, ..., 0.5586, -0.1328, -1.2188],\n",
" [ 0.3223, 0.3223, 0.1270, ..., -0.6562, -0.9102, -1.6875],\n",
" ...,\n",
" [-1.1719, 0.1270, 0.9336, ..., 0.2295, -0.3652, 0.6641],\n",
" [ 0.7227, 0.9883, -0.1777, ..., 0.3105, 1.2422, 0.0420],\n",
" [ 0.3145, -0.8750, 1.7109, ..., 0.8633, -0.3809, 0.5352]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_18': tensor([[[ 1.9844, -0.8398, 0.6680, ..., -1.7578, 0.9727, 0.3730],\n",
" [ 1.6172, 0.0752, -0.3184, ..., 0.7578, -0.2148, -1.8828],\n",
" [ 0.3066, 0.2715, 0.3398, ..., -1.1328, -1.8516, -2.7500],\n",
" ...,\n",
" [-1.2422, 0.5156, 1.6484, ..., -0.1777, -0.9336, 0.5625],\n",
" [ 1.3047, 1.6328, 0.1147, ..., -0.8906, 0.9883, -1.4688],\n",
" [ 0.7734, -0.8672, 2.6406, ..., -0.2910, -1.7891, -0.5312]]],\n",
" dtype=torch.bfloat16),\n",
" 't_19': tensor([0.7305], dtype=torch.bfloat16),\n",
" 'latents_19_start': tensor([[[ 1.4531, -0.5781, 0.4648, ..., -1.2656, 0.7539, 0.2832],\n",
" [ 1.2344, 0.0977, -0.0776, ..., 0.5469, -0.1289, -1.1875],\n",
" [ 0.3164, 0.3184, 0.1211, ..., -0.6367, -0.8789, -1.6406],\n",
" ...,\n",
" [-1.1484, 0.1182, 0.9062, ..., 0.2324, -0.3496, 0.6562],\n",
" [ 0.6992, 0.9609, -0.1797, ..., 0.3262, 1.2266, 0.0664],\n",
" [ 0.3008, -0.8594, 1.6641, ..., 0.8672, -0.3516, 0.5430]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_19': tensor([[[ 1.9844, -0.8516, 0.6641, ..., -1.7656, 0.9688, 0.3770],\n",
" [ 1.6094, 0.0684, -0.3281, ..., 0.7734, -0.2119, -1.8672],\n",
" [ 0.3086, 0.2695, 0.3301, ..., -1.1484, -1.8438, -2.7188],\n",
" ...,\n",
" [-1.2422, 0.5117, 1.6484, ..., -0.1738, -0.8984, 0.5781],\n",
" [ 1.3203, 1.6328, 0.1157, ..., -0.8828, 1.0078, -1.4844],\n",
" [ 0.7617, -0.8672, 2.6719, ..., -0.2480, -1.8125, -0.5273]]],\n",
" dtype=torch.bfloat16),\n",
" 't_20': tensor([0.7148], dtype=torch.bfloat16),\n",
" 'latents_20_start': tensor([[[ 1.4219, -0.5625, 0.4531, ..., -1.2344, 0.7383, 0.2773],\n",
" [ 1.2109, 0.0967, -0.0723, ..., 0.5352, -0.1250, -1.1562],\n",
" [ 0.3105, 0.3145, 0.1157, ..., -0.6172, -0.8477, -1.5938],\n",
" ...,\n",
" [-1.1250, 0.1094, 0.8789, ..., 0.2354, -0.3340, 0.6484],\n",
" [ 0.6758, 0.9336, -0.1816, ..., 0.3418, 1.2109, 0.0913],\n",
" [ 0.2871, -0.8438, 1.6172, ..., 0.8711, -0.3203, 0.5508]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_20': tensor([[[ 1.9766, -0.8438, 0.6562, ..., -1.7656, 0.9570, 0.3809],\n",
" [ 1.6094, 0.0713, -0.3340, ..., 0.7734, -0.2246, -1.8594],\n",
" [ 0.2910, 0.2598, 0.3281, ..., -1.1250, -1.8359, -2.7500],\n",
" ...,\n",
" [-1.2422, 0.5039, 1.6406, ..., -0.1738, -0.8867, 0.6016],\n",
" [ 1.3125, 1.6172, 0.1187, ..., -0.8672, 1.0156, -1.4922],\n",
" [ 0.7578, -0.8711, 2.6719, ..., -0.2559, -1.7891, -0.5547]]],\n",
" dtype=torch.bfloat16),\n",
" 't_21': tensor([0.6992], dtype=torch.bfloat16),\n",
" 'latents_21_start': tensor([[[ 1.3906, -0.5469, 0.4414, ..., -1.2031, 0.7227, 0.2715],\n",
" [ 1.1797, 0.0952, -0.0664, ..., 0.5234, -0.1211, -1.1250],\n",
" [ 0.3047, 0.3105, 0.1099, ..., -0.5977, -0.8164, -1.5469],\n",
" ...,\n",
" [-1.1016, 0.1006, 0.8516, ..., 0.2383, -0.3184, 0.6367],\n",
" [ 0.6523, 0.9062, -0.1836, ..., 0.3574, 1.1953, 0.1172],\n",
" [ 0.2734, -0.8281, 1.5703, ..., 0.8750, -0.2891, 0.5586]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_21': tensor([[[ 1.9844, -0.8281, 0.6680, ..., -1.7578, 0.9375, 0.3457],\n",
" [ 1.5938, 0.0811, -0.3203, ..., 0.7656, -0.2207, -1.8594],\n",
" [ 0.2773, 0.2559, 0.3340, ..., -1.1328, -1.8438, -2.7344],\n",
" ...,\n",
" [-1.2266, 0.4961, 1.6406, ..., -0.1953, -0.8750, 0.5859],\n",
" [ 1.2969, 1.6250, 0.1147, ..., -0.8594, 1.0156, -1.4922],\n",
" [ 0.7578, -0.8711, 2.6719, ..., -0.2617, -1.7891, -0.5508]]],\n",
" dtype=torch.bfloat16),\n",
" 't_22': tensor([0.6797], dtype=torch.bfloat16),\n",
" 'latents_22_start': tensor([[[ 1.3594, -0.5312, 0.4297, ..., -1.1719, 0.7070, 0.2656],\n",
" [ 1.1484, 0.0938, -0.0608, ..., 0.5117, -0.1172, -1.0938],\n",
" [ 0.3008, 0.3066, 0.1040, ..., -0.5781, -0.7852, -1.5000],\n",
" ...,\n",
" [-1.0781, 0.0918, 0.8242, ..., 0.2422, -0.3027, 0.6250],\n",
" [ 0.6289, 0.8789, -0.1855, ..., 0.3730, 1.1797, 0.1436],\n",
" [ 0.2598, -0.8125, 1.5234, ..., 0.8789, -0.2578, 0.5664]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_22': tensor([[[ 1.9922, -0.8242, 0.6523, ..., -1.7500, 0.9375, 0.3477],\n",
" [ 1.5859, 0.0757, -0.3379, ..., 0.7578, -0.2178, -1.8438],\n",
" [ 0.2930, 0.2520, 0.3320, ..., -1.1250, -1.8516, -2.7500],\n",
" ...,\n",
" [-1.2031, 0.5000, 1.6406, ..., -0.2012, -0.8750, 0.5820],\n",
" [ 1.3047, 1.6094, 0.1309, ..., -0.8555, 1.0234, -1.5078],\n",
" [ 0.7617, -0.8711, 2.6562, ..., -0.2793, -1.7969, -0.5742]]],\n",
" dtype=torch.bfloat16),\n",
" 't_23': tensor([0.6641], dtype=torch.bfloat16),\n",
" 'latents_23_start': tensor([[[ 1.3203, -0.5156, 0.4180, ..., -1.1406, 0.6914, 0.2598],\n",
" [ 1.1172, 0.0923, -0.0547, ..., 0.4980, -0.1133, -1.0625],\n",
" [ 0.2949, 0.3027, 0.0981, ..., -0.5586, -0.7500, -1.4531],\n",
" ...,\n",
" [-1.0547, 0.0830, 0.7930, ..., 0.2461, -0.2871, 0.6133],\n",
" [ 0.6055, 0.8516, -0.1875, ..., 0.3887, 1.1641, 0.1709],\n",
" [ 0.2461, -0.7969, 1.4766, ..., 0.8828, -0.2256, 0.5781]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_23': tensor([[[ 1.9688, -0.8203, 0.6562, ..., -1.7422, 0.9336, 0.3359],\n",
" [ 1.5781, 0.0923, -0.3164, ..., 0.7617, -0.2188, -1.8438],\n",
" [ 0.2969, 0.2617, 0.3203, ..., -1.1328, -1.8594, -2.7500],\n",
" ...,\n",
" [-1.2109, 0.5117, 1.6406, ..., -0.1914, -0.8711, 0.5938],\n",
" [ 1.2891, 1.6094, 0.1182, ..., -0.8477, 1.0469, -1.5000],\n",
" [ 0.7461, -0.8945, 2.6562, ..., -0.2852, -1.8047, -0.5586]]],\n",
" dtype=torch.bfloat16),\n",
" 't_24': tensor([0.6445], dtype=torch.bfloat16),\n",
" 'latents_24_start': tensor([[[ 1.2812, -0.5000, 0.4062, ..., -1.1094, 0.6758, 0.2539],\n",
" [ 1.0859, 0.0908, -0.0488, ..., 0.4844, -0.1094, -1.0312],\n",
" [ 0.2891, 0.2988, 0.0923, ..., -0.5391, -0.7148, -1.4062],\n",
" ...,\n",
" [-1.0312, 0.0737, 0.7617, ..., 0.2500, -0.2715, 0.6016],\n",
" [ 0.5820, 0.8203, -0.1895, ..., 0.4043, 1.1484, 0.1982],\n",
" [ 0.2324, -0.7812, 1.4297, ..., 0.8867, -0.1924, 0.5898]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_24': tensor([[[ 1.9688, -0.8164, 0.6484, ..., -1.7422, 0.9492, 0.3574],\n",
" [ 1.5703, 0.0918, -0.3262, ..., 0.7734, -0.2207, -1.8438],\n",
" [ 0.2871, 0.2637, 0.3340, ..., -1.1250, -1.8516, -2.7656],\n",
" ...,\n",
" [-1.2031, 0.4961, 1.6328, ..., -0.1768, -0.8555, 0.6055],\n",
" [ 1.2891, 1.6172, 0.1211, ..., -0.8516, 1.0625, -1.5000],\n",
" [ 0.7461, -0.8867, 2.6562, ..., -0.2891, -1.8047, -0.5391]]],\n",
" dtype=torch.bfloat16),\n",
" 't_25': tensor([0.6289], dtype=torch.bfloat16),\n",
" 'latents_25_start': tensor([[[ 1.2422, -0.4844, 0.3945, ..., -1.0781, 0.6562, 0.2471],\n",
" [ 1.0547, 0.0889, -0.0427, ..., 0.4707, -0.1055, -0.9961],\n",
" [ 0.2832, 0.2930, 0.0859, ..., -0.5195, -0.6797, -1.3516],\n",
" ...,\n",
" [-1.0078, 0.0645, 0.7305, ..., 0.2539, -0.2559, 0.5898],\n",
" [ 0.5586, 0.7891, -0.1914, ..., 0.4199, 1.1250, 0.2266],\n",
" [ 0.2188, -0.7656, 1.3828, ..., 0.8906, -0.1582, 0.6016]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_25': tensor([[[ 1.9609, -0.8242, 0.6523, ..., -1.7422, 0.9258, 0.3418],\n",
" [ 1.5625, 0.0850, -0.3359, ..., 0.7812, -0.2295, -1.8516],\n",
" [ 0.2871, 0.2520, 0.3184, ..., -1.1250, -1.8359, -2.7500],\n",
" ...,\n",
" [-1.1875, 0.4863, 1.6250, ..., -0.1924, -0.8633, 0.6055],\n",
" [ 1.2969, 1.6172, 0.1240, ..., -0.8359, 1.0625, -1.5078],\n",
" [ 0.7305, -0.8789, 2.6562, ..., -0.2969, -1.8203, -0.5469]]],\n",
" dtype=torch.bfloat16),\n",
" 't_26': tensor([0.6094], dtype=torch.bfloat16),\n",
" 'latents_26_start': tensor([[[ 1.2031, -0.4688, 0.3828, ..., -1.0469, 0.6406, 0.2402],\n",
" [ 1.0234, 0.0874, -0.0364, ..., 0.4551, -0.1011, -0.9609],\n",
" [ 0.2773, 0.2891, 0.0801, ..., -0.4980, -0.6445, -1.2969],\n",
" ...,\n",
" [-0.9844, 0.0552, 0.6992, ..., 0.2578, -0.2393, 0.5781],\n",
" [ 0.5352, 0.7578, -0.1934, ..., 0.4355, 1.1016, 0.2559],\n",
" [ 0.2051, -0.7500, 1.3359, ..., 0.8945, -0.1235, 0.6133]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_26': tensor([[[ 1.9609, -0.8320, 0.6602, ..., -1.7578, 0.9219, 0.3496],\n",
" [ 1.5703, 0.0801, -0.3359, ..., 0.7812, -0.2266, -1.8281],\n",
" [ 0.2793, 0.2471, 0.3223, ..., -1.1172, -1.8281, -2.7188],\n",
" ...,\n",
" [-1.1797, 0.4863, 1.6172, ..., -0.2021, -0.8516, 0.6133],\n",
" [ 1.2969, 1.6016, 0.1226, ..., -0.8359, 1.0625, -1.5000],\n",
" [ 0.7305, -0.8906, 2.6562, ..., -0.2988, -1.8047, -0.5469]]],\n",
" dtype=torch.bfloat16),\n",
" 't_27': tensor([0.5898], dtype=torch.bfloat16),\n",
" 'latents_27_start': tensor([[[ 1.1641, -0.4531, 0.3691, ..., -1.0156, 0.6211, 0.2334],\n",
" [ 0.9922, 0.0859, -0.0298, ..., 0.4395, -0.0967, -0.9258],\n",
" [ 0.2715, 0.2852, 0.0737, ..., -0.4766, -0.6094, -1.2422],\n",
" ...,\n",
" [-0.9609, 0.0457, 0.6680, ..., 0.2617, -0.2227, 0.5664],\n",
" [ 0.5078, 0.7266, -0.1953, ..., 0.4512, 1.0781, 0.2852],\n",
" [ 0.1904, -0.7344, 1.2812, ..., 0.8984, -0.0884, 0.6250]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_27': tensor([[[ 1.9453, -0.8398, 0.6562, ..., -1.7578, 0.9141, 0.3398],\n",
" [ 1.5469, 0.0654, -0.3379, ..., 0.7734, -0.2236, -1.8359],\n",
" [ 0.2773, 0.2354, 0.3223, ..., -1.1094, -1.8359, -2.7188],\n",
" ...,\n",
" [-1.1797, 0.4844, 1.6328, ..., -0.1992, -0.8359, 0.6172],\n",
" [ 1.2891, 1.5859, 0.1133, ..., -0.8242, 1.0469, -1.4922],\n",
" [ 0.7148, -0.9180, 2.6406, ..., -0.3047, -1.8203, -0.5508]]],\n",
" dtype=torch.bfloat16),\n",
" 't_28': tensor([0.5664], dtype=torch.bfloat16),\n",
" 'latents_28_start': tensor([[[ 1.1250, -0.4355, 0.3555, ..., -0.9805, 0.6016, 0.2266],\n",
" [ 0.9609, 0.0845, -0.0231, ..., 0.4238, -0.0923, -0.8906],\n",
" [ 0.2656, 0.2812, 0.0674, ..., -0.4551, -0.5742, -1.1875],\n",
" ...,\n",
" [-0.9375, 0.0361, 0.6367, ..., 0.2656, -0.2061, 0.5547],\n",
" [ 0.4824, 0.6953, -0.1973, ..., 0.4668, 1.0547, 0.3145],\n",
" [ 0.1758, -0.7148, 1.2266, ..., 0.9062, -0.0522, 0.6367]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_28': tensor([[[ 1.9453, -0.8281, 0.6406, ..., -1.7344, 0.9219, 0.3672],\n",
" [ 1.5547, 0.0684, -0.3496, ..., 0.8047, -0.1953, -1.8281],\n",
" [ 0.2812, 0.2207, 0.3281, ..., -1.1016, -1.8359, -2.7188],\n",
" ...,\n",
" [-1.1953, 0.4668, 1.6328, ..., -0.1758, -0.8203, 0.6445],\n",
" [ 1.2734, 1.5781, 0.1045, ..., -0.7969, 1.0547, -1.4922],\n",
" [ 0.6953, -0.9258, 2.6562, ..., -0.3086, -1.7969, -0.5508]]],\n",
" dtype=torch.bfloat16),\n",
" 't_29': tensor([0.5469], dtype=torch.bfloat16),\n",
" 'latents_29_start': tensor([[[ 1.0859, -0.4180, 0.3418, ..., -0.9453, 0.5820, 0.2188],\n",
" [ 0.9297, 0.0830, -0.0159, ..., 0.4082, -0.0884, -0.8516],\n",
" [ 0.2598, 0.2773, 0.0608, ..., -0.4336, -0.5352, -1.1328],\n",
" ...,\n",
" [-0.9141, 0.0266, 0.6016, ..., 0.2695, -0.1895, 0.5430],\n",
" [ 0.4570, 0.6641, -0.1992, ..., 0.4824, 1.0312, 0.3457],\n",
" [ 0.1621, -0.6953, 1.1719, ..., 0.9141, -0.0156, 0.6484]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_29': tensor([[[ 1.9688, -0.8320, 0.6445, ..., -1.7734, 0.9219, 0.3672],\n",
" [ 1.5469, 0.0732, -0.3477, ..., 0.7930, -0.2100, -1.8281],\n",
" [ 0.2793, 0.2354, 0.3262, ..., -1.1250, -1.8359, -2.7188],\n",
" ...,\n",
" [-1.1953, 0.4746, 1.6172, ..., -0.1738, -0.8086, 0.6484],\n",
" [ 1.2969, 1.5781, 0.0952, ..., -0.8164, 1.0859, -1.4766],\n",
" [ 0.6797, -0.9180, 2.6562, ..., -0.3145, -1.7812, -0.5508]]],\n",
" dtype=torch.bfloat16),\n",
" 't_30': tensor([0.5273], dtype=torch.bfloat16),\n",
" 'latents_30_start': tensor([[[ 1.0469, -0.4004, 0.3281, ..., -0.9102, 0.5625, 0.2109],\n",
" [ 0.8984, 0.0815, -0.0087, ..., 0.3926, -0.0840, -0.8125],\n",
" [ 0.2539, 0.2734, 0.0540, ..., -0.4102, -0.4961, -1.0781],\n",
" ...,\n",
" [-0.8906, 0.0168, 0.5664, ..., 0.2734, -0.1729, 0.5312],\n",
" [ 0.4297, 0.6328, -0.2012, ..., 0.5000, 1.0078, 0.3770],\n",
" [ 0.1484, -0.6758, 1.1172, ..., 0.9219, 0.0212, 0.6602]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_30': tensor([[[ 1.9766, -0.8359, 0.6484, ..., -1.7812, 0.9219, 0.3691],\n",
" [ 1.5469, 0.0693, -0.3359, ..., 0.7969, -0.2090, -1.8203],\n",
" [ 0.2852, 0.2363, 0.3320, ..., -1.1172, -1.8359, -2.7344],\n",
" ...,\n",
" [-1.1953, 0.4766, 1.6406, ..., -0.1855, -0.8203, 0.6484],\n",
" [ 1.2891, 1.5781, 0.1064, ..., -0.8203, 1.0859, -1.4922],\n",
" [ 0.6953, -0.9219, 2.6719, ..., -0.3145, -1.7656, -0.5312]]],\n",
" dtype=torch.bfloat16),\n",
" 't_31': tensor([0.5078], dtype=torch.bfloat16),\n",
" 'latents_31_start': tensor([[[ 1.0078, -0.3828, 0.3145, ..., -0.8711, 0.5430, 0.2031],\n",
" [ 0.8672, 0.0801, -0.0015, ..., 0.3750, -0.0796, -0.7734],\n",
" [ 0.2480, 0.2676, 0.0469, ..., -0.3867, -0.4570, -1.0234],\n",
" ...,\n",
" [-0.8672, 0.0067, 0.5312, ..., 0.2773, -0.1553, 0.5156],\n",
" [ 0.4023, 0.5977, -0.2031, ..., 0.5156, 0.9844, 0.4082],\n",
" [ 0.1338, -0.6562, 1.0625, ..., 0.9297, 0.0588, 0.6719]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_31': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9297, 0.3555],\n",
" [ 1.5469, 0.0796, -0.3516, ..., 0.7969, -0.2051, -1.8281],\n",
" [ 0.2754, 0.2344, 0.3262, ..., -1.1172, -1.8359, -2.7344],\n",
" ...,\n",
" [-1.1953, 0.4766, 1.6328, ..., -0.1895, -0.8125, 0.6289],\n",
" [ 1.2969, 1.5859, 0.0938, ..., -0.8125, 1.0781, -1.4766],\n",
" [ 0.6836, -0.9375, 2.6562, ..., -0.2949, -1.7734, -0.5234]]],\n",
" dtype=torch.bfloat16),\n",
" 't_32': tensor([0.4844], dtype=torch.bfloat16),\n",
" 'latents_32_start': tensor([[[ 0.9648, -0.3652, 0.3008, ..., -0.8320, 0.5234, 0.1953],\n",
" [ 0.8320, 0.0781, 0.0061, ..., 0.3574, -0.0752, -0.7344],\n",
" [ 0.2422, 0.2617, 0.0398, ..., -0.3633, -0.4180, -0.9648],\n",
" ...,\n",
" [-0.8398, -0.0037, 0.4961, ..., 0.2812, -0.1377, 0.5000],\n",
" [ 0.3750, 0.5625, -0.2051, ..., 0.5352, 0.9609, 0.4395],\n",
" [ 0.1191, -0.6367, 1.0078, ..., 0.9375, 0.0977, 0.6836]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_32': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9258, 0.3535],\n",
" [ 1.5391, 0.0728, -0.3457, ..., 0.7891, -0.2021, -1.8203],\n",
" [ 0.2754, 0.2344, 0.3223, ..., -1.1172, -1.8281, -2.7344],\n",
" ...,\n",
" [-1.1875, 0.4688, 1.6250, ..., -0.1768, -0.8086, 0.6133],\n",
" [ 1.2812, 1.5703, 0.1079, ..., -0.8320, 1.0703, -1.4844],\n",
" [ 0.6914, -0.9297, 2.6562, ..., -0.3027, -1.7656, -0.5156]]],\n",
" dtype=torch.bfloat16),\n",
" 't_33': tensor([0.4629], dtype=torch.bfloat16),\n",
" 'latents_33_start': tensor([[[ 0.9219, -0.3477, 0.2871, ..., -0.7930, 0.5039, 0.1875],\n",
" [ 0.7969, 0.0767, 0.0138, ..., 0.3398, -0.0708, -0.6953],\n",
" [ 0.2363, 0.2559, 0.0327, ..., -0.3379, -0.3770, -0.9023],\n",
" ...,\n",
" [-0.8125, -0.0141, 0.4609, ..., 0.2852, -0.1196, 0.4863],\n",
" [ 0.3457, 0.5273, -0.2070, ..., 0.5547, 0.9375, 0.4727],\n",
" [ 0.1035, -0.6172, 0.9492, ..., 0.9453, 0.1367, 0.6953]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_33': tensor([[[ 1.9609, -0.8242, 0.6484, ..., -1.7578, 0.9258, 0.3477],\n",
" [ 1.5312, 0.0684, -0.3379, ..., 0.7891, -0.2012, -1.8125],\n",
" [ 0.2812, 0.2158, 0.3164, ..., -1.1172, -1.8125, -2.7188],\n",
" ...,\n",
" [-1.1797, 0.4570, 1.6250, ..., -0.1826, -0.8086, 0.6055],\n",
" [ 1.2812, 1.5625, 0.1025, ..., -0.8125, 1.0625, -1.4922],\n",
" [ 0.6797, -0.9297, 2.6250, ..., -0.2949, -1.7656, -0.5117]]],\n",
" dtype=torch.bfloat16),\n",
" 't_34': tensor([0.4375], dtype=torch.bfloat16),\n",
" 'latents_34_start': tensor([[[ 0.8789, -0.3281, 0.2715, ..., -0.7539, 0.4824, 0.1797],\n",
" [ 0.7617, 0.0752, 0.0215, ..., 0.3223, -0.0664, -0.6523],\n",
" [ 0.2295, 0.2500, 0.0255, ..., -0.3125, -0.3359, -0.8398],\n",
" ...,\n",
" [-0.7852, -0.0245, 0.4238, ..., 0.2891, -0.1011, 0.4727],\n",
" [ 0.3164, 0.4922, -0.2090, ..., 0.5742, 0.9141, 0.5078],\n",
" [ 0.0879, -0.5977, 0.8906, ..., 0.9531, 0.1768, 0.7070]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_34': tensor([[[ 1.9766, -0.8242, 0.6523, ..., -1.7812, 0.9258, 0.3535],\n",
" [ 1.5312, 0.0640, -0.3457, ..., 0.8086, -0.1934, -1.8047],\n",
" [ 0.2715, 0.1992, 0.3105, ..., -1.0938, -1.8047, -2.7344],\n",
" ...,\n",
" [-1.1875, 0.4609, 1.6172, ..., -0.1953, -0.8164, 0.6172],\n",
" [ 1.2812, 1.5625, 0.0942, ..., -0.8242, 1.0781, -1.4922],\n",
" [ 0.6562, -0.9531, 2.6406, ..., -0.3008, -1.7734, -0.4980]]],\n",
" dtype=torch.bfloat16),\n",
" 't_35': tensor([0.4160], dtype=torch.bfloat16),\n",
" 'latents_35_start': tensor([[[ 0.8320, -0.3086, 0.2559, ..., -0.7109, 0.4609, 0.1719],\n",
" [ 0.7266, 0.0737, 0.0295, ..., 0.3027, -0.0620, -0.6094],\n",
" [ 0.2236, 0.2451, 0.0183, ..., -0.2871, -0.2930, -0.7773],\n",
" ...,\n",
" [-0.7578, -0.0352, 0.3867, ..., 0.2930, -0.0820, 0.4590],\n",
" [ 0.2871, 0.4551, -0.2109, ..., 0.5938, 0.8906, 0.5430],\n",
" [ 0.0728, -0.5742, 0.8281, ..., 0.9609, 0.2178, 0.7188]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_35': tensor([[[ 1.9766, -0.8242, 0.6484, ..., -1.7734, 0.9258, 0.3496],\n",
" [ 1.5312, 0.0708, -0.3418, ..., 0.8047, -0.1953, -1.7969],\n",
" [ 0.2832, 0.1875, 0.3086, ..., -1.0938, -1.7891, -2.7344],\n",
" ...,\n",
" [-1.1719, 0.4551, 1.6172, ..., -0.1953, -0.8047, 0.6055],\n",
" [ 1.2578, 1.5625, 0.0898, ..., -0.8242, 1.0859, -1.4766],\n",
" [ 0.6602, -0.9414, 2.6406, ..., -0.3164, -1.7656, -0.5078]]],\n",
" dtype=torch.bfloat16),\n",
" 't_36': tensor([0.3926], dtype=torch.bfloat16),\n",
" 'latents_36_start': tensor([[[ 0.7852, -0.2891, 0.2402, ..., -0.6680, 0.4395, 0.1641],\n",
" [ 0.6914, 0.0723, 0.0376, ..., 0.2832, -0.0574, -0.5664],\n",
" [ 0.2168, 0.2402, 0.0110, ..., -0.2617, -0.2500, -0.7109],\n",
" ...,\n",
" [-0.7305, -0.0459, 0.3477, ..., 0.2969, -0.0630, 0.4453],\n",
" [ 0.2578, 0.4180, -0.2129, ..., 0.6133, 0.8633, 0.5781],\n",
" [ 0.0571, -0.5508, 0.7656, ..., 0.9688, 0.2598, 0.7305]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_36': tensor([[[ 1.9609, -0.8164, 0.6445, ..., -1.7656, 0.9102, 0.3477],\n",
" [ 1.5234, 0.0654, -0.3320, ..., 0.8164, -0.2041, -1.7812],\n",
" [ 0.2734, 0.1836, 0.3066, ..., -1.0938, -1.7891, -2.7031],\n",
" ...,\n",
" [-1.1875, 0.4688, 1.6016, ..., -0.2051, -0.8008, 0.5977],\n",
" [ 1.2500, 1.5234, 0.0747, ..., -0.8438, 1.0781, -1.4922],\n",
" [ 0.6484, -0.9219, 2.6250, ..., -0.3027, -1.7812, -0.5078]]],\n",
" dtype=torch.bfloat16),\n",
" 't_37': tensor([0.3652], dtype=torch.bfloat16),\n",
" 'latents_37_start': tensor([[[ 0.7383, -0.2695, 0.2246, ..., -0.6250, 0.4180, 0.1553],\n",
" [ 0.6562, 0.0708, 0.0457, ..., 0.2637, -0.0525, -0.5234],\n",
" [ 0.2100, 0.2354, 0.0035, ..., -0.2354, -0.2061, -0.6445],\n",
" ...,\n",
" [-0.7031, -0.0574, 0.3086, ..., 0.3027, -0.0435, 0.4316],\n",
" [ 0.2275, 0.3809, -0.2148, ..., 0.6328, 0.8359, 0.6133],\n",
" [ 0.0413, -0.5273, 0.7031, ..., 0.9766, 0.3027, 0.7422]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_37': tensor([[[ 1.9922, -0.8320, 0.6523, ..., -1.7812, 0.9336, 0.3613],\n",
" [ 1.5312, 0.0713, -0.3477, ..., 0.8281, -0.1963, -1.7891],\n",
" [ 0.2754, 0.1777, 0.2930, ..., -1.0859, -1.7734, -2.7031],\n",
" ...,\n",
" [-1.1875, 0.4590, 1.6172, ..., -0.1855, -0.8086, 0.5820],\n",
" [ 1.2422, 1.5391, 0.0552, ..., -0.8633, 1.0781, -1.5156],\n",
" [ 0.6602, -0.9219, 2.6250, ..., -0.3047, -1.7734, -0.5000]]],\n",
" dtype=torch.bfloat16),\n",
" 't_38': tensor([0.3418], dtype=torch.bfloat16),\n",
" 'latents_38_start': tensor([[[ 0.6875, -0.2490, 0.2080, ..., -0.5820, 0.3945, 0.1465],\n",
" [ 0.6172, 0.0688, 0.0544, ..., 0.2432, -0.0476, -0.4785],\n",
" [ 0.2031, 0.2305, -0.0038, ..., -0.2080, -0.1621, -0.5781],\n",
" ...,\n",
" [-0.6719, -0.0688, 0.2676, ..., 0.3066, -0.0232, 0.4180],\n",
" [ 0.1963, 0.3418, -0.2158, ..., 0.6562, 0.8086, 0.6523],\n",
" [ 0.0248, -0.5039, 0.6367, ..., 0.9844, 0.3477, 0.7539]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_38': tensor([[[ 1.9766, -0.8477, 0.6484, ..., -1.7812, 0.9531, 0.3789],\n",
" [ 1.5234, 0.0781, -0.3496, ..., 0.8281, -0.1973, -1.7734],\n",
" [ 0.2617, 0.1660, 0.2949, ..., -1.0547, -1.7422, -2.6875],\n",
" ...,\n",
" [-1.1641, 0.4668, 1.6328, ..., -0.1836, -0.8125, 0.5547],\n",
" [ 1.2344, 1.5312, 0.0752, ..., -0.8867, 1.0234, -1.5156],\n",
" [ 0.6406, -0.9219, 2.6250, ..., -0.2988, -1.7812, -0.5117]]],\n",
" dtype=torch.bfloat16),\n",
" 't_39': tensor([0.3164], dtype=torch.bfloat16),\n",
" 'latents_39_start': tensor([[[ 0.6367, -0.2275, 0.1914, ..., -0.5352, 0.3711, 0.1367],\n",
" [ 0.5781, 0.0669, 0.0635, ..., 0.2217, -0.0425, -0.4336],\n",
" [ 0.1963, 0.2266, -0.0114, ..., -0.1807, -0.1172, -0.5078],\n",
" ...,\n",
" [-0.6406, -0.0811, 0.2256, ..., 0.3105, -0.0023, 0.4043],\n",
" [ 0.1641, 0.3027, -0.2178, ..., 0.6797, 0.7812, 0.6914],\n",
" [ 0.0083, -0.4805, 0.5703, ..., 0.9922, 0.3926, 0.7656]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_39': tensor([[[ 1.9531, -0.8320, 0.6523, ..., -1.7500, 0.9336, 0.3965],\n",
" [ 1.5156, 0.0767, -0.3691, ..., 0.8281, -0.1777, -1.7500],\n",
" [ 0.2637, 0.1729, 0.2734, ..., -1.0234, -1.6797, -2.6250],\n",
" ...,\n",
" [-1.1406, 0.4531, 1.6016, ..., -0.1973, -0.8164, 0.5234],\n",
" [ 1.2344, 1.5078, 0.0593, ..., -0.8906, 1.0234, -1.4844],\n",
" [ 0.6367, -0.9023, 2.5938, ..., -0.2910, -1.7500, -0.5078]]],\n",
" dtype=torch.bfloat16),\n",
" 't_40': tensor([0.2891], dtype=torch.bfloat16),\n",
" 'latents_40_start': tensor([[[ 0.5859, -0.2061, 0.1738, ..., -0.4883, 0.3457, 0.1260],\n",
" [ 0.5391, 0.0649, 0.0732, ..., 0.2002, -0.0378, -0.3867],\n",
" [ 0.1895, 0.2217, -0.0186, ..., -0.1543, -0.0732, -0.4395],\n",
" ...,\n",
" [-0.6094, -0.0928, 0.1836, ..., 0.3164, 0.0192, 0.3906],\n",
" [ 0.1318, 0.2637, -0.2197, ..., 0.7031, 0.7539, 0.7305],\n",
" [-0.0084, -0.4570, 0.5039, ..., 1.0000, 0.4375, 0.7773]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_40': tensor([[[ 1.9766, -0.8555, 0.6445, ..., -1.7500, 0.9414, 0.4004],\n",
" [ 1.5234, 0.0693, -0.3613, ..., 0.8672, -0.1650, -1.7266],\n",
" [ 0.2598, 0.1660, 0.2734, ..., -1.0156, -1.7031, -2.6250],\n",
" ...,\n",
" [-1.1641, 0.4531, 1.6016, ..., -0.1611, -0.8125, 0.5039],\n",
" [ 1.2109, 1.5156, 0.0391, ..., -0.8906, 0.9883, -1.4688],\n",
" [ 0.6250, -0.9102, 2.6094, ..., -0.2930, -1.7578, -0.4922]]],\n",
" dtype=torch.bfloat16),\n",
" 't_41': tensor([0.2617], dtype=torch.bfloat16),\n",
" 'latents_41_start': tensor([[[ 0.5312, -0.1826, 0.1562, ..., -0.4414, 0.3203, 0.1152],\n",
" [ 0.4980, 0.0630, 0.0830, ..., 0.1768, -0.0334, -0.3398],\n",
" [ 0.1826, 0.2168, -0.0259, ..., -0.1270, -0.0273, -0.3691],\n",
" ...,\n",
" [-0.5781, -0.1050, 0.1406, ..., 0.3203, 0.0410, 0.3770],\n",
" [ 0.0991, 0.2227, -0.2207, ..., 0.7266, 0.7266, 0.7695],\n",
" [-0.0253, -0.4316, 0.4336, ..., 1.0078, 0.4844, 0.7891]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_41': tensor([[[ 1.9375, -0.8555, 0.6367, ..., -1.7266, 0.9531, 0.4297],\n",
" [ 1.5312, 0.0659, -0.3633, ..., 0.8711, -0.1660, -1.7266],\n",
" [ 0.2520, 0.1826, 0.2676, ..., -0.9844, -1.6328, -2.5781],\n",
" ...,\n",
" [-1.1484, 0.4453, 1.5859, ..., -0.1562, -0.8281, 0.4922],\n",
" [ 1.1797, 1.5078, 0.0229, ..., -0.9102, 0.9531, -1.4609],\n",
" [ 0.6211, -0.9102, 2.5938, ..., -0.2832, -1.7266, -0.4727]]],\n",
" dtype=torch.bfloat16),\n",
" 't_42': tensor([0.2354], dtype=torch.bfloat16),\n",
" 'latents_42_start': tensor([[[ 0.4785, -0.1592, 0.1387, ..., -0.3945, 0.2949, 0.1035],\n",
" [ 0.4551, 0.0613, 0.0928, ..., 0.1523, -0.0288, -0.2930],\n",
" [ 0.1758, 0.2119, -0.0332, ..., -0.0996, 0.0178, -0.2969],\n",
" ...,\n",
" [-0.5469, -0.1172, 0.0967, ..., 0.3242, 0.0640, 0.3633],\n",
" [ 0.0664, 0.1816, -0.2217, ..., 0.7500, 0.6992, 0.8086],\n",
" [-0.0425, -0.4062, 0.3613, ..., 1.0156, 0.5312, 0.8008]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_42': tensor([[[ 1.9297, -0.8594, 0.6289, ..., -1.7031, 0.9609, 0.4297],\n",
" [ 1.5000, 0.0811, -0.3652, ..., 0.8711, -0.1494, -1.7031],\n",
" [ 0.2246, 0.1543, 0.2695, ..., -0.9492, -1.6328, -2.5312],\n",
" ...,\n",
" [-1.1328, 0.4395, 1.5781, ..., -0.1680, -0.8398, 0.4453],\n",
" [ 1.1484, 1.4766, 0.0073, ..., -0.9648, 0.8984, -1.4219],\n",
" [ 0.5938, -0.8672, 2.5312, ..., -0.2949, -1.7031, -0.4766]]],\n",
" dtype=torch.bfloat16),\n",
" 't_43': tensor([0.2070], dtype=torch.bfloat16),\n",
" 'latents_43_start': tensor([[[ 0.4238, -0.1348, 0.1211, ..., -0.3457, 0.2676, 0.0913],\n",
" [ 0.4121, 0.0591, 0.1030, ..., 0.1279, -0.0245, -0.2441],\n",
" [ 0.1699, 0.2080, -0.0408, ..., -0.0728, 0.0640, -0.2246],\n",
" ...,\n",
" [-0.5156, -0.1299, 0.0520, ..., 0.3281, 0.0879, 0.3516],\n",
" [ 0.0339, 0.1396, -0.2217, ..., 0.7773, 0.6719, 0.8477],\n",
" [-0.0593, -0.3809, 0.2891, ..., 1.0234, 0.5781, 0.8125]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_43': tensor([[[ 1.9141, -0.8711, 0.6328, ..., -1.6484, 0.9883, 0.4453],\n",
" [ 1.4844, 0.0693, -0.4004, ..., 0.8867, -0.1152, -1.6797],\n",
" [ 0.2432, 0.1484, 0.2090, ..., -0.8867, -1.5938, -2.4531],\n",
" ...,\n",
" [-1.1094, 0.4453, 1.5703, ..., -0.1865, -0.8594, 0.3906],\n",
" [ 1.0938, 1.4375, -0.0374, ..., -0.9648, 0.8359, -1.4531],\n",
" [ 0.5898, -0.8516, 2.4688, ..., -0.2773, -1.6484, -0.4746]]],\n",
" dtype=torch.bfloat16),\n",
" 't_44': tensor([0.1777], dtype=torch.bfloat16),\n",
" 'latents_44_start': tensor([[[ 0.3672, -0.1094, 0.1025, ..., -0.2969, 0.2393, 0.0781],\n",
" [ 0.3691, 0.0571, 0.1147, ..., 0.1021, -0.0212, -0.1953],\n",
" [ 0.1631, 0.2041, -0.0469, ..., -0.0469, 0.1104, -0.1533],\n",
" ...,\n",
" [-0.4844, -0.1426, 0.0063, ..., 0.3340, 0.1128, 0.3398],\n",
" [ 0.0022, 0.0977, -0.2207, ..., 0.8047, 0.6484, 0.8906],\n",
" [-0.0762, -0.3555, 0.2168, ..., 1.0312, 0.6250, 0.8281]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_44': tensor([[[ 1.8984, -0.8984, 0.5977, ..., -1.6094, 0.9805, 0.4551],\n",
" [ 1.4609, 0.0474, -0.4160, ..., 0.9102, -0.1060, -1.6484],\n",
" [ 0.2119, 0.0791, 0.2021, ..., -0.8359, -1.5391, -2.4219],\n",
" ...,\n",
" [-1.0859, 0.4395, 1.5391, ..., -0.2002, -0.8516, 0.3359],\n",
" [ 1.0625, 1.4219, -0.0486, ..., -0.9609, 0.7930, -1.4453],\n",
" [ 0.5898, -0.8008, 2.4219, ..., -0.3223, -1.5703, -0.4609]]],\n",
" dtype=torch.bfloat16),\n",
" 't_45': tensor([0.1484], dtype=torch.bfloat16),\n",
" 'latents_45_start': tensor([[[ 0.3105, -0.0825, 0.0850, ..., -0.2490, 0.2100, 0.0645],\n",
" [ 0.3262, 0.0557, 0.1270, ..., 0.0747, -0.0181, -0.1465],\n",
" [ 0.1562, 0.2021, -0.0530, ..., -0.0219, 0.1562, -0.0811],\n",
" ...,\n",
" [-0.4512, -0.1553, -0.0398, ..., 0.3398, 0.1387, 0.3301],\n",
" [-0.0295, 0.0552, -0.2197, ..., 0.8320, 0.6250, 0.9336],\n",
" [-0.0938, -0.3320, 0.1445, ..., 1.0391, 0.6719, 0.8438]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_45': tensor([[[ 1.8516, -0.8984, 0.5547, ..., -1.5547, 0.9688, 0.4727],\n",
" [ 1.4219, 0.0500, -0.4258, ..., 0.9141, -0.0588, -1.5703],\n",
" [ 0.1592, 0.0850, 0.1924, ..., -0.7500, -1.4766, -2.3125],\n",
" ...,\n",
" [-1.0469, 0.4414, 1.4766, ..., -0.1494, -0.8438, 0.3047],\n",
" [ 0.9883, 1.4062, -0.0874, ..., -0.9844, 0.7422, -1.4297],\n",
" [ 0.5742, -0.7578, 2.3125, ..., -0.3301, -1.3672, -0.4238]]],\n",
" dtype=torch.bfloat16),\n",
" 't_46': tensor([0.1172], dtype=torch.bfloat16),\n",
" 'latents_46_start': tensor([[[ 0.2539, -0.0549, 0.0679, ..., -0.2012, 0.1807, 0.0500],\n",
" [ 0.2832, 0.0542, 0.1396, ..., 0.0469, -0.0162, -0.0986],\n",
" [ 0.1514, 0.1992, -0.0588, ..., 0.0011, 0.2012, -0.0103],\n",
" ...,\n",
" [-0.4199, -0.1689, -0.0850, ..., 0.3438, 0.1641, 0.3203],\n",
" [-0.0598, 0.0122, -0.2168, ..., 0.8633, 0.6016, 0.9766],\n",
" [-0.1113, -0.3086, 0.0737, ..., 1.0469, 0.7148, 0.8555]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_46': tensor([[[ 1.7734, -0.9492, 0.5430, ..., -1.5312, 0.9727, 0.5273],\n",
" [ 1.4219, 0.0054, -0.4844, ..., 0.9297, 0.0198, -1.4531],\n",
" [ 0.1377, -0.0330, 0.1299, ..., -0.6914, -1.3906, -2.2188],\n",
" ...,\n",
" [-0.9727, 0.4297, 1.3906, ..., -0.1172, -0.8164, 0.2295],\n",
" [ 0.8945, 1.3516, -0.1758, ..., -0.9961, 0.6445, -1.5078],\n",
" [ 0.5234, -0.7422, 2.2031, ..., -0.3848, -1.1641, -0.4883]]],\n",
" dtype=torch.bfloat16),\n",
" 't_47': tensor([0.0854], dtype=torch.bfloat16),\n",
" 'latents_47_start': tensor([[[ 0.1982, -0.0250, 0.0508, ..., -0.1523, 0.1504, 0.0334],\n",
" [ 0.2383, 0.0540, 0.1553, ..., 0.0176, -0.0168, -0.0530],\n",
" [ 0.1475, 0.2002, -0.0630, ..., 0.0228, 0.2451, 0.0596],\n",
" ...,\n",
" [-0.3887, -0.1826, -0.1289, ..., 0.3477, 0.1895, 0.3125],\n",
" [-0.0879, -0.0303, -0.2109, ..., 0.8945, 0.5820, 1.0234],\n",
" [-0.1279, -0.2852, 0.0044, ..., 1.0625, 0.7500, 0.8711]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_47': tensor([[[ 1.6250, -0.9688, 0.4160, ..., -1.4609, 0.9961, 0.5742],\n",
" [ 1.3594, 0.0728, -0.5430, ..., 0.9062, 0.0530, -1.3438],\n",
" [ 0.1553, -0.1787, 0.0908, ..., -0.5820, -1.1875, -1.9688],\n",
" ...,\n",
" [-0.8281, 0.4160, 1.2422, ..., -0.0122, -0.7500, 0.1396],\n",
" [ 0.7734, 1.2812, -0.2295, ..., -0.9883, 0.5039, -1.4844],\n",
" [ 0.4453, -0.6719, 1.9688, ..., -0.4180, -0.9141, -0.6211]]],\n",
" dtype=torch.bfloat16),\n",
" 't_48': tensor([0.0532], dtype=torch.bfloat16),\n",
" 'latents_48_start': tensor([[[ 0.1455, 0.0065, 0.0374, ..., -0.1050, 0.1182, 0.0148],\n",
" [ 0.1943, 0.0515, 0.1729, ..., -0.0118, -0.0186, -0.0093],\n",
" [ 0.1426, 0.2061, -0.0659, ..., 0.0417, 0.2832, 0.1235],\n",
" ...,\n",
" [-0.3613, -0.1963, -0.1689, ..., 0.3477, 0.2139, 0.3086],\n",
" [-0.1133, -0.0718, -0.2031, ..., 0.9258, 0.5664, 1.0703],\n",
" [-0.1426, -0.2637, -0.0596, ..., 1.0781, 0.7812, 0.8906]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred_48': tensor([[[ 1.2031, -0.9297, 0.2559, ..., -1.4141, 0.8906, 0.4258],\n",
" [ 1.1016, 0.0625, -0.3242, ..., 0.8047, 0.1318, -1.1094],\n",
" [ 0.1094, -0.2695, 0.2334, ..., -0.5820, -1.1016, -1.5234],\n",
" ...,\n",
" [-0.5938, 0.4551, 1.0938, ..., 0.0281, -0.6289, 0.1357],\n",
" [ 0.6523, 1.0625, -0.2275, ..., -1.0938, 0.4297, -1.3984],\n",
" [ 0.3906, -0.4180, 1.5391, ..., -0.5742, -0.6250, -0.7852]]],\n",
" dtype=torch.bfloat16),\n",
" 't_49': tensor([0.0200], dtype=torch.bfloat16),\n",
" 'latents_49_start': tensor([[[ 1.0547e-01, 3.7354e-02, 2.8809e-02, ..., -5.8105e-02,\n",
" 8.8867e-02, 6.1035e-04],\n",
" [ 1.5820e-01, 4.9316e-02, 1.8359e-01, ..., -3.8574e-02,\n",
" -2.2949e-02, 2.7588e-02],\n",
" [ 1.3867e-01, 2.1484e-01, -7.3730e-02, ..., 6.1035e-02,\n",
" 3.2031e-01, 1.7383e-01],\n",
" ...,\n",
" [-3.4180e-01, -2.1094e-01, -2.0508e-01, ..., 3.4766e-01,\n",
" 2.3438e-01, 3.0469e-01],\n",
" [-1.3477e-01, -1.0693e-01, -1.9531e-01, ..., 9.6094e-01,\n",
" 5.5078e-01, 1.1172e+00],\n",
" [-1.5527e-01, -2.5000e-01, -1.1035e-01, ..., 1.0938e+00,\n",
" 8.0078e-01, 9.1797e-01]]], dtype=torch.bfloat16),\n",
" 'noise_pred_49': tensor([[[ 0.7461, -0.5586, 0.2197, ..., -1.0469, 0.7109, 0.4902],\n",
" [ 0.6094, 0.0464, -0.1650, ..., 0.4980, 0.2314, -0.9414],\n",
" [ 0.1064, -0.2109, 0.1846, ..., -0.3633, -0.8086, -1.0234],\n",
" ...,\n",
" [-0.2559, 0.3711, 0.7461, ..., -0.2217, -0.2988, 0.0339],\n",
" [ 0.4980, 0.5156, -0.0260, ..., -1.1250, 0.1064, -1.1250],\n",
" [ 0.2471, 0.0179, 0.6875, ..., -0.7188, -0.5898, -0.8672]]],\n",
" dtype=torch.bfloat16),\n",
" 'output': tensor([[[ 0.0903, 0.0486, 0.0244, ..., -0.0371, 0.0747, -0.0092],\n",
" [ 0.1465, 0.0483, 0.1865, ..., -0.0486, -0.0276, 0.0464],\n",
" [ 0.1367, 0.2188, -0.0776, ..., 0.0684, 0.3359, 0.1943],\n",
" ...,\n",
" [-0.3359, -0.2188, -0.2197, ..., 0.3516, 0.2402, 0.3047],\n",
" [-0.1445, -0.1172, -0.1943, ..., 0.9844, 0.5469, 1.1406],\n",
" [-0.1602, -0.2500, -0.1240, ..., 1.1094, 0.8125, 0.9336]]],\n",
" dtype=torch.bfloat16),\n",
" 'height': 576,\n",
" 'width': 448,\n",
" 't': tensor([1.], dtype=torch.bfloat16),\n",
" 'latents_start': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
" [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
" [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
" ...,\n",
" [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
" [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
" [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
" dtype=torch.bfloat16),\n",
" 'noise_pred': tensor([[[ 1.8906, -0.8945, 0.5938, ..., -1.7578, 1.0078, 0.2539],\n",
" [ 1.5781, 0.0278, -0.2793, ..., 0.7305, -0.1553, -1.7969],\n",
" [ 0.3027, 0.2949, 0.1621, ..., -1.0625, -1.5938, -2.6406],\n",
" ...,\n",
" [-1.2578, 0.5352, 1.5859, ..., -0.2773, -1.0312, 0.3203],\n",
" [ 1.2734, 1.5312, 0.0728, ..., -0.6211, 0.8984, -1.1562],\n",
" [ 0.6172, -0.9336, 2.6719, ..., -0.1050, -1.8672, -0.3691]]],\n",
" dtype=torch.bfloat16)}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"src[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22f19ae9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|