{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e1e781e9", "metadata": {}, "outputs": [], "source": [ "%cd /home/ubuntu/Qwen-Image-Edit-Angles" ] }, { "cell_type": "code", "execution_count": null, "id": "d6192ee5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4941" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import glob\n", "from pathlib import Path\n", "\n", "base_data = Path(\"/data/regression_output\")\n", "\n", "all_reg = list(base_data.glob(\"*.pt\"))\n", "max_ind = max([int(reg_pth.stem) for reg_pth in all_reg])\n", "\n", "max_ind" ] }, { "cell_type": "code", "execution_count": 14, "id": "b5124900", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "prompt_embeds\n", "prompt_embeds_mask\n", "noise\n", "image_latents\n", "vae_image_sizes\n", "img_shapes\n", "txt_seq_lens\n", "t_0\n", "latents_0_start\n", "noise_pred_0\n", "t_1\n", "latents_1_start\n", "noise_pred_1\n", "t_2\n", "latents_2_start\n", "noise_pred_2\n", "t_3\n", "latents_3_start\n", "noise_pred_3\n", "t_4\n", "latents_4_start\n", "noise_pred_4\n", "t_5\n", "latents_5_start\n", "noise_pred_5\n", "t_6\n", "latents_6_start\n", "noise_pred_6\n", "t_7\n", "latents_7_start\n", "noise_pred_7\n", "t_8\n", "latents_8_start\n", "noise_pred_8\n", "t_9\n", "latents_9_start\n", "noise_pred_9\n", "t_10\n", "latents_10_start\n", "noise_pred_10\n", "t_11\n", "latents_11_start\n", "noise_pred_11\n", "t_12\n", "latents_12_start\n", "noise_pred_12\n", "t_13\n", "latents_13_start\n", "noise_pred_13\n", "t_14\n", "latents_14_start\n", "noise_pred_14\n", "t_15\n", "latents_15_start\n", "noise_pred_15\n", "t_16\n", "latents_16_start\n", "noise_pred_16\n", "t_17\n", "latents_17_start\n", "noise_pred_17\n", "t_18\n", "latents_18_start\n", "noise_pred_18\n", "t_19\n", "latents_19_start\n", "noise_pred_19\n", "t_20\n", "latents_20_start\n", "noise_pred_20\n", "t_21\n", "latents_21_start\n", "noise_pred_21\n", "t_22\n", "latents_22_start\n", "noise_pred_22\n", "t_23\n", "latents_23_start\n", "noise_pred_23\n", "t_24\n", "latents_24_start\n", "noise_pred_24\n", "t_25\n", "latents_25_start\n", "noise_pred_25\n", "t_26\n", "latents_26_start\n", "noise_pred_26\n", "t_27\n", "latents_27_start\n", "noise_pred_27\n", "t_28\n", "latents_28_start\n", "noise_pred_28\n", "t_29\n", "latents_29_start\n", "noise_pred_29\n", "t_30\n", "latents_30_start\n", "noise_pred_30\n", "t_31\n", "latents_31_start\n", "noise_pred_31\n", "t_32\n", "latents_32_start\n", "noise_pred_32\n", "t_33\n", "latents_33_start\n", "noise_pred_33\n", "t_34\n", "latents_34_start\n", "noise_pred_34\n", "t_35\n", "latents_35_start\n", "noise_pred_35\n", "t_36\n", "latents_36_start\n", "noise_pred_36\n", "t_37\n", "latents_37_start\n", "noise_pred_37\n", "t_38\n", "latents_38_start\n", "noise_pred_38\n", "t_39\n", "latents_39_start\n", "noise_pred_39\n", "t_40\n", "latents_40_start\n", "noise_pred_40\n", "t_41\n", "latents_41_start\n", "noise_pred_41\n", "t_42\n", "latents_42_start\n", "noise_pred_42\n", "t_43\n", "latents_43_start\n", "noise_pred_43\n", "t_44\n", "latents_44_start\n", "noise_pred_44\n", "t_45\n", "latents_45_start\n", "noise_pred_45\n", "t_46\n", "latents_46_start\n", "noise_pred_46\n", "t_47\n", "latents_47_start\n", "noise_pred_47\n", "t_48\n", "latents_48_start\n", "noise_pred_48\n", "t_49\n", "latents_49_start\n", "noise_pred_49\n", "output\n", "height\n", "width\n" ] } ], "source": [ "import torch\n", "\n", "out = all_reg[0]\n", "out_dict = torch.load(out)\n", "for k in out_dict.keys():\n", " print(k)" ] }, { "cell_type": "code", "execution_count": null, "id": "74f693db", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'003329'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "da107d9f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "69G\t/data/regression_output\n" ] } ], "source": [ "!du -h {base_data}" ] }, { "cell_type": "code", "execution_count": null, "id": "269c0bfb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 16, "id": "5964bf2b", "metadata": {}, "outputs": [], "source": [ "class RegressionSource:\n", " # WIP\n", "\n", " def __init__(self, data_dir, gen_steps=50):\n", " if not isinstance(data_dir, Path):\n", " data_dir = Path(data_dir)\n", " self.data_paths = list(data_dir.glob(\"*.pt\"))\n", " self.gen_steps = gen_steps\n", " self._len = gen_steps * len(self.data_paths)\n", " \n", " def __len__(self):\n", " return self._len\n", " \n", " def __getitem__(self, idx):\n", " data_idx = idx // self.gen_steps\n", " step_idx = idx % self.gen_steps\n", " out_dict = torch.load(self.data_paths[data_idx])\n", " t = out_dict.pop(f\"t_{step_idx}\")\n", " latents_start = out_dict.pop(f\"latents_{step_idx}_start\")\n", " noise_pred = out_dict.pop(f\"noise_pred_{step_idx}\")\n", " out_dict[\"t\"] = t\n", " out_dict[\"latents_start\"] = latents_start\n", " out_dict[\"noise_pred\"] = noise_pred\n", " return out_dict\n", "\n", " \n" ] }, { "cell_type": "code", "execution_count": 17, "id": "b62e7bec", "metadata": {}, "outputs": [], "source": [ "src = RegressionSource(\"/data/regression_output\")" ] }, { "cell_type": "code", "execution_count": null, "id": "4ee68ab3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 18, "id": "9738e1d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'prompt_embeds': tensor([[[ 3.2188, 3.4375, 3.1719, ..., 0.3535, 1.7812, 2.0312],\n", " [ 3.0938, 1.9297, 0.7031, ..., 2.0625, -0.2314, 1.2266],\n", " [ 2.6250, 1.7031, 3.5625, ..., 0.8828, 2.1719, 1.4766],\n", " ...,\n", " [ 4.7812, 0.1689, 4.4688, ..., 5.0000, -1.8359, -0.7500],\n", " [-0.0654, 2.1406, -1.4922, ..., 0.7930, 3.9844, 1.6406],\n", " [-2.7031, 1.5547, 2.6094, ..., -0.0481, 0.1582, 0.7383]]],\n", " dtype=torch.bfloat16),\n", " 'prompt_embeds_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n", " 'noise': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n", " [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n", " [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n", " ...,\n", " [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n", " [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n", " [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n", " dtype=torch.bfloat16),\n", " 'image_latents': tensor([[[ 0.1719, 0.0194, 0.0084, ..., -0.1494, 0.0552, 0.2295],\n", " [ 0.1777, 0.1406, 0.1592, ..., 0.1260, -0.2412, -0.0041],\n", " [ 0.1187, 0.2324, 0.1104, ..., 0.0801, 0.3516, 0.4414],\n", " ...,\n", " [-0.0972, -0.3242, -0.3027, ..., 0.3672, 0.1699, 0.4004],\n", " [-0.1221, -0.0125, -0.3867, ..., 0.7031, 0.8477, 0.8320],\n", " [-0.1416, -0.1914, -0.3359, ..., 0.9883, 1.3359, 0.7422]]],\n", " dtype=torch.bfloat16),\n", " 'vae_image_sizes': [(448, 576)],\n", " 'img_shapes': [[(1, 36, 28), (1, 36, 28)]],\n", " 'txt_seq_lens': [228],\n", " 't_1': tensor([0.9883], dtype=torch.bfloat16),\n", " 'latents_1_start': tensor([[[ 1.9531, -0.7930, 0.6289, ..., -1.7188, 1.0312, 0.3770],\n", " [ 1.6406, 0.1143, -0.1533, ..., 0.7461, -0.1748, -1.6719],\n", " [ 0.3945, 0.3887, 0.1895, ..., -0.9141, -1.3125, -2.2969],\n", " ...,\n", " [-1.4609, 0.2471, 1.3203, ..., 0.1826, -0.6133, 0.7578],\n", " [ 1.0234, 1.3516, -0.1582, ..., 0.1226, 1.4609, -0.2715],\n", " [ 0.4863, -1.1016, 2.3125, ..., 0.8281, -0.8086, 0.4297]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_1': tensor([[[ 1.9062, -0.9102, 0.5742, ..., -1.7422, 1.0625, 0.3359],\n", " [ 1.5859, 0.0306, -0.2637, ..., 0.7539, -0.1768, -1.7969],\n", " [ 0.3184, 0.3066, 0.1592, ..., -1.0391, -1.5391, -2.5625],\n", " ...,\n", " [-1.2734, 0.4941, 1.5781, ..., -0.2344, -1.0156, 0.3477],\n", " [ 1.2422, 1.5234, 0.0510, ..., -0.5820, 0.9219, -1.0859],\n", " [ 0.6172, -0.9336, 2.5781, ..., -0.0801, -1.7734, -0.3730]]],\n", " dtype=torch.bfloat16),\n", " 't_2': tensor([0.9766], dtype=torch.bfloat16),\n", " 'latents_2_start': tensor([[[ 1.9297, -0.7812, 0.6211, ..., -1.6953, 1.0156, 0.3730],\n", " [ 1.6250, 0.1138, -0.1504, ..., 0.7383, -0.1729, -1.6484],\n", " [ 0.3906, 0.3848, 0.1875, ..., -0.9023, -1.2969, -2.2656],\n", " ...,\n", " [-1.4453, 0.2412, 1.3047, ..., 0.1855, -0.6016, 0.7539],\n", " [ 1.0078, 1.3359, -0.1592, ..., 0.1299, 1.4531, -0.2578],\n", " [ 0.4785, -1.0938, 2.2812, ..., 0.8281, -0.7891, 0.4336]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_2': tensor([[[ 1.8984, -0.9219, 0.5664, ..., -1.7188, 1.0703, 0.3633],\n", " [ 1.5859, 0.0256, -0.2539, ..., 0.7578, -0.1719, -1.7656],\n", " [ 0.3105, 0.3027, 0.1611, ..., -1.0000, -1.4688, -2.4688],\n", " ...,\n", " [-1.2969, 0.4453, 1.5625, ..., -0.1934, -0.9883, 0.4082],\n", " [ 1.2188, 1.4844, -0.0028, ..., -0.4492, 1.0312, -0.9180],\n", " [ 0.5820, -1.0156, 2.5156, ..., 0.1885, -1.5391, -0.1602]]],\n", " dtype=torch.bfloat16),\n", " 't_3': tensor([0.9648], dtype=torch.bfloat16),\n", " 'latents_3_start': tensor([[[ 1.9062, -0.7695, 0.6133, ..., -1.6719, 1.0000, 0.3691],\n", " [ 1.6094, 0.1133, -0.1475, ..., 0.7305, -0.1709, -1.6250],\n", " [ 0.3867, 0.3809, 0.1855, ..., -0.8906, -1.2812, -2.2344],\n", " ...,\n", " [-1.4297, 0.2354, 1.2891, ..., 0.1875, -0.5898, 0.7500],\n", " [ 0.9922, 1.3203, -0.1592, ..., 0.1357, 1.4375, -0.2461],\n", " [ 0.4707, -1.0781, 2.2500, ..., 0.8242, -0.7695, 0.4355]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_3': tensor([[[ 1.8984, -0.9180, 0.5430, ..., -1.7031, 1.0938, 0.3691],\n", " [ 1.5703, 0.0308, -0.2676, ..., 0.7812, -0.1602, -1.7422],\n", " [ 0.3164, 0.2949, 0.1514, ..., -0.9922, -1.4609, -2.4531],\n", " ...,\n", " [-1.3203, 0.4277, 1.5234, ..., -0.1611, -0.9688, 0.4434],\n", " [ 1.1875, 1.4609, -0.0179, ..., -0.4355, 1.0312, -0.8867],\n", " [ 0.5547, -1.0234, 2.4844, ..., 0.2344, -1.4844, -0.1025]]],\n", " dtype=torch.bfloat16),\n", " 't_4': tensor([0.9531], dtype=torch.bfloat16),\n", " 'latents_4_start': tensor([[[ 1.8828, -0.7578, 0.6055, ..., -1.6484, 0.9844, 0.3652],\n", " [ 1.5859, 0.1128, -0.1445, ..., 0.7188, -0.1689, -1.6016],\n", " [ 0.3828, 0.3770, 0.1836, ..., -0.8789, -1.2656, -2.2031],\n", " ...,\n", " [-1.4141, 0.2305, 1.2734, ..., 0.1895, -0.5781, 0.7461],\n", " [ 0.9766, 1.3047, -0.1592, ..., 0.1416, 1.4219, -0.2354],\n", " [ 0.4629, -1.0625, 2.2188, ..., 0.8203, -0.7500, 0.4375]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_4': tensor([[[ 1.8984, -0.9141, 0.5508, ..., -1.7109, 1.0859, 0.3672],\n", " [ 1.5625, 0.0238, -0.2754, ..., 0.7656, -0.1768, -1.7578],\n", " [ 0.3105, 0.2988, 0.1602, ..., -1.0156, -1.4766, -2.4375],\n", " ...,\n", " [-1.3125, 0.4316, 1.5469, ..., -0.1621, -0.9805, 0.4141],\n", " [ 1.1953, 1.4844, -0.0118, ..., -0.4590, 1.0078, -0.9492],\n", " [ 0.5703, -1.0156, 2.5156, ..., 0.1777, -1.5469, -0.1475]]],\n", " dtype=torch.bfloat16),\n", " 't_5': tensor([0.9414], dtype=torch.bfloat16),\n", " 'latents_5_start': tensor([[[ 1.8594, -0.7461, 0.5977, ..., -1.6250, 0.9688, 0.3613],\n", " [ 1.5625, 0.1123, -0.1406, ..., 0.7109, -0.1670, -1.5781],\n", " [ 0.3789, 0.3730, 0.1816, ..., -0.8672, -1.2500, -2.1719],\n", " ...,\n", " [-1.3984, 0.2246, 1.2500, ..., 0.1914, -0.5664, 0.7422],\n", " [ 0.9609, 1.2891, -0.1592, ..., 0.1475, 1.4062, -0.2236],\n", " [ 0.4551, -1.0469, 2.1875, ..., 0.8164, -0.7305, 0.4395]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_5': tensor([[[ 1.8906, -0.8984, 0.5586, ..., -1.7031, 1.0391, 0.3516],\n", " [ 1.5625, 0.0388, -0.2871, ..., 0.8008, -0.1504, -1.7734],\n", " [ 0.3008, 0.2949, 0.1777, ..., -1.0312, -1.5781, -2.5469],\n", " ...,\n", " [-1.3047, 0.4688, 1.5938, ..., -0.2188, -1.0781, 0.3945],\n", " [ 1.2031, 1.4844, 0.0082, ..., -0.5469, 0.9414, -1.1875],\n", " [ 0.5781, -0.9336, 2.5625, ..., -0.0903, -1.8047, -0.3828]]],\n", " dtype=torch.bfloat16),\n", " 't_6': tensor([0.9258], dtype=torch.bfloat16),\n", " 'latents_6_start': tensor([[[ 1.8359, -0.7344, 0.5898, ..., -1.6016, 0.9570, 0.3574],\n", " [ 1.5391, 0.1118, -0.1367, ..., 0.6992, -0.1650, -1.5547],\n", " [ 0.3750, 0.3691, 0.1797, ..., -0.8555, -1.2266, -2.1406],\n", " ...,\n", " [-1.3828, 0.2188, 1.2266, ..., 0.1943, -0.5508, 0.7383],\n", " [ 0.9453, 1.2734, -0.1592, ..., 0.1543, 1.3906, -0.2080],\n", " [ 0.4473, -1.0312, 2.1562, ..., 0.8164, -0.7070, 0.4453]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_6': tensor([[[ 1.9219, -0.8828, 0.5820, ..., -1.7109, 1.0781, 0.3613],\n", " [ 1.5703, 0.0359, -0.2812, ..., 0.7773, -0.1865, -1.8203],\n", " [ 0.3301, 0.2949, 0.1924, ..., -1.0781, -1.6016, -2.5312],\n", " ...,\n", " [-1.2734, 0.4941, 1.6094, ..., -0.2236, -1.0391, 0.3633],\n", " [ 1.2656, 1.5469, 0.0796, ..., -0.6797, 0.8672, -1.2656],\n", " [ 0.6484, -0.9102, 2.5938, ..., -0.1904, -1.8516, -0.4590]]],\n", " dtype=torch.bfloat16),\n", " 't_7': tensor([0.9102], dtype=torch.bfloat16),\n", " 'latents_7_start': tensor([[[ 1.8125, -0.7227, 0.5820, ..., -1.5781, 0.9414, 0.3535],\n", " [ 1.5156, 0.1113, -0.1328, ..., 0.6875, -0.1621, -1.5312],\n", " [ 0.3711, 0.3652, 0.1768, ..., -0.8398, -1.2031, -2.1094],\n", " ...,\n", " [-1.3672, 0.2119, 1.2031, ..., 0.1973, -0.5352, 0.7344],\n", " [ 0.9297, 1.2500, -0.1602, ..., 0.1631, 1.3828, -0.1914],\n", " [ 0.4395, -1.0156, 2.1250, ..., 0.8203, -0.6836, 0.4512]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_7': tensor([[[ 1.9531, -0.8906, 0.5938, ..., -1.7266, 1.0938, 0.4180],\n", " [ 1.5781, 0.0309, -0.3008, ..., 0.7969, -0.1699, -1.8281],\n", " [ 0.3262, 0.3008, 0.2314, ..., -1.0781, -1.6797, -2.6250],\n", " ...,\n", " [-1.2812, 0.5039, 1.5938, ..., -0.2314, -1.0547, 0.3828],\n", " [ 1.2734, 1.5781, 0.0859, ..., -0.7930, 0.8555, -1.3516],\n", " [ 0.6914, -0.9062, 2.6250, ..., -0.2598, -1.8516, -0.4902]]],\n", " dtype=torch.bfloat16),\n", " 't_8': tensor([0.8984], dtype=torch.bfloat16),\n", " 'latents_8_start': tensor([[[ 1.7891, -0.7109, 0.5742, ..., -1.5547, 0.9258, 0.3477],\n", " [ 1.4922, 0.1108, -0.1289, ..., 0.6758, -0.1602, -1.5078],\n", " [ 0.3672, 0.3613, 0.1738, ..., -0.8242, -1.1797, -2.0781],\n", " ...,\n", " [-1.3516, 0.2051, 1.1797, ..., 0.2002, -0.5195, 0.7305],\n", " [ 0.9141, 1.2266, -0.1611, ..., 0.1738, 1.3750, -0.1729],\n", " [ 0.4297, -1.0000, 2.0938, ..., 0.8242, -0.6602, 0.4570]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_8': tensor([[[ 1.9453, -0.8789, 0.6094, ..., -1.7266, 1.0781, 0.4082],\n", " [ 1.5703, 0.0396, -0.3047, ..., 0.7891, -0.1826, -1.8516],\n", " [ 0.3164, 0.2949, 0.2500, ..., -1.0859, -1.7031, -2.6250],\n", " ...,\n", " [-1.2578, 0.5234, 1.5938, ..., -0.2246, -1.0547, 0.3770],\n", " [ 1.2734, 1.6016, 0.0884, ..., -0.8828, 0.8086, -1.3672],\n", " [ 0.7070, -0.8828, 2.6094, ..., -0.2832, -1.8750, -0.5117]]],\n", " dtype=torch.bfloat16),\n", " 't_9': tensor([0.8828], dtype=torch.bfloat16),\n", " 'latents_9_start': tensor([[[ 1.7656, -0.6992, 0.5664, ..., -1.5312, 0.9102, 0.3418],\n", " [ 1.4688, 0.1104, -0.1245, ..., 0.6641, -0.1572, -1.4844],\n", " [ 0.3633, 0.3574, 0.1699, ..., -0.8086, -1.1562, -2.0469],\n", " ...,\n", " [-1.3359, 0.1982, 1.1562, ..., 0.2031, -0.5039, 0.7266],\n", " [ 0.8984, 1.2031, -0.1621, ..., 0.1855, 1.3672, -0.1543],\n", " [ 0.4199, -0.9883, 2.0625, ..., 0.8281, -0.6328, 0.4648]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_9': tensor([[[ 1.9531, -0.8828, 0.6172, ..., -1.7188, 1.0391, 0.4141],\n", " [ 1.5703, 0.0583, -0.3125, ..., 0.7930, -0.1582, -1.8594],\n", " [ 0.3203, 0.2910, 0.2598, ..., -1.1016, -1.7500, -2.6562],\n", " ...,\n", " [-1.2422, 0.5273, 1.6094, ..., -0.2168, -1.0391, 0.4121],\n", " [ 1.2656, 1.6172, 0.1001, ..., -0.8984, 0.8008, -1.4062],\n", " [ 0.7383, -0.8750, 2.6250, ..., -0.2891, -1.8672, -0.5312]]],\n", " dtype=torch.bfloat16),\n", " 't_10': tensor([0.8711], dtype=torch.bfloat16),\n", " 'latents_10_start': tensor([[[ 1.7344, -0.6875, 0.5586, ..., -1.5078, 0.8945, 0.3359],\n", " [ 1.4453, 0.1094, -0.1201, ..., 0.6523, -0.1553, -1.4609],\n", " [ 0.3594, 0.3535, 0.1660, ..., -0.7930, -1.1328, -2.0156],\n", " ...,\n", " [-1.3203, 0.1904, 1.1328, ..., 0.2061, -0.4902, 0.7227],\n", " [ 0.8789, 1.1797, -0.1631, ..., 0.1982, 1.3594, -0.1348],\n", " [ 0.4102, -0.9766, 2.0312, ..., 0.8320, -0.6055, 0.4727]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_10': tensor([[[ 1.9609, -0.8672, 0.6445, ..., -1.7188, 1.0156, 0.4180],\n", " [ 1.5781, 0.0728, -0.3125, ..., 0.7812, -0.1602, -1.8828],\n", " [ 0.3125, 0.2832, 0.2832, ..., -1.1172, -1.7734, -2.6875],\n", " ...,\n", " [-1.2500, 0.5273, 1.6016, ..., -0.2227, -1.0625, 0.4062],\n", " [ 1.2500, 1.6094, 0.1016, ..., -0.9297, 0.8086, -1.4219],\n", " [ 0.7617, -0.8555, 2.6406, ..., -0.3105, -1.8594, -0.5352]]],\n", " dtype=torch.bfloat16),\n", " 't_11': tensor([0.8555], dtype=torch.bfloat16),\n", " 'latents_11_start': tensor([[[ 1.7031, -0.6758, 0.5508, ..., -1.4844, 0.8789, 0.3301],\n", " [ 1.4219, 0.1084, -0.1157, ..., 0.6406, -0.1533, -1.4375],\n", " [ 0.3555, 0.3496, 0.1621, ..., -0.7773, -1.1094, -1.9766],\n", " ...,\n", " [-1.3047, 0.1826, 1.1094, ..., 0.2090, -0.4746, 0.7188],\n", " [ 0.8594, 1.1562, -0.1641, ..., 0.2119, 1.3516, -0.1143],\n", " [ 0.3984, -0.9648, 1.9922, ..., 0.8359, -0.5781, 0.4805]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_11': tensor([[[ 1.9688, -0.8516, 0.6484, ..., -1.7266, 1.0000, 0.4082],\n", " [ 1.5938, 0.0679, -0.3105, ..., 0.8086, -0.1455, -1.8984],\n", " [ 0.3203, 0.2812, 0.2949, ..., -1.1094, -1.7812, -2.6719],\n", " ...,\n", " [-1.2500, 0.5273, 1.5938, ..., -0.2119, -1.0625, 0.4102],\n", " [ 1.2656, 1.6016, 0.1011, ..., -0.9180, 0.8281, -1.4531],\n", " [ 0.7695, -0.8320, 2.6562, ..., -0.2891, -1.8516, -0.5234]]],\n", " dtype=torch.bfloat16),\n", " 't_12': tensor([0.8438], dtype=torch.bfloat16),\n", " 'latents_12_start': tensor([[[ 1.6719, -0.6641, 0.5430, ..., -1.4609, 0.8633, 0.3242],\n", " [ 1.3984, 0.1074, -0.1113, ..., 0.6289, -0.1514, -1.4062],\n", " [ 0.3516, 0.3457, 0.1582, ..., -0.7617, -1.0859, -1.9375],\n", " ...,\n", " [-1.2891, 0.1748, 1.0859, ..., 0.2119, -0.4590, 0.7109],\n", " [ 0.8398, 1.1328, -0.1660, ..., 0.2256, 1.3359, -0.0933],\n", " [ 0.3867, -0.9531, 1.9531, ..., 0.8398, -0.5508, 0.4883]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_12': tensor([[[ 1.9688, -0.8477, 0.6602, ..., -1.7422, 0.9805, 0.3965],\n", " [ 1.5938, 0.0845, -0.3066, ..., 0.7891, -0.1816, -1.9062],\n", " [ 0.3105, 0.2754, 0.3242, ..., -1.1328, -1.8047, -2.6875],\n", " ...,\n", " [-1.2422, 0.5195, 1.5938, ..., -0.2227, -1.0625, 0.4180],\n", " [ 1.2500, 1.6172, 0.1138, ..., -0.9492, 0.8281, -1.4609],\n", " [ 0.7852, -0.8555, 2.6562, ..., -0.3047, -1.8438, -0.5430]]],\n", " dtype=torch.bfloat16),\n", " 't_13': tensor([0.8281], dtype=torch.bfloat16),\n", " 'latents_13_start': tensor([[[ 1.6406, -0.6523, 0.5312, ..., -1.4375, 0.8477, 0.3184],\n", " [ 1.3750, 0.1060, -0.1069, ..., 0.6172, -0.1484, -1.3750],\n", " [ 0.3477, 0.3418, 0.1533, ..., -0.7461, -1.0625, -1.8984],\n", " ...,\n", " [-1.2734, 0.1670, 1.0625, ..., 0.2148, -0.4434, 0.7031],\n", " [ 0.8203, 1.1094, -0.1680, ..., 0.2393, 1.3203, -0.0718],\n", " [ 0.3750, -0.9414, 1.9141, ..., 0.8438, -0.5234, 0.4961]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_13': tensor([[[ 1.9688, -0.8438, 0.6562, ..., -1.7500, 0.9805, 0.3906],\n", " [ 1.5938, 0.0791, -0.3066, ..., 0.7734, -0.1934, -1.9062],\n", " [ 0.3145, 0.2734, 0.3203, ..., -1.1250, -1.8359, -2.7188],\n", " ...,\n", " [-1.2422, 0.5156, 1.6094, ..., -0.2021, -1.0312, 0.4609],\n", " [ 1.2656, 1.6250, 0.1108, ..., -0.9453, 0.8789, -1.4688],\n", " [ 0.7930, -0.8594, 2.6719, ..., -0.3105, -1.8359, -0.5469]]],\n", " dtype=torch.bfloat16),\n", " 't_14': tensor([0.8125], dtype=torch.bfloat16),\n", " 'latents_14_start': tensor([[[ 1.6094, -0.6406, 0.5195, ..., -1.4141, 0.8320, 0.3125],\n", " [ 1.3516, 0.1050, -0.1025, ..., 0.6055, -0.1455, -1.3438],\n", " [ 0.3438, 0.3379, 0.1484, ..., -0.7305, -1.0312, -1.8594],\n", " ...,\n", " [-1.2578, 0.1592, 1.0391, ..., 0.2178, -0.4277, 0.6953],\n", " [ 0.8008, 1.0859, -0.1699, ..., 0.2539, 1.3047, -0.0498],\n", " [ 0.3633, -0.9297, 1.8750, ..., 0.8477, -0.4961, 0.5039]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_14': tensor([[[ 1.9609, -0.8438, 0.6562, ..., -1.7500, 0.9727, 0.3828],\n", " [ 1.5938, 0.0840, -0.3203, ..., 0.7695, -0.2061, -1.8906],\n", " [ 0.3125, 0.2754, 0.3262, ..., -1.1328, -1.8359, -2.7031],\n", " ...,\n", " [-1.2422, 0.5117, 1.6094, ..., -0.2002, -1.0156, 0.4746],\n", " [ 1.2656, 1.6172, 0.1108, ..., -0.9219, 0.8945, -1.4609],\n", " [ 0.7969, -0.8672, 2.6406, ..., -0.3047, -1.7969, -0.5430]]],\n", " dtype=torch.bfloat16),\n", " 't_15': tensor([0.7969], dtype=torch.bfloat16),\n", " 'latents_15_start': tensor([[[ 1.5781, -0.6289, 0.5078, ..., -1.3906, 0.8164, 0.3066],\n", " [ 1.3281, 0.1035, -0.0977, ..., 0.5938, -0.1426, -1.3125],\n", " [ 0.3398, 0.3340, 0.1436, ..., -0.7148, -1.0000, -1.8203],\n", " ...,\n", " [-1.2422, 0.1514, 1.0156, ..., 0.2207, -0.4121, 0.6875],\n", " [ 0.7812, 1.0625, -0.1719, ..., 0.2676, 1.2891, -0.0275],\n", " [ 0.3516, -0.9180, 1.8359, ..., 0.8516, -0.4688, 0.5117]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_15': tensor([[[ 1.9531, -0.8320, 0.6641, ..., -1.7578, 0.9570, 0.3789],\n", " [ 1.5938, 0.0806, -0.3184, ..., 0.7500, -0.2031, -1.8750],\n", " [ 0.3242, 0.2656, 0.3301, ..., -1.1406, -1.8359, -2.7031],\n", " ...,\n", " [-1.2578, 0.5195, 1.6328, ..., -0.1875, -0.9883, 0.5117],\n", " [ 1.2734, 1.6094, 0.1230, ..., -0.9297, 0.9102, -1.4531],\n", " [ 0.7930, -0.8633, 2.6406, ..., -0.3027, -1.8203, -0.5312]]],\n", " dtype=torch.bfloat16),\n", " 't_16': tensor([0.7812], dtype=torch.bfloat16),\n", " 'latents_16_start': tensor([[[ 1.5469, -0.6172, 0.4980, ..., -1.3594, 0.8008, 0.3008],\n", " [ 1.3047, 0.1021, -0.0928, ..., 0.5820, -0.1396, -1.2812],\n", " [ 0.3340, 0.3301, 0.1387, ..., -0.6953, -0.9727, -1.7812],\n", " ...,\n", " [-1.2188, 0.1436, 0.9883, ..., 0.2236, -0.3965, 0.6797],\n", " [ 0.7617, 1.0391, -0.1738, ..., 0.2812, 1.2734, -0.0048],\n", " [ 0.3398, -0.9062, 1.7969, ..., 0.8555, -0.4395, 0.5195]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_16': tensor([[[ 1.9766, -0.8320, 0.6758, ..., -1.7734, 0.9844, 0.3867],\n", " [ 1.6094, 0.0923, -0.3164, ..., 0.7617, -0.2148, -1.8984],\n", " [ 0.3281, 0.2695, 0.3398, ..., -1.1484, -1.8516, -2.7500],\n", " ...,\n", " [-1.2500, 0.5156, 1.6406, ..., -0.1953, -0.9648, 0.5273],\n", " [ 1.3125, 1.6250, 0.1113, ..., -0.9102, 0.9414, -1.4609],\n", " [ 0.7969, -0.8750, 2.6719, ..., -0.2988, -1.7891, -0.5469]]],\n", " dtype=torch.bfloat16),\n", " 't_17': tensor([0.7656], dtype=torch.bfloat16),\n", " 'latents_17_start': tensor([[[ 1.5156, -0.6055, 0.4883, ..., -1.3281, 0.7852, 0.2949],\n", " [ 1.2812, 0.1006, -0.0879, ..., 0.5703, -0.1367, -1.2500],\n", " [ 0.3281, 0.3262, 0.1328, ..., -0.6758, -0.9414, -1.7344],\n", " ...,\n", " [-1.1953, 0.1357, 0.9609, ..., 0.2266, -0.3809, 0.6719],\n", " [ 0.7422, 1.0156, -0.1758, ..., 0.2949, 1.2578, 0.0184],\n", " [ 0.3281, -0.8906, 1.7578, ..., 0.8594, -0.4102, 0.5273]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_17': tensor([[[ 1.9688, -0.8242, 0.6719, ..., -1.7578, 0.9688, 0.3691],\n", " [ 1.6094, 0.0869, -0.3145, ..., 0.7500, -0.2217, -1.8828],\n", " [ 0.3203, 0.2754, 0.3457, ..., -1.1406, -1.8516, -2.7500],\n", " ...,\n", " [-1.2266, 0.5156, 1.6250, ..., -0.1904, -0.9492, 0.5273],\n", " [ 1.3047, 1.6172, 0.1040, ..., -0.9141, 0.9570, -1.4531],\n", " [ 0.7852, -0.8633, 2.6562, ..., -0.2949, -1.7969, -0.5430]]],\n", " dtype=torch.bfloat16),\n", " 't_18': tensor([0.7461], dtype=torch.bfloat16),\n", " 'latents_18_start': tensor([[[ 1.4844, -0.5938, 0.4766, ..., -1.2969, 0.7695, 0.2891],\n", " [ 1.2578, 0.0991, -0.0830, ..., 0.5586, -0.1328, -1.2188],\n", " [ 0.3223, 0.3223, 0.1270, ..., -0.6562, -0.9102, -1.6875],\n", " ...,\n", " [-1.1719, 0.1270, 0.9336, ..., 0.2295, -0.3652, 0.6641],\n", " [ 0.7227, 0.9883, -0.1777, ..., 0.3105, 1.2422, 0.0420],\n", " [ 0.3145, -0.8750, 1.7109, ..., 0.8633, -0.3809, 0.5352]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_18': tensor([[[ 1.9844, -0.8398, 0.6680, ..., -1.7578, 0.9727, 0.3730],\n", " [ 1.6172, 0.0752, -0.3184, ..., 0.7578, -0.2148, -1.8828],\n", " [ 0.3066, 0.2715, 0.3398, ..., -1.1328, -1.8516, -2.7500],\n", " ...,\n", " [-1.2422, 0.5156, 1.6484, ..., -0.1777, -0.9336, 0.5625],\n", " [ 1.3047, 1.6328, 0.1147, ..., -0.8906, 0.9883, -1.4688],\n", " [ 0.7734, -0.8672, 2.6406, ..., -0.2910, -1.7891, -0.5312]]],\n", " dtype=torch.bfloat16),\n", " 't_19': tensor([0.7305], dtype=torch.bfloat16),\n", " 'latents_19_start': tensor([[[ 1.4531, -0.5781, 0.4648, ..., -1.2656, 0.7539, 0.2832],\n", " [ 1.2344, 0.0977, -0.0776, ..., 0.5469, -0.1289, -1.1875],\n", " [ 0.3164, 0.3184, 0.1211, ..., -0.6367, -0.8789, -1.6406],\n", " ...,\n", " [-1.1484, 0.1182, 0.9062, ..., 0.2324, -0.3496, 0.6562],\n", " [ 0.6992, 0.9609, -0.1797, ..., 0.3262, 1.2266, 0.0664],\n", " [ 0.3008, -0.8594, 1.6641, ..., 0.8672, -0.3516, 0.5430]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_19': tensor([[[ 1.9844, -0.8516, 0.6641, ..., -1.7656, 0.9688, 0.3770],\n", " [ 1.6094, 0.0684, -0.3281, ..., 0.7734, -0.2119, -1.8672],\n", " [ 0.3086, 0.2695, 0.3301, ..., -1.1484, -1.8438, -2.7188],\n", " ...,\n", " [-1.2422, 0.5117, 1.6484, ..., -0.1738, -0.8984, 0.5781],\n", " [ 1.3203, 1.6328, 0.1157, ..., -0.8828, 1.0078, -1.4844],\n", " [ 0.7617, -0.8672, 2.6719, ..., -0.2480, -1.8125, -0.5273]]],\n", " dtype=torch.bfloat16),\n", " 't_20': tensor([0.7148], dtype=torch.bfloat16),\n", " 'latents_20_start': tensor([[[ 1.4219, -0.5625, 0.4531, ..., -1.2344, 0.7383, 0.2773],\n", " [ 1.2109, 0.0967, -0.0723, ..., 0.5352, -0.1250, -1.1562],\n", " [ 0.3105, 0.3145, 0.1157, ..., -0.6172, -0.8477, -1.5938],\n", " ...,\n", " [-1.1250, 0.1094, 0.8789, ..., 0.2354, -0.3340, 0.6484],\n", " [ 0.6758, 0.9336, -0.1816, ..., 0.3418, 1.2109, 0.0913],\n", " [ 0.2871, -0.8438, 1.6172, ..., 0.8711, -0.3203, 0.5508]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_20': tensor([[[ 1.9766, -0.8438, 0.6562, ..., -1.7656, 0.9570, 0.3809],\n", " [ 1.6094, 0.0713, -0.3340, ..., 0.7734, -0.2246, -1.8594],\n", " [ 0.2910, 0.2598, 0.3281, ..., -1.1250, -1.8359, -2.7500],\n", " ...,\n", " [-1.2422, 0.5039, 1.6406, ..., -0.1738, -0.8867, 0.6016],\n", " [ 1.3125, 1.6172, 0.1187, ..., -0.8672, 1.0156, -1.4922],\n", " [ 0.7578, -0.8711, 2.6719, ..., -0.2559, -1.7891, -0.5547]]],\n", " dtype=torch.bfloat16),\n", " 't_21': tensor([0.6992], dtype=torch.bfloat16),\n", " 'latents_21_start': tensor([[[ 1.3906, -0.5469, 0.4414, ..., -1.2031, 0.7227, 0.2715],\n", " [ 1.1797, 0.0952, -0.0664, ..., 0.5234, -0.1211, -1.1250],\n", " [ 0.3047, 0.3105, 0.1099, ..., -0.5977, -0.8164, -1.5469],\n", " ...,\n", " [-1.1016, 0.1006, 0.8516, ..., 0.2383, -0.3184, 0.6367],\n", " [ 0.6523, 0.9062, -0.1836, ..., 0.3574, 1.1953, 0.1172],\n", " [ 0.2734, -0.8281, 1.5703, ..., 0.8750, -0.2891, 0.5586]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_21': tensor([[[ 1.9844, -0.8281, 0.6680, ..., -1.7578, 0.9375, 0.3457],\n", " [ 1.5938, 0.0811, -0.3203, ..., 0.7656, -0.2207, -1.8594],\n", " [ 0.2773, 0.2559, 0.3340, ..., -1.1328, -1.8438, -2.7344],\n", " ...,\n", " [-1.2266, 0.4961, 1.6406, ..., -0.1953, -0.8750, 0.5859],\n", " [ 1.2969, 1.6250, 0.1147, ..., -0.8594, 1.0156, -1.4922],\n", " [ 0.7578, -0.8711, 2.6719, ..., -0.2617, -1.7891, -0.5508]]],\n", " dtype=torch.bfloat16),\n", " 't_22': tensor([0.6797], dtype=torch.bfloat16),\n", " 'latents_22_start': tensor([[[ 1.3594, -0.5312, 0.4297, ..., -1.1719, 0.7070, 0.2656],\n", " [ 1.1484, 0.0938, -0.0608, ..., 0.5117, -0.1172, -1.0938],\n", " [ 0.3008, 0.3066, 0.1040, ..., -0.5781, -0.7852, -1.5000],\n", " ...,\n", " [-1.0781, 0.0918, 0.8242, ..., 0.2422, -0.3027, 0.6250],\n", " [ 0.6289, 0.8789, -0.1855, ..., 0.3730, 1.1797, 0.1436],\n", " [ 0.2598, -0.8125, 1.5234, ..., 0.8789, -0.2578, 0.5664]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_22': tensor([[[ 1.9922, -0.8242, 0.6523, ..., -1.7500, 0.9375, 0.3477],\n", " [ 1.5859, 0.0757, -0.3379, ..., 0.7578, -0.2178, -1.8438],\n", " [ 0.2930, 0.2520, 0.3320, ..., -1.1250, -1.8516, -2.7500],\n", " ...,\n", " [-1.2031, 0.5000, 1.6406, ..., -0.2012, -0.8750, 0.5820],\n", " [ 1.3047, 1.6094, 0.1309, ..., -0.8555, 1.0234, -1.5078],\n", " [ 0.7617, -0.8711, 2.6562, ..., -0.2793, -1.7969, -0.5742]]],\n", " dtype=torch.bfloat16),\n", " 't_23': tensor([0.6641], dtype=torch.bfloat16),\n", " 'latents_23_start': tensor([[[ 1.3203, -0.5156, 0.4180, ..., -1.1406, 0.6914, 0.2598],\n", " [ 1.1172, 0.0923, -0.0547, ..., 0.4980, -0.1133, -1.0625],\n", " [ 0.2949, 0.3027, 0.0981, ..., -0.5586, -0.7500, -1.4531],\n", " ...,\n", " [-1.0547, 0.0830, 0.7930, ..., 0.2461, -0.2871, 0.6133],\n", " [ 0.6055, 0.8516, -0.1875, ..., 0.3887, 1.1641, 0.1709],\n", " [ 0.2461, -0.7969, 1.4766, ..., 0.8828, -0.2256, 0.5781]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_23': tensor([[[ 1.9688, -0.8203, 0.6562, ..., -1.7422, 0.9336, 0.3359],\n", " [ 1.5781, 0.0923, -0.3164, ..., 0.7617, -0.2188, -1.8438],\n", " [ 0.2969, 0.2617, 0.3203, ..., -1.1328, -1.8594, -2.7500],\n", " ...,\n", " [-1.2109, 0.5117, 1.6406, ..., -0.1914, -0.8711, 0.5938],\n", " [ 1.2891, 1.6094, 0.1182, ..., -0.8477, 1.0469, -1.5000],\n", " [ 0.7461, -0.8945, 2.6562, ..., -0.2852, -1.8047, -0.5586]]],\n", " dtype=torch.bfloat16),\n", " 't_24': tensor([0.6445], dtype=torch.bfloat16),\n", " 'latents_24_start': tensor([[[ 1.2812, -0.5000, 0.4062, ..., -1.1094, 0.6758, 0.2539],\n", " [ 1.0859, 0.0908, -0.0488, ..., 0.4844, -0.1094, -1.0312],\n", " [ 0.2891, 0.2988, 0.0923, ..., -0.5391, -0.7148, -1.4062],\n", " ...,\n", " [-1.0312, 0.0737, 0.7617, ..., 0.2500, -0.2715, 0.6016],\n", " [ 0.5820, 0.8203, -0.1895, ..., 0.4043, 1.1484, 0.1982],\n", " [ 0.2324, -0.7812, 1.4297, ..., 0.8867, -0.1924, 0.5898]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_24': tensor([[[ 1.9688, -0.8164, 0.6484, ..., -1.7422, 0.9492, 0.3574],\n", " [ 1.5703, 0.0918, -0.3262, ..., 0.7734, -0.2207, -1.8438],\n", " [ 0.2871, 0.2637, 0.3340, ..., -1.1250, -1.8516, -2.7656],\n", " ...,\n", " [-1.2031, 0.4961, 1.6328, ..., -0.1768, -0.8555, 0.6055],\n", " [ 1.2891, 1.6172, 0.1211, ..., -0.8516, 1.0625, -1.5000],\n", " [ 0.7461, -0.8867, 2.6562, ..., -0.2891, -1.8047, -0.5391]]],\n", " dtype=torch.bfloat16),\n", " 't_25': tensor([0.6289], dtype=torch.bfloat16),\n", " 'latents_25_start': tensor([[[ 1.2422, -0.4844, 0.3945, ..., -1.0781, 0.6562, 0.2471],\n", " [ 1.0547, 0.0889, -0.0427, ..., 0.4707, -0.1055, -0.9961],\n", " [ 0.2832, 0.2930, 0.0859, ..., -0.5195, -0.6797, -1.3516],\n", " ...,\n", " [-1.0078, 0.0645, 0.7305, ..., 0.2539, -0.2559, 0.5898],\n", " [ 0.5586, 0.7891, -0.1914, ..., 0.4199, 1.1250, 0.2266],\n", " [ 0.2188, -0.7656, 1.3828, ..., 0.8906, -0.1582, 0.6016]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_25': tensor([[[ 1.9609, -0.8242, 0.6523, ..., -1.7422, 0.9258, 0.3418],\n", " [ 1.5625, 0.0850, -0.3359, ..., 0.7812, -0.2295, -1.8516],\n", " [ 0.2871, 0.2520, 0.3184, ..., -1.1250, -1.8359, -2.7500],\n", " ...,\n", " [-1.1875, 0.4863, 1.6250, ..., -0.1924, -0.8633, 0.6055],\n", " [ 1.2969, 1.6172, 0.1240, ..., -0.8359, 1.0625, -1.5078],\n", " [ 0.7305, -0.8789, 2.6562, ..., -0.2969, -1.8203, -0.5469]]],\n", " dtype=torch.bfloat16),\n", " 't_26': tensor([0.6094], dtype=torch.bfloat16),\n", " 'latents_26_start': tensor([[[ 1.2031, -0.4688, 0.3828, ..., -1.0469, 0.6406, 0.2402],\n", " [ 1.0234, 0.0874, -0.0364, ..., 0.4551, -0.1011, -0.9609],\n", " [ 0.2773, 0.2891, 0.0801, ..., -0.4980, -0.6445, -1.2969],\n", " ...,\n", " [-0.9844, 0.0552, 0.6992, ..., 0.2578, -0.2393, 0.5781],\n", " [ 0.5352, 0.7578, -0.1934, ..., 0.4355, 1.1016, 0.2559],\n", " [ 0.2051, -0.7500, 1.3359, ..., 0.8945, -0.1235, 0.6133]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_26': tensor([[[ 1.9609, -0.8320, 0.6602, ..., -1.7578, 0.9219, 0.3496],\n", " [ 1.5703, 0.0801, -0.3359, ..., 0.7812, -0.2266, -1.8281],\n", " [ 0.2793, 0.2471, 0.3223, ..., -1.1172, -1.8281, -2.7188],\n", " ...,\n", " [-1.1797, 0.4863, 1.6172, ..., -0.2021, -0.8516, 0.6133],\n", " [ 1.2969, 1.6016, 0.1226, ..., -0.8359, 1.0625, -1.5000],\n", " [ 0.7305, -0.8906, 2.6562, ..., -0.2988, -1.8047, -0.5469]]],\n", " dtype=torch.bfloat16),\n", " 't_27': tensor([0.5898], dtype=torch.bfloat16),\n", " 'latents_27_start': tensor([[[ 1.1641, -0.4531, 0.3691, ..., -1.0156, 0.6211, 0.2334],\n", " [ 0.9922, 0.0859, -0.0298, ..., 0.4395, -0.0967, -0.9258],\n", " [ 0.2715, 0.2852, 0.0737, ..., -0.4766, -0.6094, -1.2422],\n", " ...,\n", " [-0.9609, 0.0457, 0.6680, ..., 0.2617, -0.2227, 0.5664],\n", " [ 0.5078, 0.7266, -0.1953, ..., 0.4512, 1.0781, 0.2852],\n", " [ 0.1904, -0.7344, 1.2812, ..., 0.8984, -0.0884, 0.6250]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_27': tensor([[[ 1.9453, -0.8398, 0.6562, ..., -1.7578, 0.9141, 0.3398],\n", " [ 1.5469, 0.0654, -0.3379, ..., 0.7734, -0.2236, -1.8359],\n", " [ 0.2773, 0.2354, 0.3223, ..., -1.1094, -1.8359, -2.7188],\n", " ...,\n", " [-1.1797, 0.4844, 1.6328, ..., -0.1992, -0.8359, 0.6172],\n", " [ 1.2891, 1.5859, 0.1133, ..., -0.8242, 1.0469, -1.4922],\n", " [ 0.7148, -0.9180, 2.6406, ..., -0.3047, -1.8203, -0.5508]]],\n", " dtype=torch.bfloat16),\n", " 't_28': tensor([0.5664], dtype=torch.bfloat16),\n", " 'latents_28_start': tensor([[[ 1.1250, -0.4355, 0.3555, ..., -0.9805, 0.6016, 0.2266],\n", " [ 0.9609, 0.0845, -0.0231, ..., 0.4238, -0.0923, -0.8906],\n", " [ 0.2656, 0.2812, 0.0674, ..., -0.4551, -0.5742, -1.1875],\n", " ...,\n", " [-0.9375, 0.0361, 0.6367, ..., 0.2656, -0.2061, 0.5547],\n", " [ 0.4824, 0.6953, -0.1973, ..., 0.4668, 1.0547, 0.3145],\n", " [ 0.1758, -0.7148, 1.2266, ..., 0.9062, -0.0522, 0.6367]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_28': tensor([[[ 1.9453, -0.8281, 0.6406, ..., -1.7344, 0.9219, 0.3672],\n", " [ 1.5547, 0.0684, -0.3496, ..., 0.8047, -0.1953, -1.8281],\n", " [ 0.2812, 0.2207, 0.3281, ..., -1.1016, -1.8359, -2.7188],\n", " ...,\n", " [-1.1953, 0.4668, 1.6328, ..., -0.1758, -0.8203, 0.6445],\n", " [ 1.2734, 1.5781, 0.1045, ..., -0.7969, 1.0547, -1.4922],\n", " [ 0.6953, -0.9258, 2.6562, ..., -0.3086, -1.7969, -0.5508]]],\n", " dtype=torch.bfloat16),\n", " 't_29': tensor([0.5469], dtype=torch.bfloat16),\n", " 'latents_29_start': tensor([[[ 1.0859, -0.4180, 0.3418, ..., -0.9453, 0.5820, 0.2188],\n", " [ 0.9297, 0.0830, -0.0159, ..., 0.4082, -0.0884, -0.8516],\n", " [ 0.2598, 0.2773, 0.0608, ..., -0.4336, -0.5352, -1.1328],\n", " ...,\n", " [-0.9141, 0.0266, 0.6016, ..., 0.2695, -0.1895, 0.5430],\n", " [ 0.4570, 0.6641, -0.1992, ..., 0.4824, 1.0312, 0.3457],\n", " [ 0.1621, -0.6953, 1.1719, ..., 0.9141, -0.0156, 0.6484]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_29': tensor([[[ 1.9688, -0.8320, 0.6445, ..., -1.7734, 0.9219, 0.3672],\n", " [ 1.5469, 0.0732, -0.3477, ..., 0.7930, -0.2100, -1.8281],\n", " [ 0.2793, 0.2354, 0.3262, ..., -1.1250, -1.8359, -2.7188],\n", " ...,\n", " [-1.1953, 0.4746, 1.6172, ..., -0.1738, -0.8086, 0.6484],\n", " [ 1.2969, 1.5781, 0.0952, ..., -0.8164, 1.0859, -1.4766],\n", " [ 0.6797, -0.9180, 2.6562, ..., -0.3145, -1.7812, -0.5508]]],\n", " dtype=torch.bfloat16),\n", " 't_30': tensor([0.5273], dtype=torch.bfloat16),\n", " 'latents_30_start': tensor([[[ 1.0469, -0.4004, 0.3281, ..., -0.9102, 0.5625, 0.2109],\n", " [ 0.8984, 0.0815, -0.0087, ..., 0.3926, -0.0840, -0.8125],\n", " [ 0.2539, 0.2734, 0.0540, ..., -0.4102, -0.4961, -1.0781],\n", " ...,\n", " [-0.8906, 0.0168, 0.5664, ..., 0.2734, -0.1729, 0.5312],\n", " [ 0.4297, 0.6328, -0.2012, ..., 0.5000, 1.0078, 0.3770],\n", " [ 0.1484, -0.6758, 1.1172, ..., 0.9219, 0.0212, 0.6602]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_30': tensor([[[ 1.9766, -0.8359, 0.6484, ..., -1.7812, 0.9219, 0.3691],\n", " [ 1.5469, 0.0693, -0.3359, ..., 0.7969, -0.2090, -1.8203],\n", " [ 0.2852, 0.2363, 0.3320, ..., -1.1172, -1.8359, -2.7344],\n", " ...,\n", " [-1.1953, 0.4766, 1.6406, ..., -0.1855, -0.8203, 0.6484],\n", " [ 1.2891, 1.5781, 0.1064, ..., -0.8203, 1.0859, -1.4922],\n", " [ 0.6953, -0.9219, 2.6719, ..., -0.3145, -1.7656, -0.5312]]],\n", " dtype=torch.bfloat16),\n", " 't_31': tensor([0.5078], dtype=torch.bfloat16),\n", " 'latents_31_start': tensor([[[ 1.0078, -0.3828, 0.3145, ..., -0.8711, 0.5430, 0.2031],\n", " [ 0.8672, 0.0801, -0.0015, ..., 0.3750, -0.0796, -0.7734],\n", " [ 0.2480, 0.2676, 0.0469, ..., -0.3867, -0.4570, -1.0234],\n", " ...,\n", " [-0.8672, 0.0067, 0.5312, ..., 0.2773, -0.1553, 0.5156],\n", " [ 0.4023, 0.5977, -0.2031, ..., 0.5156, 0.9844, 0.4082],\n", " [ 0.1338, -0.6562, 1.0625, ..., 0.9297, 0.0588, 0.6719]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_31': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9297, 0.3555],\n", " [ 1.5469, 0.0796, -0.3516, ..., 0.7969, -0.2051, -1.8281],\n", " [ 0.2754, 0.2344, 0.3262, ..., -1.1172, -1.8359, -2.7344],\n", " ...,\n", " [-1.1953, 0.4766, 1.6328, ..., -0.1895, -0.8125, 0.6289],\n", " [ 1.2969, 1.5859, 0.0938, ..., -0.8125, 1.0781, -1.4766],\n", " [ 0.6836, -0.9375, 2.6562, ..., -0.2949, -1.7734, -0.5234]]],\n", " dtype=torch.bfloat16),\n", " 't_32': tensor([0.4844], dtype=torch.bfloat16),\n", " 'latents_32_start': tensor([[[ 0.9648, -0.3652, 0.3008, ..., -0.8320, 0.5234, 0.1953],\n", " [ 0.8320, 0.0781, 0.0061, ..., 0.3574, -0.0752, -0.7344],\n", " [ 0.2422, 0.2617, 0.0398, ..., -0.3633, -0.4180, -0.9648],\n", " ...,\n", " [-0.8398, -0.0037, 0.4961, ..., 0.2812, -0.1377, 0.5000],\n", " [ 0.3750, 0.5625, -0.2051, ..., 0.5352, 0.9609, 0.4395],\n", " [ 0.1191, -0.6367, 1.0078, ..., 0.9375, 0.0977, 0.6836]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_32': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9258, 0.3535],\n", " [ 1.5391, 0.0728, -0.3457, ..., 0.7891, -0.2021, -1.8203],\n", " [ 0.2754, 0.2344, 0.3223, ..., -1.1172, -1.8281, -2.7344],\n", " ...,\n", " [-1.1875, 0.4688, 1.6250, ..., -0.1768, -0.8086, 0.6133],\n", " [ 1.2812, 1.5703, 0.1079, ..., -0.8320, 1.0703, -1.4844],\n", " [ 0.6914, -0.9297, 2.6562, ..., -0.3027, -1.7656, -0.5156]]],\n", " dtype=torch.bfloat16),\n", " 't_33': tensor([0.4629], dtype=torch.bfloat16),\n", " 'latents_33_start': tensor([[[ 0.9219, -0.3477, 0.2871, ..., -0.7930, 0.5039, 0.1875],\n", " [ 0.7969, 0.0767, 0.0138, ..., 0.3398, -0.0708, -0.6953],\n", " [ 0.2363, 0.2559, 0.0327, ..., -0.3379, -0.3770, -0.9023],\n", " ...,\n", " [-0.8125, -0.0141, 0.4609, ..., 0.2852, -0.1196, 0.4863],\n", " [ 0.3457, 0.5273, -0.2070, ..., 0.5547, 0.9375, 0.4727],\n", " [ 0.1035, -0.6172, 0.9492, ..., 0.9453, 0.1367, 0.6953]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_33': tensor([[[ 1.9609, -0.8242, 0.6484, ..., -1.7578, 0.9258, 0.3477],\n", " [ 1.5312, 0.0684, -0.3379, ..., 0.7891, -0.2012, -1.8125],\n", " [ 0.2812, 0.2158, 0.3164, ..., -1.1172, -1.8125, -2.7188],\n", " ...,\n", " [-1.1797, 0.4570, 1.6250, ..., -0.1826, -0.8086, 0.6055],\n", " [ 1.2812, 1.5625, 0.1025, ..., -0.8125, 1.0625, -1.4922],\n", " [ 0.6797, -0.9297, 2.6250, ..., -0.2949, -1.7656, -0.5117]]],\n", " dtype=torch.bfloat16),\n", " 't_34': tensor([0.4375], dtype=torch.bfloat16),\n", " 'latents_34_start': tensor([[[ 0.8789, -0.3281, 0.2715, ..., -0.7539, 0.4824, 0.1797],\n", " [ 0.7617, 0.0752, 0.0215, ..., 0.3223, -0.0664, -0.6523],\n", " [ 0.2295, 0.2500, 0.0255, ..., -0.3125, -0.3359, -0.8398],\n", " ...,\n", " [-0.7852, -0.0245, 0.4238, ..., 0.2891, -0.1011, 0.4727],\n", " [ 0.3164, 0.4922, -0.2090, ..., 0.5742, 0.9141, 0.5078],\n", " [ 0.0879, -0.5977, 0.8906, ..., 0.9531, 0.1768, 0.7070]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_34': tensor([[[ 1.9766, -0.8242, 0.6523, ..., -1.7812, 0.9258, 0.3535],\n", " [ 1.5312, 0.0640, -0.3457, ..., 0.8086, -0.1934, -1.8047],\n", " [ 0.2715, 0.1992, 0.3105, ..., -1.0938, -1.8047, -2.7344],\n", " ...,\n", " [-1.1875, 0.4609, 1.6172, ..., -0.1953, -0.8164, 0.6172],\n", " [ 1.2812, 1.5625, 0.0942, ..., -0.8242, 1.0781, -1.4922],\n", " [ 0.6562, -0.9531, 2.6406, ..., -0.3008, -1.7734, -0.4980]]],\n", " dtype=torch.bfloat16),\n", " 't_35': tensor([0.4160], dtype=torch.bfloat16),\n", " 'latents_35_start': tensor([[[ 0.8320, -0.3086, 0.2559, ..., -0.7109, 0.4609, 0.1719],\n", " [ 0.7266, 0.0737, 0.0295, ..., 0.3027, -0.0620, -0.6094],\n", " [ 0.2236, 0.2451, 0.0183, ..., -0.2871, -0.2930, -0.7773],\n", " ...,\n", " [-0.7578, -0.0352, 0.3867, ..., 0.2930, -0.0820, 0.4590],\n", " [ 0.2871, 0.4551, -0.2109, ..., 0.5938, 0.8906, 0.5430],\n", " [ 0.0728, -0.5742, 0.8281, ..., 0.9609, 0.2178, 0.7188]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_35': tensor([[[ 1.9766, -0.8242, 0.6484, ..., -1.7734, 0.9258, 0.3496],\n", " [ 1.5312, 0.0708, -0.3418, ..., 0.8047, -0.1953, -1.7969],\n", " [ 0.2832, 0.1875, 0.3086, ..., -1.0938, -1.7891, -2.7344],\n", " ...,\n", " [-1.1719, 0.4551, 1.6172, ..., -0.1953, -0.8047, 0.6055],\n", " [ 1.2578, 1.5625, 0.0898, ..., -0.8242, 1.0859, -1.4766],\n", " [ 0.6602, -0.9414, 2.6406, ..., -0.3164, -1.7656, -0.5078]]],\n", " dtype=torch.bfloat16),\n", " 't_36': tensor([0.3926], dtype=torch.bfloat16),\n", " 'latents_36_start': tensor([[[ 0.7852, -0.2891, 0.2402, ..., -0.6680, 0.4395, 0.1641],\n", " [ 0.6914, 0.0723, 0.0376, ..., 0.2832, -0.0574, -0.5664],\n", " [ 0.2168, 0.2402, 0.0110, ..., -0.2617, -0.2500, -0.7109],\n", " ...,\n", " [-0.7305, -0.0459, 0.3477, ..., 0.2969, -0.0630, 0.4453],\n", " [ 0.2578, 0.4180, -0.2129, ..., 0.6133, 0.8633, 0.5781],\n", " [ 0.0571, -0.5508, 0.7656, ..., 0.9688, 0.2598, 0.7305]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_36': tensor([[[ 1.9609, -0.8164, 0.6445, ..., -1.7656, 0.9102, 0.3477],\n", " [ 1.5234, 0.0654, -0.3320, ..., 0.8164, -0.2041, -1.7812],\n", " [ 0.2734, 0.1836, 0.3066, ..., -1.0938, -1.7891, -2.7031],\n", " ...,\n", " [-1.1875, 0.4688, 1.6016, ..., -0.2051, -0.8008, 0.5977],\n", " [ 1.2500, 1.5234, 0.0747, ..., -0.8438, 1.0781, -1.4922],\n", " [ 0.6484, -0.9219, 2.6250, ..., -0.3027, -1.7812, -0.5078]]],\n", " dtype=torch.bfloat16),\n", " 't_37': tensor([0.3652], dtype=torch.bfloat16),\n", " 'latents_37_start': tensor([[[ 0.7383, -0.2695, 0.2246, ..., -0.6250, 0.4180, 0.1553],\n", " [ 0.6562, 0.0708, 0.0457, ..., 0.2637, -0.0525, -0.5234],\n", " [ 0.2100, 0.2354, 0.0035, ..., -0.2354, -0.2061, -0.6445],\n", " ...,\n", " [-0.7031, -0.0574, 0.3086, ..., 0.3027, -0.0435, 0.4316],\n", " [ 0.2275, 0.3809, -0.2148, ..., 0.6328, 0.8359, 0.6133],\n", " [ 0.0413, -0.5273, 0.7031, ..., 0.9766, 0.3027, 0.7422]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_37': tensor([[[ 1.9922, -0.8320, 0.6523, ..., -1.7812, 0.9336, 0.3613],\n", " [ 1.5312, 0.0713, -0.3477, ..., 0.8281, -0.1963, -1.7891],\n", " [ 0.2754, 0.1777, 0.2930, ..., -1.0859, -1.7734, -2.7031],\n", " ...,\n", " [-1.1875, 0.4590, 1.6172, ..., -0.1855, -0.8086, 0.5820],\n", " [ 1.2422, 1.5391, 0.0552, ..., -0.8633, 1.0781, -1.5156],\n", " [ 0.6602, -0.9219, 2.6250, ..., -0.3047, -1.7734, -0.5000]]],\n", " dtype=torch.bfloat16),\n", " 't_38': tensor([0.3418], dtype=torch.bfloat16),\n", " 'latents_38_start': tensor([[[ 0.6875, -0.2490, 0.2080, ..., -0.5820, 0.3945, 0.1465],\n", " [ 0.6172, 0.0688, 0.0544, ..., 0.2432, -0.0476, -0.4785],\n", " [ 0.2031, 0.2305, -0.0038, ..., -0.2080, -0.1621, -0.5781],\n", " ...,\n", " [-0.6719, -0.0688, 0.2676, ..., 0.3066, -0.0232, 0.4180],\n", " [ 0.1963, 0.3418, -0.2158, ..., 0.6562, 0.8086, 0.6523],\n", " [ 0.0248, -0.5039, 0.6367, ..., 0.9844, 0.3477, 0.7539]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_38': tensor([[[ 1.9766, -0.8477, 0.6484, ..., -1.7812, 0.9531, 0.3789],\n", " [ 1.5234, 0.0781, -0.3496, ..., 0.8281, -0.1973, -1.7734],\n", " [ 0.2617, 0.1660, 0.2949, ..., -1.0547, -1.7422, -2.6875],\n", " ...,\n", " [-1.1641, 0.4668, 1.6328, ..., -0.1836, -0.8125, 0.5547],\n", " [ 1.2344, 1.5312, 0.0752, ..., -0.8867, 1.0234, -1.5156],\n", " [ 0.6406, -0.9219, 2.6250, ..., -0.2988, -1.7812, -0.5117]]],\n", " dtype=torch.bfloat16),\n", " 't_39': tensor([0.3164], dtype=torch.bfloat16),\n", " 'latents_39_start': tensor([[[ 0.6367, -0.2275, 0.1914, ..., -0.5352, 0.3711, 0.1367],\n", " [ 0.5781, 0.0669, 0.0635, ..., 0.2217, -0.0425, -0.4336],\n", " [ 0.1963, 0.2266, -0.0114, ..., -0.1807, -0.1172, -0.5078],\n", " ...,\n", " [-0.6406, -0.0811, 0.2256, ..., 0.3105, -0.0023, 0.4043],\n", " [ 0.1641, 0.3027, -0.2178, ..., 0.6797, 0.7812, 0.6914],\n", " [ 0.0083, -0.4805, 0.5703, ..., 0.9922, 0.3926, 0.7656]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_39': tensor([[[ 1.9531, -0.8320, 0.6523, ..., -1.7500, 0.9336, 0.3965],\n", " [ 1.5156, 0.0767, -0.3691, ..., 0.8281, -0.1777, -1.7500],\n", " [ 0.2637, 0.1729, 0.2734, ..., -1.0234, -1.6797, -2.6250],\n", " ...,\n", " [-1.1406, 0.4531, 1.6016, ..., -0.1973, -0.8164, 0.5234],\n", " [ 1.2344, 1.5078, 0.0593, ..., -0.8906, 1.0234, -1.4844],\n", " [ 0.6367, -0.9023, 2.5938, ..., -0.2910, -1.7500, -0.5078]]],\n", " dtype=torch.bfloat16),\n", " 't_40': tensor([0.2891], dtype=torch.bfloat16),\n", " 'latents_40_start': tensor([[[ 0.5859, -0.2061, 0.1738, ..., -0.4883, 0.3457, 0.1260],\n", " [ 0.5391, 0.0649, 0.0732, ..., 0.2002, -0.0378, -0.3867],\n", " [ 0.1895, 0.2217, -0.0186, ..., -0.1543, -0.0732, -0.4395],\n", " ...,\n", " [-0.6094, -0.0928, 0.1836, ..., 0.3164, 0.0192, 0.3906],\n", " [ 0.1318, 0.2637, -0.2197, ..., 0.7031, 0.7539, 0.7305],\n", " [-0.0084, -0.4570, 0.5039, ..., 1.0000, 0.4375, 0.7773]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_40': tensor([[[ 1.9766, -0.8555, 0.6445, ..., -1.7500, 0.9414, 0.4004],\n", " [ 1.5234, 0.0693, -0.3613, ..., 0.8672, -0.1650, -1.7266],\n", " [ 0.2598, 0.1660, 0.2734, ..., -1.0156, -1.7031, -2.6250],\n", " ...,\n", " [-1.1641, 0.4531, 1.6016, ..., -0.1611, -0.8125, 0.5039],\n", " [ 1.2109, 1.5156, 0.0391, ..., -0.8906, 0.9883, -1.4688],\n", " [ 0.6250, -0.9102, 2.6094, ..., -0.2930, -1.7578, -0.4922]]],\n", " dtype=torch.bfloat16),\n", " 't_41': tensor([0.2617], dtype=torch.bfloat16),\n", " 'latents_41_start': tensor([[[ 0.5312, -0.1826, 0.1562, ..., -0.4414, 0.3203, 0.1152],\n", " [ 0.4980, 0.0630, 0.0830, ..., 0.1768, -0.0334, -0.3398],\n", " [ 0.1826, 0.2168, -0.0259, ..., -0.1270, -0.0273, -0.3691],\n", " ...,\n", " [-0.5781, -0.1050, 0.1406, ..., 0.3203, 0.0410, 0.3770],\n", " [ 0.0991, 0.2227, -0.2207, ..., 0.7266, 0.7266, 0.7695],\n", " [-0.0253, -0.4316, 0.4336, ..., 1.0078, 0.4844, 0.7891]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_41': tensor([[[ 1.9375, -0.8555, 0.6367, ..., -1.7266, 0.9531, 0.4297],\n", " [ 1.5312, 0.0659, -0.3633, ..., 0.8711, -0.1660, -1.7266],\n", " [ 0.2520, 0.1826, 0.2676, ..., -0.9844, -1.6328, -2.5781],\n", " ...,\n", " [-1.1484, 0.4453, 1.5859, ..., -0.1562, -0.8281, 0.4922],\n", " [ 1.1797, 1.5078, 0.0229, ..., -0.9102, 0.9531, -1.4609],\n", " [ 0.6211, -0.9102, 2.5938, ..., -0.2832, -1.7266, -0.4727]]],\n", " dtype=torch.bfloat16),\n", " 't_42': tensor([0.2354], dtype=torch.bfloat16),\n", " 'latents_42_start': tensor([[[ 0.4785, -0.1592, 0.1387, ..., -0.3945, 0.2949, 0.1035],\n", " [ 0.4551, 0.0613, 0.0928, ..., 0.1523, -0.0288, -0.2930],\n", " [ 0.1758, 0.2119, -0.0332, ..., -0.0996, 0.0178, -0.2969],\n", " ...,\n", " [-0.5469, -0.1172, 0.0967, ..., 0.3242, 0.0640, 0.3633],\n", " [ 0.0664, 0.1816, -0.2217, ..., 0.7500, 0.6992, 0.8086],\n", " [-0.0425, -0.4062, 0.3613, ..., 1.0156, 0.5312, 0.8008]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_42': tensor([[[ 1.9297, -0.8594, 0.6289, ..., -1.7031, 0.9609, 0.4297],\n", " [ 1.5000, 0.0811, -0.3652, ..., 0.8711, -0.1494, -1.7031],\n", " [ 0.2246, 0.1543, 0.2695, ..., -0.9492, -1.6328, -2.5312],\n", " ...,\n", " [-1.1328, 0.4395, 1.5781, ..., -0.1680, -0.8398, 0.4453],\n", " [ 1.1484, 1.4766, 0.0073, ..., -0.9648, 0.8984, -1.4219],\n", " [ 0.5938, -0.8672, 2.5312, ..., -0.2949, -1.7031, -0.4766]]],\n", " dtype=torch.bfloat16),\n", " 't_43': tensor([0.2070], dtype=torch.bfloat16),\n", " 'latents_43_start': tensor([[[ 0.4238, -0.1348, 0.1211, ..., -0.3457, 0.2676, 0.0913],\n", " [ 0.4121, 0.0591, 0.1030, ..., 0.1279, -0.0245, -0.2441],\n", " [ 0.1699, 0.2080, -0.0408, ..., -0.0728, 0.0640, -0.2246],\n", " ...,\n", " [-0.5156, -0.1299, 0.0520, ..., 0.3281, 0.0879, 0.3516],\n", " [ 0.0339, 0.1396, -0.2217, ..., 0.7773, 0.6719, 0.8477],\n", " [-0.0593, -0.3809, 0.2891, ..., 1.0234, 0.5781, 0.8125]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_43': tensor([[[ 1.9141, -0.8711, 0.6328, ..., -1.6484, 0.9883, 0.4453],\n", " [ 1.4844, 0.0693, -0.4004, ..., 0.8867, -0.1152, -1.6797],\n", " [ 0.2432, 0.1484, 0.2090, ..., -0.8867, -1.5938, -2.4531],\n", " ...,\n", " [-1.1094, 0.4453, 1.5703, ..., -0.1865, -0.8594, 0.3906],\n", " [ 1.0938, 1.4375, -0.0374, ..., -0.9648, 0.8359, -1.4531],\n", " [ 0.5898, -0.8516, 2.4688, ..., -0.2773, -1.6484, -0.4746]]],\n", " dtype=torch.bfloat16),\n", " 't_44': tensor([0.1777], dtype=torch.bfloat16),\n", " 'latents_44_start': tensor([[[ 0.3672, -0.1094, 0.1025, ..., -0.2969, 0.2393, 0.0781],\n", " [ 0.3691, 0.0571, 0.1147, ..., 0.1021, -0.0212, -0.1953],\n", " [ 0.1631, 0.2041, -0.0469, ..., -0.0469, 0.1104, -0.1533],\n", " ...,\n", " [-0.4844, -0.1426, 0.0063, ..., 0.3340, 0.1128, 0.3398],\n", " [ 0.0022, 0.0977, -0.2207, ..., 0.8047, 0.6484, 0.8906],\n", " [-0.0762, -0.3555, 0.2168, ..., 1.0312, 0.6250, 0.8281]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_44': tensor([[[ 1.8984, -0.8984, 0.5977, ..., -1.6094, 0.9805, 0.4551],\n", " [ 1.4609, 0.0474, -0.4160, ..., 0.9102, -0.1060, -1.6484],\n", " [ 0.2119, 0.0791, 0.2021, ..., -0.8359, -1.5391, -2.4219],\n", " ...,\n", " [-1.0859, 0.4395, 1.5391, ..., -0.2002, -0.8516, 0.3359],\n", " [ 1.0625, 1.4219, -0.0486, ..., -0.9609, 0.7930, -1.4453],\n", " [ 0.5898, -0.8008, 2.4219, ..., -0.3223, -1.5703, -0.4609]]],\n", " dtype=torch.bfloat16),\n", " 't_45': tensor([0.1484], dtype=torch.bfloat16),\n", " 'latents_45_start': tensor([[[ 0.3105, -0.0825, 0.0850, ..., -0.2490, 0.2100, 0.0645],\n", " [ 0.3262, 0.0557, 0.1270, ..., 0.0747, -0.0181, -0.1465],\n", " [ 0.1562, 0.2021, -0.0530, ..., -0.0219, 0.1562, -0.0811],\n", " ...,\n", " [-0.4512, -0.1553, -0.0398, ..., 0.3398, 0.1387, 0.3301],\n", " [-0.0295, 0.0552, -0.2197, ..., 0.8320, 0.6250, 0.9336],\n", " [-0.0938, -0.3320, 0.1445, ..., 1.0391, 0.6719, 0.8438]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_45': tensor([[[ 1.8516, -0.8984, 0.5547, ..., -1.5547, 0.9688, 0.4727],\n", " [ 1.4219, 0.0500, -0.4258, ..., 0.9141, -0.0588, -1.5703],\n", " [ 0.1592, 0.0850, 0.1924, ..., -0.7500, -1.4766, -2.3125],\n", " ...,\n", " [-1.0469, 0.4414, 1.4766, ..., -0.1494, -0.8438, 0.3047],\n", " [ 0.9883, 1.4062, -0.0874, ..., -0.9844, 0.7422, -1.4297],\n", " [ 0.5742, -0.7578, 2.3125, ..., -0.3301, -1.3672, -0.4238]]],\n", " dtype=torch.bfloat16),\n", " 't_46': tensor([0.1172], dtype=torch.bfloat16),\n", " 'latents_46_start': tensor([[[ 0.2539, -0.0549, 0.0679, ..., -0.2012, 0.1807, 0.0500],\n", " [ 0.2832, 0.0542, 0.1396, ..., 0.0469, -0.0162, -0.0986],\n", " [ 0.1514, 0.1992, -0.0588, ..., 0.0011, 0.2012, -0.0103],\n", " ...,\n", " [-0.4199, -0.1689, -0.0850, ..., 0.3438, 0.1641, 0.3203],\n", " [-0.0598, 0.0122, -0.2168, ..., 0.8633, 0.6016, 0.9766],\n", " [-0.1113, -0.3086, 0.0737, ..., 1.0469, 0.7148, 0.8555]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_46': tensor([[[ 1.7734, -0.9492, 0.5430, ..., -1.5312, 0.9727, 0.5273],\n", " [ 1.4219, 0.0054, -0.4844, ..., 0.9297, 0.0198, -1.4531],\n", " [ 0.1377, -0.0330, 0.1299, ..., -0.6914, -1.3906, -2.2188],\n", " ...,\n", " [-0.9727, 0.4297, 1.3906, ..., -0.1172, -0.8164, 0.2295],\n", " [ 0.8945, 1.3516, -0.1758, ..., -0.9961, 0.6445, -1.5078],\n", " [ 0.5234, -0.7422, 2.2031, ..., -0.3848, -1.1641, -0.4883]]],\n", " dtype=torch.bfloat16),\n", " 't_47': tensor([0.0854], dtype=torch.bfloat16),\n", " 'latents_47_start': tensor([[[ 0.1982, -0.0250, 0.0508, ..., -0.1523, 0.1504, 0.0334],\n", " [ 0.2383, 0.0540, 0.1553, ..., 0.0176, -0.0168, -0.0530],\n", " [ 0.1475, 0.2002, -0.0630, ..., 0.0228, 0.2451, 0.0596],\n", " ...,\n", " [-0.3887, -0.1826, -0.1289, ..., 0.3477, 0.1895, 0.3125],\n", " [-0.0879, -0.0303, -0.2109, ..., 0.8945, 0.5820, 1.0234],\n", " [-0.1279, -0.2852, 0.0044, ..., 1.0625, 0.7500, 0.8711]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_47': tensor([[[ 1.6250, -0.9688, 0.4160, ..., -1.4609, 0.9961, 0.5742],\n", " [ 1.3594, 0.0728, -0.5430, ..., 0.9062, 0.0530, -1.3438],\n", " [ 0.1553, -0.1787, 0.0908, ..., -0.5820, -1.1875, -1.9688],\n", " ...,\n", " [-0.8281, 0.4160, 1.2422, ..., -0.0122, -0.7500, 0.1396],\n", " [ 0.7734, 1.2812, -0.2295, ..., -0.9883, 0.5039, -1.4844],\n", " [ 0.4453, -0.6719, 1.9688, ..., -0.4180, -0.9141, -0.6211]]],\n", " dtype=torch.bfloat16),\n", " 't_48': tensor([0.0532], dtype=torch.bfloat16),\n", " 'latents_48_start': tensor([[[ 0.1455, 0.0065, 0.0374, ..., -0.1050, 0.1182, 0.0148],\n", " [ 0.1943, 0.0515, 0.1729, ..., -0.0118, -0.0186, -0.0093],\n", " [ 0.1426, 0.2061, -0.0659, ..., 0.0417, 0.2832, 0.1235],\n", " ...,\n", " [-0.3613, -0.1963, -0.1689, ..., 0.3477, 0.2139, 0.3086],\n", " [-0.1133, -0.0718, -0.2031, ..., 0.9258, 0.5664, 1.0703],\n", " [-0.1426, -0.2637, -0.0596, ..., 1.0781, 0.7812, 0.8906]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred_48': tensor([[[ 1.2031, -0.9297, 0.2559, ..., -1.4141, 0.8906, 0.4258],\n", " [ 1.1016, 0.0625, -0.3242, ..., 0.8047, 0.1318, -1.1094],\n", " [ 0.1094, -0.2695, 0.2334, ..., -0.5820, -1.1016, -1.5234],\n", " ...,\n", " [-0.5938, 0.4551, 1.0938, ..., 0.0281, -0.6289, 0.1357],\n", " [ 0.6523, 1.0625, -0.2275, ..., -1.0938, 0.4297, -1.3984],\n", " [ 0.3906, -0.4180, 1.5391, ..., -0.5742, -0.6250, -0.7852]]],\n", " dtype=torch.bfloat16),\n", " 't_49': tensor([0.0200], dtype=torch.bfloat16),\n", " 'latents_49_start': tensor([[[ 1.0547e-01, 3.7354e-02, 2.8809e-02, ..., -5.8105e-02,\n", " 8.8867e-02, 6.1035e-04],\n", " [ 1.5820e-01, 4.9316e-02, 1.8359e-01, ..., -3.8574e-02,\n", " -2.2949e-02, 2.7588e-02],\n", " [ 1.3867e-01, 2.1484e-01, -7.3730e-02, ..., 6.1035e-02,\n", " 3.2031e-01, 1.7383e-01],\n", " ...,\n", " [-3.4180e-01, -2.1094e-01, -2.0508e-01, ..., 3.4766e-01,\n", " 2.3438e-01, 3.0469e-01],\n", " [-1.3477e-01, -1.0693e-01, -1.9531e-01, ..., 9.6094e-01,\n", " 5.5078e-01, 1.1172e+00],\n", " [-1.5527e-01, -2.5000e-01, -1.1035e-01, ..., 1.0938e+00,\n", " 8.0078e-01, 9.1797e-01]]], dtype=torch.bfloat16),\n", " 'noise_pred_49': tensor([[[ 0.7461, -0.5586, 0.2197, ..., -1.0469, 0.7109, 0.4902],\n", " [ 0.6094, 0.0464, -0.1650, ..., 0.4980, 0.2314, -0.9414],\n", " [ 0.1064, -0.2109, 0.1846, ..., -0.3633, -0.8086, -1.0234],\n", " ...,\n", " [-0.2559, 0.3711, 0.7461, ..., -0.2217, -0.2988, 0.0339],\n", " [ 0.4980, 0.5156, -0.0260, ..., -1.1250, 0.1064, -1.1250],\n", " [ 0.2471, 0.0179, 0.6875, ..., -0.7188, -0.5898, -0.8672]]],\n", " dtype=torch.bfloat16),\n", " 'output': tensor([[[ 0.0903, 0.0486, 0.0244, ..., -0.0371, 0.0747, -0.0092],\n", " [ 0.1465, 0.0483, 0.1865, ..., -0.0486, -0.0276, 0.0464],\n", " [ 0.1367, 0.2188, -0.0776, ..., 0.0684, 0.3359, 0.1943],\n", " ...,\n", " [-0.3359, -0.2188, -0.2197, ..., 0.3516, 0.2402, 0.3047],\n", " [-0.1445, -0.1172, -0.1943, ..., 0.9844, 0.5469, 1.1406],\n", " [-0.1602, -0.2500, -0.1240, ..., 1.1094, 0.8125, 0.9336]]],\n", " dtype=torch.bfloat16),\n", " 'height': 576,\n", " 'width': 448,\n", " 't': tensor([1.], dtype=torch.bfloat16),\n", " 'latents_start': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n", " [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n", " [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n", " ...,\n", " [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n", " [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n", " [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n", " dtype=torch.bfloat16),\n", " 'noise_pred': tensor([[[ 1.8906, -0.8945, 0.5938, ..., -1.7578, 1.0078, 0.2539],\n", " [ 1.5781, 0.0278, -0.2793, ..., 0.7305, -0.1553, -1.7969],\n", " [ 0.3027, 0.2949, 0.1621, ..., -1.0625, -1.5938, -2.6406],\n", " ...,\n", " [-1.2578, 0.5352, 1.5859, ..., -0.2773, -1.0312, 0.3203],\n", " [ 1.2734, 1.5312, 0.0728, ..., -0.6211, 0.8984, -1.1562],\n", " [ 0.6172, -0.9336, 2.6719, ..., -0.1050, -1.8672, -0.3691]]],\n", " dtype=torch.bfloat16)}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "src[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "22f19ae9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }