cavargas10 commited on
Commit
590f44b
·
verified ·
1 Parent(s): 16b89c2

Delete trellis/models

Browse files
trellis/models/__init__.py DELETED
@@ -1,96 +0,0 @@
1
- import importlib
2
-
3
- __attributes = {
4
- 'SparseStructureEncoder': 'sparse_structure_vae',
5
- 'SparseStructureDecoder': 'sparse_structure_vae',
6
-
7
- 'SparseStructureFlowModel': 'sparse_structure_flow',
8
-
9
- 'SLatEncoder': 'structured_latent_vae',
10
- 'SLatGaussianDecoder': 'structured_latent_vae',
11
- 'SLatRadianceFieldDecoder': 'structured_latent_vae',
12
- 'SLatMeshDecoder': 'structured_latent_vae',
13
- 'ElasticSLatEncoder': 'structured_latent_vae',
14
- 'ElasticSLatGaussianDecoder': 'structured_latent_vae',
15
- 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
16
- 'ElasticSLatMeshDecoder': 'structured_latent_vae',
17
-
18
- 'SLatFlowModel': 'structured_latent_flow',
19
- 'ElasticSLatFlowModel': 'structured_latent_flow',
20
- }
21
-
22
- __submodules = []
23
-
24
- __all__ = list(__attributes.keys()) + __submodules
25
-
26
- def __getattr__(name):
27
- if name not in globals():
28
- if name in __attributes:
29
- module_name = __attributes[name]
30
- module = importlib.import_module(f".{module_name}", __name__)
31
- globals()[name] = getattr(module, name)
32
- elif name in __submodules:
33
- module = importlib.import_module(f".{name}", __name__)
34
- globals()[name] = module
35
- else:
36
- raise AttributeError(f"module {__name__} has no attribute {name}")
37
- return globals()[name]
38
-
39
-
40
- def from_pretrained(path: str, **kwargs):
41
- """
42
- Load a model from a pretrained checkpoint.
43
-
44
- Args:
45
- path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
46
- NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
47
- **kwargs: Additional arguments for the model constructor.
48
- """
49
- import os
50
- import json
51
- from safetensors.torch import load_file
52
- is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
53
-
54
- if is_local:
55
- config_file = f"{path}.json"
56
- model_file = f"{path}.safetensors"
57
- else:
58
- from huggingface_hub import hf_hub_download
59
- path_parts = path.split('/')
60
- repo_id = f'{path_parts[0]}/{path_parts[1]}'
61
- model_name = '/'.join(path_parts[2:])
62
- config_file = hf_hub_download(repo_id, f"{model_name}.json")
63
- model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
64
-
65
- with open(config_file, 'r') as f:
66
- config = json.load(f)
67
- model = __getattr__(config['name'])(**config['args'], **kwargs)
68
- model.load_state_dict(load_file(model_file))
69
-
70
- return model
71
-
72
-
73
- # For Pylance
74
- if __name__ == '__main__':
75
- from .sparse_structure_vae import (
76
- SparseStructureEncoder,
77
- SparseStructureDecoder,
78
- )
79
-
80
- from .sparse_structure_flow import SparseStructureFlowModel
81
-
82
- from .structured_latent_vae import (
83
- SLatEncoder,
84
- SLatGaussianDecoder,
85
- SLatRadianceFieldDecoder,
86
- SLatMeshDecoder,
87
- ElasticSLatEncoder,
88
- ElasticSLatGaussianDecoder,
89
- ElasticSLatRadianceFieldDecoder,
90
- ElasticSLatMeshDecoder,
91
- )
92
-
93
- from .structured_latent_flow import (
94
- SLatFlowModel,
95
- ElasticSLatFlowModel,
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/sparse_elastic_mixin.py DELETED
@@ -1,24 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import *
3
- import math
4
- from ..modules import sparse as sp
5
- from ..utils.elastic_utils import ElasticModuleMixin
6
-
7
-
8
- class SparseTransformerElasticMixin(ElasticModuleMixin):
9
- def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
10
- return x.feats.shape[0]
11
-
12
- @contextmanager
13
- def with_mem_ratio(self, mem_ratio=1.0):
14
- if mem_ratio == 1.0:
15
- yield 1.0
16
- return
17
- num_blocks = len(self.blocks)
18
- num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
19
- exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
20
- for i in range(num_blocks):
21
- self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
22
- yield exact_mem_ratio
23
- for i in range(num_blocks):
24
- self.blocks[i].use_checkpoint = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/sparse_structure_flow.py DELETED
@@ -1,200 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
- from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8
- from ..modules.spatial import patchify, unpatchify
9
-
10
-
11
- class TimestepEmbedder(nn.Module):
12
- """
13
- Embeds scalar timesteps into vector representations.
14
- """
15
- def __init__(self, hidden_size, frequency_embedding_size=256):
16
- super().__init__()
17
- self.mlp = nn.Sequential(
18
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
19
- nn.SiLU(),
20
- nn.Linear(hidden_size, hidden_size, bias=True),
21
- )
22
- self.frequency_embedding_size = frequency_embedding_size
23
-
24
- @staticmethod
25
- def timestep_embedding(t, dim, max_period=10000):
26
- """
27
- Create sinusoidal timestep embeddings.
28
-
29
- Args:
30
- t: a 1-D Tensor of N indices, one per batch element.
31
- These may be fractional.
32
- dim: the dimension of the output.
33
- max_period: controls the minimum frequency of the embeddings.
34
-
35
- Returns:
36
- an (N, D) Tensor of positional embeddings.
37
- """
38
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
- half = dim // 2
40
- freqs = torch.exp(
41
- -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
42
- ).to(device=t.device)
43
- args = t[:, None].float() * freqs[None]
44
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
- if dim % 2:
46
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
- return embedding
48
-
49
- def forward(self, t):
50
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
51
- t_emb = self.mlp(t_freq)
52
- return t_emb
53
-
54
-
55
- class SparseStructureFlowModel(nn.Module):
56
- def __init__(
57
- self,
58
- resolution: int,
59
- in_channels: int,
60
- model_channels: int,
61
- cond_channels: int,
62
- out_channels: int,
63
- num_blocks: int,
64
- num_heads: Optional[int] = None,
65
- num_head_channels: Optional[int] = 64,
66
- mlp_ratio: float = 4,
67
- patch_size: int = 2,
68
- pe_mode: Literal["ape", "rope"] = "ape",
69
- use_fp16: bool = False,
70
- use_checkpoint: bool = False,
71
- share_mod: bool = False,
72
- qk_rms_norm: bool = False,
73
- qk_rms_norm_cross: bool = False,
74
- ):
75
- super().__init__()
76
- self.resolution = resolution
77
- self.in_channels = in_channels
78
- self.model_channels = model_channels
79
- self.cond_channels = cond_channels
80
- self.out_channels = out_channels
81
- self.num_blocks = num_blocks
82
- self.num_heads = num_heads or model_channels // num_head_channels
83
- self.mlp_ratio = mlp_ratio
84
- self.patch_size = patch_size
85
- self.pe_mode = pe_mode
86
- self.use_fp16 = use_fp16
87
- self.use_checkpoint = use_checkpoint
88
- self.share_mod = share_mod
89
- self.qk_rms_norm = qk_rms_norm
90
- self.qk_rms_norm_cross = qk_rms_norm_cross
91
- self.dtype = torch.float16 if use_fp16 else torch.float32
92
-
93
- self.t_embedder = TimestepEmbedder(model_channels)
94
- if share_mod:
95
- self.adaLN_modulation = nn.Sequential(
96
- nn.SiLU(),
97
- nn.Linear(model_channels, 6 * model_channels, bias=True)
98
- )
99
-
100
- if pe_mode == "ape":
101
- pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102
- coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
103
- coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104
- pos_emb = pos_embedder(coords)
105
- self.register_buffer("pos_emb", pos_emb)
106
-
107
- self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108
-
109
- self.blocks = nn.ModuleList([
110
- ModulatedTransformerCrossBlock(
111
- model_channels,
112
- cond_channels,
113
- num_heads=self.num_heads,
114
- mlp_ratio=self.mlp_ratio,
115
- attn_mode='full',
116
- use_checkpoint=self.use_checkpoint,
117
- use_rope=(pe_mode == "rope"),
118
- share_mod=share_mod,
119
- qk_rms_norm=self.qk_rms_norm,
120
- qk_rms_norm_cross=self.qk_rms_norm_cross,
121
- )
122
- for _ in range(num_blocks)
123
- ])
124
-
125
- self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126
-
127
- self.initialize_weights()
128
- if use_fp16:
129
- self.convert_to_fp16()
130
-
131
- @property
132
- def device(self) -> torch.device:
133
- """
134
- Return the device of the model.
135
- """
136
- return next(self.parameters()).device
137
-
138
- def convert_to_fp16(self) -> None:
139
- """
140
- Convert the torso of the model to float16.
141
- """
142
- self.blocks.apply(convert_module_to_f16)
143
-
144
- def convert_to_fp32(self) -> None:
145
- """
146
- Convert the torso of the model to float32.
147
- """
148
- self.blocks.apply(convert_module_to_f32)
149
-
150
- def initialize_weights(self) -> None:
151
- # Initialize transformer layers:
152
- def _basic_init(module):
153
- if isinstance(module, nn.Linear):
154
- torch.nn.init.xavier_uniform_(module.weight)
155
- if module.bias is not None:
156
- nn.init.constant_(module.bias, 0)
157
- self.apply(_basic_init)
158
-
159
- # Initialize timestep embedding MLP:
160
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
161
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
162
-
163
- # Zero-out adaLN modulation layers in DiT blocks:
164
- if self.share_mod:
165
- nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
166
- nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
167
- else:
168
- for block in self.blocks:
169
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
-
172
- # Zero-out output layers:
173
- nn.init.constant_(self.out_layer.weight, 0)
174
- nn.init.constant_(self.out_layer.bias, 0)
175
-
176
- def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177
- assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178
- f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
179
-
180
- h = patchify(x, self.patch_size)
181
- h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
182
-
183
- h = self.input_layer(h)
184
- h = h + self.pos_emb[None]
185
- t_emb = self.t_embedder(t)
186
- if self.share_mod:
187
- t_emb = self.adaLN_modulation(t_emb)
188
- t_emb = t_emb.type(self.dtype)
189
- h = h.type(self.dtype)
190
- cond = cond.type(self.dtype)
191
- for block in self.blocks:
192
- h = block(h, t_emb, cond)
193
- h = h.type(x.dtype)
194
- h = F.layer_norm(h, h.shape[-1:])
195
- h = self.out_layer(h)
196
-
197
- h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
198
- h = unpatchify(h, self.patch_size).contiguous()
199
-
200
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/sparse_structure_vae.py DELETED
@@ -1,306 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
- from ..modules.spatial import pixel_shuffle_3d
7
- from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
-
9
-
10
- def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
11
- """
12
- Return a normalization layer.
13
- """
14
- if norm_type == "group":
15
- return GroupNorm32(32, *args, **kwargs)
16
- elif norm_type == "layer":
17
- return ChannelLayerNorm32(*args, **kwargs)
18
- else:
19
- raise ValueError(f"Invalid norm type {norm_type}")
20
-
21
-
22
- class ResBlock3d(nn.Module):
23
- def __init__(
24
- self,
25
- channels: int,
26
- out_channels: Optional[int] = None,
27
- norm_type: Literal["group", "layer"] = "layer",
28
- ):
29
- super().__init__()
30
- self.channels = channels
31
- self.out_channels = out_channels or channels
32
-
33
- self.norm1 = norm_layer(norm_type, channels)
34
- self.norm2 = norm_layer(norm_type, self.out_channels)
35
- self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
- self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
- self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
-
39
- def forward(self, x: torch.Tensor) -> torch.Tensor:
40
- h = self.norm1(x)
41
- h = F.silu(h)
42
- h = self.conv1(h)
43
- h = self.norm2(h)
44
- h = F.silu(h)
45
- h = self.conv2(h)
46
- h = h + self.skip_connection(x)
47
- return h
48
-
49
-
50
- class DownsampleBlock3d(nn.Module):
51
- def __init__(
52
- self,
53
- in_channels: int,
54
- out_channels: int,
55
- mode: Literal["conv", "avgpool"] = "conv",
56
- ):
57
- assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
58
-
59
- super().__init__()
60
- self.in_channels = in_channels
61
- self.out_channels = out_channels
62
-
63
- if mode == "conv":
64
- self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
- elif mode == "avgpool":
66
- assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
67
-
68
- def forward(self, x: torch.Tensor) -> torch.Tensor:
69
- if hasattr(self, "conv"):
70
- return self.conv(x)
71
- else:
72
- return F.avg_pool3d(x, 2)
73
-
74
-
75
- class UpsampleBlock3d(nn.Module):
76
- def __init__(
77
- self,
78
- in_channels: int,
79
- out_channels: int,
80
- mode: Literal["conv", "nearest"] = "conv",
81
- ):
82
- assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
83
-
84
- super().__init__()
85
- self.in_channels = in_channels
86
- self.out_channels = out_channels
87
-
88
- if mode == "conv":
89
- self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
- elif mode == "nearest":
91
- assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
92
-
93
- def forward(self, x: torch.Tensor) -> torch.Tensor:
94
- if hasattr(self, "conv"):
95
- x = self.conv(x)
96
- return pixel_shuffle_3d(x, 2)
97
- else:
98
- return F.interpolate(x, scale_factor=2, mode="nearest")
99
-
100
-
101
- class SparseStructureEncoder(nn.Module):
102
- """
103
- Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
-
105
- Args:
106
- in_channels (int): Channels of the input.
107
- latent_channels (int): Channels of the latent representation.
108
- num_res_blocks (int): Number of residual blocks at each resolution.
109
- channels (List[int]): Channels of the encoder blocks.
110
- num_res_blocks_middle (int): Number of residual blocks in the middle.
111
- norm_type (Literal["group", "layer"]): Type of normalization layer.
112
- use_fp16 (bool): Whether to use FP16.
113
- """
114
- def __init__(
115
- self,
116
- in_channels: int,
117
- latent_channels: int,
118
- num_res_blocks: int,
119
- channels: List[int],
120
- num_res_blocks_middle: int = 2,
121
- norm_type: Literal["group", "layer"] = "layer",
122
- use_fp16: bool = False,
123
- ):
124
- super().__init__()
125
- self.in_channels = in_channels
126
- self.latent_channels = latent_channels
127
- self.num_res_blocks = num_res_blocks
128
- self.channels = channels
129
- self.num_res_blocks_middle = num_res_blocks_middle
130
- self.norm_type = norm_type
131
- self.use_fp16 = use_fp16
132
- self.dtype = torch.float16 if use_fp16 else torch.float32
133
-
134
- self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
135
-
136
- self.blocks = nn.ModuleList([])
137
- for i, ch in enumerate(channels):
138
- self.blocks.extend([
139
- ResBlock3d(ch, ch)
140
- for _ in range(num_res_blocks)
141
- ])
142
- if i < len(channels) - 1:
143
- self.blocks.append(
144
- DownsampleBlock3d(ch, channels[i+1])
145
- )
146
-
147
- self.middle_block = nn.Sequential(*[
148
- ResBlock3d(channels[-1], channels[-1])
149
- for _ in range(num_res_blocks_middle)
150
- ])
151
-
152
- self.out_layer = nn.Sequential(
153
- norm_layer(norm_type, channels[-1]),
154
- nn.SiLU(),
155
- nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
- )
157
-
158
- if use_fp16:
159
- self.convert_to_fp16()
160
-
161
- @property
162
- def device(self) -> torch.device:
163
- """
164
- Return the device of the model.
165
- """
166
- return next(self.parameters()).device
167
-
168
- def convert_to_fp16(self) -> None:
169
- """
170
- Convert the torso of the model to float16.
171
- """
172
- self.use_fp16 = True
173
- self.dtype = torch.float16
174
- self.blocks.apply(convert_module_to_f16)
175
- self.middle_block.apply(convert_module_to_f16)
176
-
177
- def convert_to_fp32(self) -> None:
178
- """
179
- Convert the torso of the model to float32.
180
- """
181
- self.use_fp16 = False
182
- self.dtype = torch.float32
183
- self.blocks.apply(convert_module_to_f32)
184
- self.middle_block.apply(convert_module_to_f32)
185
-
186
- def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
187
- h = self.input_layer(x)
188
- h = h.type(self.dtype)
189
-
190
- for block in self.blocks:
191
- h = block(h)
192
- h = self.middle_block(h)
193
-
194
- h = h.type(x.dtype)
195
- h = self.out_layer(h)
196
-
197
- mean, logvar = h.chunk(2, dim=1)
198
-
199
- if sample_posterior:
200
- std = torch.exp(0.5 * logvar)
201
- z = mean + std * torch.randn_like(std)
202
- else:
203
- z = mean
204
-
205
- if return_raw:
206
- return z, mean, logvar
207
- return z
208
-
209
-
210
- class SparseStructureDecoder(nn.Module):
211
- """
212
- Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
-
214
- Args:
215
- out_channels (int): Channels of the output.
216
- latent_channels (int): Channels of the latent representation.
217
- num_res_blocks (int): Number of residual blocks at each resolution.
218
- channels (List[int]): Channels of the decoder blocks.
219
- num_res_blocks_middle (int): Number of residual blocks in the middle.
220
- norm_type (Literal["group", "layer"]): Type of normalization layer.
221
- use_fp16 (bool): Whether to use FP16.
222
- """
223
- def __init__(
224
- self,
225
- out_channels: int,
226
- latent_channels: int,
227
- num_res_blocks: int,
228
- channels: List[int],
229
- num_res_blocks_middle: int = 2,
230
- norm_type: Literal["group", "layer"] = "layer",
231
- use_fp16: bool = False,
232
- ):
233
- super().__init__()
234
- self.out_channels = out_channels
235
- self.latent_channels = latent_channels
236
- self.num_res_blocks = num_res_blocks
237
- self.channels = channels
238
- self.num_res_blocks_middle = num_res_blocks_middle
239
- self.norm_type = norm_type
240
- self.use_fp16 = use_fp16
241
- self.dtype = torch.float16 if use_fp16 else torch.float32
242
-
243
- self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
-
245
- self.middle_block = nn.Sequential(*[
246
- ResBlock3d(channels[0], channels[0])
247
- for _ in range(num_res_blocks_middle)
248
- ])
249
-
250
- self.blocks = nn.ModuleList([])
251
- for i, ch in enumerate(channels):
252
- self.blocks.extend([
253
- ResBlock3d(ch, ch)
254
- for _ in range(num_res_blocks)
255
- ])
256
- if i < len(channels) - 1:
257
- self.blocks.append(
258
- UpsampleBlock3d(ch, channels[i+1])
259
- )
260
-
261
- self.out_layer = nn.Sequential(
262
- norm_layer(norm_type, channels[-1]),
263
- nn.SiLU(),
264
- nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
- )
266
-
267
- if use_fp16:
268
- self.convert_to_fp16()
269
-
270
- @property
271
- def device(self) -> torch.device:
272
- """
273
- Return the device of the model.
274
- """
275
- return next(self.parameters()).device
276
-
277
- def convert_to_fp16(self) -> None:
278
- """
279
- Convert the torso of the model to float16.
280
- """
281
- self.use_fp16 = True
282
- self.dtype = torch.float16
283
- self.blocks.apply(convert_module_to_f16)
284
- self.middle_block.apply(convert_module_to_f16)
285
-
286
- def convert_to_fp32(self) -> None:
287
- """
288
- Convert the torso of the model to float32.
289
- """
290
- self.use_fp16 = False
291
- self.dtype = torch.float32
292
- self.blocks.apply(convert_module_to_f32)
293
- self.middle_block.apply(convert_module_to_f32)
294
-
295
- def forward(self, x: torch.Tensor) -> torch.Tensor:
296
- h = self.input_layer(x)
297
-
298
- h = h.type(self.dtype)
299
-
300
- h = self.middle_block(h)
301
- for block in self.blocks:
302
- h = block(h)
303
-
304
- h = h.type(x.dtype)
305
- h = self.out_layer(h)
306
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_flow.py DELETED
@@ -1,276 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
- from ..modules.transformer import AbsolutePositionEmbedder
8
- from ..modules.norm import LayerNorm32
9
- from ..modules import sparse as sp
10
- from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
- from .sparse_structure_flow import TimestepEmbedder
12
- from .sparse_elastic_mixin import SparseTransformerElasticMixin
13
-
14
-
15
- class SparseResBlock3d(nn.Module):
16
- def __init__(
17
- self,
18
- channels: int,
19
- emb_channels: int,
20
- out_channels: Optional[int] = None,
21
- downsample: bool = False,
22
- upsample: bool = False,
23
- ):
24
- super().__init__()
25
- self.channels = channels
26
- self.emb_channels = emb_channels
27
- self.out_channels = out_channels or channels
28
- self.downsample = downsample
29
- self.upsample = upsample
30
-
31
- assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
32
-
33
- self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
34
- self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
35
- self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
36
- self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
37
- self.emb_layers = nn.Sequential(
38
- nn.SiLU(),
39
- nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
40
- )
41
- self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
42
- self.updown = None
43
- if self.downsample:
44
- self.updown = sp.SparseDownsample(2)
45
- elif self.upsample:
46
- self.updown = sp.SparseUpsample(2)
47
-
48
- def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
49
- if self.updown is not None:
50
- x = self.updown(x)
51
- return x
52
-
53
- def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
54
- emb_out = self.emb_layers(emb).type(x.dtype)
55
- scale, shift = torch.chunk(emb_out, 2, dim=1)
56
-
57
- x = self._updown(x)
58
- h = x.replace(self.norm1(x.feats))
59
- h = h.replace(F.silu(h.feats))
60
- h = self.conv1(h)
61
- h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
62
- h = h.replace(F.silu(h.feats))
63
- h = self.conv2(h)
64
- h = h + self.skip_connection(x)
65
-
66
- return h
67
-
68
-
69
- class SLatFlowModel(nn.Module):
70
- def __init__(
71
- self,
72
- resolution: int,
73
- in_channels: int,
74
- model_channels: int,
75
- cond_channels: int,
76
- out_channels: int,
77
- num_blocks: int,
78
- num_heads: Optional[int] = None,
79
- num_head_channels: Optional[int] = 64,
80
- mlp_ratio: float = 4,
81
- patch_size: int = 2,
82
- num_io_res_blocks: int = 2,
83
- io_block_channels: List[int] = None,
84
- pe_mode: Literal["ape", "rope"] = "ape",
85
- use_fp16: bool = False,
86
- use_checkpoint: bool = False,
87
- use_skip_connection: bool = True,
88
- share_mod: bool = False,
89
- qk_rms_norm: bool = False,
90
- qk_rms_norm_cross: bool = False,
91
- ):
92
- super().__init__()
93
- self.resolution = resolution
94
- self.in_channels = in_channels
95
- self.model_channels = model_channels
96
- self.cond_channels = cond_channels
97
- self.out_channels = out_channels
98
- self.num_blocks = num_blocks
99
- self.num_heads = num_heads or model_channels // num_head_channels
100
- self.mlp_ratio = mlp_ratio
101
- self.patch_size = patch_size
102
- self.num_io_res_blocks = num_io_res_blocks
103
- self.io_block_channels = io_block_channels
104
- self.pe_mode = pe_mode
105
- self.use_fp16 = use_fp16
106
- self.use_checkpoint = use_checkpoint
107
- self.use_skip_connection = use_skip_connection
108
- self.share_mod = share_mod
109
- self.qk_rms_norm = qk_rms_norm
110
- self.qk_rms_norm_cross = qk_rms_norm_cross
111
- self.dtype = torch.float16 if use_fp16 else torch.float32
112
-
113
- if self.io_block_channels is not None:
114
- assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
115
- assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
116
-
117
- self.t_embedder = TimestepEmbedder(model_channels)
118
- if share_mod:
119
- self.adaLN_modulation = nn.Sequential(
120
- nn.SiLU(),
121
- nn.Linear(model_channels, 6 * model_channels, bias=True)
122
- )
123
-
124
- if pe_mode == "ape":
125
- self.pos_embedder = AbsolutePositionEmbedder(model_channels)
126
-
127
- self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
128
-
129
- self.input_blocks = nn.ModuleList([])
130
- if io_block_channels is not None:
131
- for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
132
- self.input_blocks.extend([
133
- SparseResBlock3d(
134
- chs,
135
- model_channels,
136
- out_channels=chs,
137
- )
138
- for _ in range(num_io_res_blocks-1)
139
- ])
140
- self.input_blocks.append(
141
- SparseResBlock3d(
142
- chs,
143
- model_channels,
144
- out_channels=next_chs,
145
- downsample=True,
146
- )
147
- )
148
-
149
- self.blocks = nn.ModuleList([
150
- ModulatedSparseTransformerCrossBlock(
151
- model_channels,
152
- cond_channels,
153
- num_heads=self.num_heads,
154
- mlp_ratio=self.mlp_ratio,
155
- attn_mode='full',
156
- use_checkpoint=self.use_checkpoint,
157
- use_rope=(pe_mode == "rope"),
158
- share_mod=self.share_mod,
159
- qk_rms_norm=self.qk_rms_norm,
160
- qk_rms_norm_cross=self.qk_rms_norm_cross,
161
- )
162
- for _ in range(num_blocks)
163
- ])
164
-
165
- self.out_blocks = nn.ModuleList([])
166
- if io_block_channels is not None:
167
- for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
168
- self.out_blocks.append(
169
- SparseResBlock3d(
170
- prev_chs * 2 if self.use_skip_connection else prev_chs,
171
- model_channels,
172
- out_channels=chs,
173
- upsample=True,
174
- )
175
- )
176
- self.out_blocks.extend([
177
- SparseResBlock3d(
178
- chs * 2 if self.use_skip_connection else chs,
179
- model_channels,
180
- out_channels=chs,
181
- )
182
- for _ in range(num_io_res_blocks-1)
183
- ])
184
-
185
- self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
186
-
187
- self.initialize_weights()
188
- if use_fp16:
189
- self.convert_to_fp16()
190
-
191
- @property
192
- def device(self) -> torch.device:
193
- """
194
- Return the device of the model.
195
- """
196
- return next(self.parameters()).device
197
-
198
- def convert_to_fp16(self) -> None:
199
- """
200
- Convert the torso of the model to float16.
201
- """
202
- self.input_blocks.apply(convert_module_to_f16)
203
- self.blocks.apply(convert_module_to_f16)
204
- self.out_blocks.apply(convert_module_to_f16)
205
-
206
- def convert_to_fp32(self) -> None:
207
- """
208
- Convert the torso of the model to float32.
209
- """
210
- self.input_blocks.apply(convert_module_to_f32)
211
- self.blocks.apply(convert_module_to_f32)
212
- self.out_blocks.apply(convert_module_to_f32)
213
-
214
- def initialize_weights(self) -> None:
215
- # Initialize transformer layers:
216
- def _basic_init(module):
217
- if isinstance(module, nn.Linear):
218
- torch.nn.init.xavier_uniform_(module.weight)
219
- if module.bias is not None:
220
- nn.init.constant_(module.bias, 0)
221
- self.apply(_basic_init)
222
-
223
- # Initialize timestep embedding MLP:
224
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
225
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
226
-
227
- # Zero-out adaLN modulation layers in DiT blocks:
228
- if self.share_mod:
229
- nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
230
- nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
231
- else:
232
- for block in self.blocks:
233
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
234
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
235
-
236
- # Zero-out output layers:
237
- nn.init.constant_(self.out_layer.weight, 0)
238
- nn.init.constant_(self.out_layer.bias, 0)
239
-
240
- def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
241
- h = self.input_layer(x).type(self.dtype)
242
- t_emb = self.t_embedder(t)
243
- if self.share_mod:
244
- t_emb = self.adaLN_modulation(t_emb)
245
- t_emb = t_emb.type(self.dtype)
246
- cond = cond.type(self.dtype)
247
-
248
- skips = []
249
- # pack with input blocks
250
- for block in self.input_blocks:
251
- h = block(h, t_emb)
252
- skips.append(h.feats)
253
-
254
- if self.pe_mode == "ape":
255
- h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
256
- for block in self.blocks:
257
- h = block(h, t_emb, cond)
258
-
259
- # unpack with output blocks
260
- for block, skip in zip(self.out_blocks, reversed(skips)):
261
- if self.use_skip_connection:
262
- h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
263
- else:
264
- h = block(h, t_emb)
265
-
266
- h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
267
- h = self.out_layer(h.type(x.dtype))
268
- return h
269
-
270
-
271
- class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
272
- """
273
- SLat Flow Model with elastic memory management.
274
- Used for training with low VRAM.
275
- """
276
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_vae/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .encoder import SLatEncoder
2
- from .decoder_gs import SLatGaussianDecoder
3
- from .decoder_rf import SLatRadianceFieldDecoder
4
- from .decoder_mesh import SLatMeshDecoder
 
 
 
 
 
trellis/models/structured_latent_vae/base.py DELETED
@@ -1,117 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5
- from ...modules import sparse as sp
6
- from ...modules.transformer import AbsolutePositionEmbedder
7
- from ...modules.sparse.transformer import SparseTransformerBlock
8
-
9
-
10
- def block_attn_config(self):
11
- """
12
- Return the attention configuration of the model.
13
- """
14
- for i in range(self.num_blocks):
15
- if self.attn_mode == "shift_window":
16
- yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17
- elif self.attn_mode == "shift_sequence":
18
- yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19
- elif self.attn_mode == "shift_order":
20
- yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
- elif self.attn_mode == "full":
22
- yield "full", None, None, None, None
23
- elif self.attn_mode == "swin":
24
- yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25
-
26
-
27
- class SparseTransformerBase(nn.Module):
28
- """
29
- Sparse Transformer without output layers.
30
- Serve as the base class for encoder and decoder.
31
- """
32
- def __init__(
33
- self,
34
- in_channels: int,
35
- model_channels: int,
36
- num_blocks: int,
37
- num_heads: Optional[int] = None,
38
- num_head_channels: Optional[int] = 64,
39
- mlp_ratio: float = 4.0,
40
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41
- window_size: Optional[int] = None,
42
- pe_mode: Literal["ape", "rope"] = "ape",
43
- use_fp16: bool = False,
44
- use_checkpoint: bool = False,
45
- qk_rms_norm: bool = False,
46
- ):
47
- super().__init__()
48
- self.in_channels = in_channels
49
- self.model_channels = model_channels
50
- self.num_blocks = num_blocks
51
- self.window_size = window_size
52
- self.num_heads = num_heads or model_channels // num_head_channels
53
- self.mlp_ratio = mlp_ratio
54
- self.attn_mode = attn_mode
55
- self.pe_mode = pe_mode
56
- self.use_fp16 = use_fp16
57
- self.use_checkpoint = use_checkpoint
58
- self.qk_rms_norm = qk_rms_norm
59
- self.dtype = torch.float16 if use_fp16 else torch.float32
60
-
61
- if pe_mode == "ape":
62
- self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
-
64
- self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
- self.blocks = nn.ModuleList([
66
- SparseTransformerBlock(
67
- model_channels,
68
- num_heads=self.num_heads,
69
- mlp_ratio=self.mlp_ratio,
70
- attn_mode=attn_mode,
71
- window_size=window_size,
72
- shift_sequence=shift_sequence,
73
- shift_window=shift_window,
74
- serialize_mode=serialize_mode,
75
- use_checkpoint=self.use_checkpoint,
76
- use_rope=(pe_mode == "rope"),
77
- qk_rms_norm=self.qk_rms_norm,
78
- )
79
- for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
- ])
81
-
82
- @property
83
- def device(self) -> torch.device:
84
- """
85
- Return the device of the model.
86
- """
87
- return next(self.parameters()).device
88
-
89
- def convert_to_fp16(self) -> None:
90
- """
91
- Convert the torso of the model to float16.
92
- """
93
- self.blocks.apply(convert_module_to_f16)
94
-
95
- def convert_to_fp32(self) -> None:
96
- """
97
- Convert the torso of the model to float32.
98
- """
99
- self.blocks.apply(convert_module_to_f32)
100
-
101
- def initialize_weights(self) -> None:
102
- # Initialize transformer layers:
103
- def _basic_init(module):
104
- if isinstance(module, nn.Linear):
105
- torch.nn.init.xavier_uniform_(module.weight)
106
- if module.bias is not None:
107
- nn.init.constant_(module.bias, 0)
108
- self.apply(_basic_init)
109
-
110
- def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111
- h = self.input_layer(x)
112
- if self.pe_mode == "ape":
113
- h = h + self.pos_embedder(x.coords[:, 1:])
114
- h = h.type(self.dtype)
115
- for block in self.blocks:
116
- h = block(h)
117
- return h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_vae/decoder_gs.py DELETED
@@ -1,122 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from ...modules import sparse as sp
6
- from ...utils.random_utils import hammersley_sequence
7
- from .base import SparseTransformerBase
8
- from ...representations import Gaussian
9
-
10
-
11
- class SLatGaussianDecoder(SparseTransformerBase):
12
- def __init__(
13
- self,
14
- resolution: int,
15
- model_channels: int,
16
- latent_channels: int,
17
- num_blocks: int,
18
- num_heads: Optional[int] = None,
19
- num_head_channels: Optional[int] = 64,
20
- mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
- window_size: int = 8,
23
- pe_mode: Literal["ape", "rope"] = "ape",
24
- use_fp16: bool = False,
25
- use_checkpoint: bool = False,
26
- qk_rms_norm: bool = False,
27
- representation_config: dict = None,
28
- ):
29
- super().__init__(
30
- in_channels=latent_channels,
31
- model_channels=model_channels,
32
- num_blocks=num_blocks,
33
- num_heads=num_heads,
34
- num_head_channels=num_head_channels,
35
- mlp_ratio=mlp_ratio,
36
- attn_mode=attn_mode,
37
- window_size=window_size,
38
- pe_mode=pe_mode,
39
- use_fp16=use_fp16,
40
- use_checkpoint=use_checkpoint,
41
- qk_rms_norm=qk_rms_norm,
42
- )
43
- self.resolution = resolution
44
- self.rep_config = representation_config
45
- self._calc_layout()
46
- self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
47
- self._build_perturbation()
48
-
49
- self.initialize_weights()
50
- if use_fp16:
51
- self.convert_to_fp16()
52
-
53
- def initialize_weights(self) -> None:
54
- super().initialize_weights()
55
- # Zero-out output layers:
56
- nn.init.constant_(self.out_layer.weight, 0)
57
- nn.init.constant_(self.out_layer.bias, 0)
58
-
59
- def _build_perturbation(self) -> None:
60
- perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
61
- perturbation = torch.tensor(perturbation).float() * 2 - 1
62
- perturbation = perturbation / self.rep_config['voxel_size']
63
- perturbation = torch.atanh(perturbation).to(self.device)
64
- self.register_buffer('offset_perturbation', perturbation)
65
-
66
- def _calc_layout(self) -> None:
67
- self.layout = {
68
- '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
69
- '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
70
- '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
71
- '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
72
- '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
73
- }
74
- start = 0
75
- for k, v in self.layout.items():
76
- v['range'] = (start, start + v['size'])
77
- start += v['size']
78
- self.out_channels = start
79
-
80
- def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
81
- """
82
- Convert a batch of network outputs to 3D representations.
83
-
84
- Args:
85
- x: The [N x * x C] sparse tensor output by the network.
86
-
87
- Returns:
88
- list of representations
89
- """
90
- ret = []
91
- for i in range(x.shape[0]):
92
- representation = Gaussian(
93
- sh_degree=0,
94
- aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
95
- mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
96
- scaling_bias = self.rep_config['scaling_bias'],
97
- opacity_bias = self.rep_config['opacity_bias'],
98
- scaling_activation = self.rep_config['scaling_activation']
99
- )
100
- xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
101
- for k, v in self.layout.items():
102
- if k == '_xyz':
103
- offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
104
- offset = offset * self.rep_config['lr'][k]
105
- if self.rep_config['perturb_offset']:
106
- offset = offset + self.offset_perturbation
107
- offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
108
- _xyz = xyz.unsqueeze(1) + offset
109
- setattr(representation, k, _xyz.flatten(0, 1))
110
- else:
111
- feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
112
- feats = feats * self.rep_config['lr'][k]
113
- setattr(representation, k, feats)
114
- ret.append(representation)
115
- return ret
116
-
117
- def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
118
- h = super().forward(x)
119
- h = h.type(x.dtype)
120
- h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
121
- h = self.out_layer(h)
122
- return self.to_representation(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_vae/decoder_mesh.py DELETED
@@ -1,167 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
- from ...modules import sparse as sp
8
- from .base import SparseTransformerBase
9
- from ...representations import MeshExtractResult
10
- from ...representations.mesh import SparseFeatures2Mesh
11
-
12
-
13
- class SparseSubdivideBlock3d(nn.Module):
14
- """
15
- A 3D subdivide block that can subdivide the sparse tensor.
16
-
17
- Args:
18
- channels: channels in the inputs and outputs.
19
- out_channels: if specified, the number of output channels.
20
- num_groups: the number of groups for the group norm.
21
- """
22
- def __init__(
23
- self,
24
- channels: int,
25
- resolution: int,
26
- out_channels: Optional[int] = None,
27
- num_groups: int = 32
28
- ):
29
- super().__init__()
30
- self.channels = channels
31
- self.resolution = resolution
32
- self.out_resolution = resolution * 2
33
- self.out_channels = out_channels or channels
34
-
35
- self.act_layers = nn.Sequential(
36
- sp.SparseGroupNorm32(num_groups, channels),
37
- sp.SparseSiLU()
38
- )
39
-
40
- self.sub = sp.SparseSubdivide()
41
-
42
- self.out_layers = nn.Sequential(
43
- sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
44
- sp.SparseGroupNorm32(num_groups, self.out_channels),
45
- sp.SparseSiLU(),
46
- zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
47
- )
48
-
49
- if self.out_channels == channels:
50
- self.skip_connection = nn.Identity()
51
- else:
52
- self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53
-
54
- def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
- """
56
- Apply the block to a Tensor, conditioned on a timestep embedding.
57
-
58
- Args:
59
- x: an [N x C x ...] Tensor of features.
60
- Returns:
61
- an [N x C x ...] Tensor of outputs.
62
- """
63
- h = self.act_layers(x)
64
- h = self.sub(h)
65
- x = self.sub(x)
66
- h = self.out_layers(h)
67
- h = h + self.skip_connection(x)
68
- return h
69
-
70
-
71
- class SLatMeshDecoder(SparseTransformerBase):
72
- def __init__(
73
- self,
74
- resolution: int,
75
- model_channels: int,
76
- latent_channels: int,
77
- num_blocks: int,
78
- num_heads: Optional[int] = None,
79
- num_head_channels: Optional[int] = 64,
80
- mlp_ratio: float = 4,
81
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
82
- window_size: int = 8,
83
- pe_mode: Literal["ape", "rope"] = "ape",
84
- use_fp16: bool = False,
85
- use_checkpoint: bool = False,
86
- qk_rms_norm: bool = False,
87
- representation_config: dict = None,
88
- ):
89
- super().__init__(
90
- in_channels=latent_channels,
91
- model_channels=model_channels,
92
- num_blocks=num_blocks,
93
- num_heads=num_heads,
94
- num_head_channels=num_head_channels,
95
- mlp_ratio=mlp_ratio,
96
- attn_mode=attn_mode,
97
- window_size=window_size,
98
- pe_mode=pe_mode,
99
- use_fp16=use_fp16,
100
- use_checkpoint=use_checkpoint,
101
- qk_rms_norm=qk_rms_norm,
102
- )
103
- self.resolution = resolution
104
- self.rep_config = representation_config
105
- self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
106
- self.out_channels = self.mesh_extractor.feats_channels
107
- self.upsample = nn.ModuleList([
108
- SparseSubdivideBlock3d(
109
- channels=model_channels,
110
- resolution=resolution,
111
- out_channels=model_channels // 4
112
- ),
113
- SparseSubdivideBlock3d(
114
- channels=model_channels // 4,
115
- resolution=resolution * 2,
116
- out_channels=model_channels // 8
117
- )
118
- ])
119
- self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120
-
121
- self.initialize_weights()
122
- if use_fp16:
123
- self.convert_to_fp16()
124
-
125
- def initialize_weights(self) -> None:
126
- super().initialize_weights()
127
- # Zero-out output layers:
128
- nn.init.constant_(self.out_layer.weight, 0)
129
- nn.init.constant_(self.out_layer.bias, 0)
130
-
131
- def convert_to_fp16(self) -> None:
132
- """
133
- Convert the torso of the model to float16.
134
- """
135
- super().convert_to_fp16()
136
- self.upsample.apply(convert_module_to_f16)
137
-
138
- def convert_to_fp32(self) -> None:
139
- """
140
- Convert the torso of the model to float32.
141
- """
142
- super().convert_to_fp32()
143
- self.upsample.apply(convert_module_to_f32)
144
-
145
- def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146
- """
147
- Convert a batch of network outputs to 3D representations.
148
-
149
- Args:
150
- x: The [N x * x C] sparse tensor output by the network.
151
-
152
- Returns:
153
- list of representations
154
- """
155
- ret = []
156
- for i in range(x.shape[0]):
157
- mesh = self.mesh_extractor(x[i], training=self.training)
158
- ret.append(mesh)
159
- return ret
160
-
161
- def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
162
- h = super().forward(x)
163
- for block in self.upsample:
164
- h = block(h)
165
- h = h.type(x.dtype)
166
- h = self.out_layer(h)
167
- return self.to_representation(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_vae/decoder_rf.py DELETED
@@ -1,104 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from ...modules import sparse as sp
7
- from .base import SparseTransformerBase
8
- from ...representations import Strivec
9
-
10
-
11
- class SLatRadianceFieldDecoder(SparseTransformerBase):
12
- def __init__(
13
- self,
14
- resolution: int,
15
- model_channels: int,
16
- latent_channels: int,
17
- num_blocks: int,
18
- num_heads: Optional[int] = None,
19
- num_head_channels: Optional[int] = 64,
20
- mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
- window_size: int = 8,
23
- pe_mode: Literal["ape", "rope"] = "ape",
24
- use_fp16: bool = False,
25
- use_checkpoint: bool = False,
26
- qk_rms_norm: bool = False,
27
- representation_config: dict = None,
28
- ):
29
- super().__init__(
30
- in_channels=latent_channels,
31
- model_channels=model_channels,
32
- num_blocks=num_blocks,
33
- num_heads=num_heads,
34
- num_head_channels=num_head_channels,
35
- mlp_ratio=mlp_ratio,
36
- attn_mode=attn_mode,
37
- window_size=window_size,
38
- pe_mode=pe_mode,
39
- use_fp16=use_fp16,
40
- use_checkpoint=use_checkpoint,
41
- qk_rms_norm=qk_rms_norm,
42
- )
43
- self.resolution = resolution
44
- self.rep_config = representation_config
45
- self._calc_layout()
46
- self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
47
-
48
- self.initialize_weights()
49
- if use_fp16:
50
- self.convert_to_fp16()
51
-
52
- def initialize_weights(self) -> None:
53
- super().initialize_weights()
54
- # Zero-out output layers:
55
- nn.init.constant_(self.out_layer.weight, 0)
56
- nn.init.constant_(self.out_layer.bias, 0)
57
-
58
- def _calc_layout(self) -> None:
59
- self.layout = {
60
- 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
61
- 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
62
- 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
63
- }
64
- start = 0
65
- for k, v in self.layout.items():
66
- v['range'] = (start, start + v['size'])
67
- start += v['size']
68
- self.out_channels = start
69
-
70
- def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
71
- """
72
- Convert a batch of network outputs to 3D representations.
73
-
74
- Args:
75
- x: The [N x * x C] sparse tensor output by the network.
76
-
77
- Returns:
78
- list of representations
79
- """
80
- ret = []
81
- for i in range(x.shape[0]):
82
- representation = Strivec(
83
- sh_degree=0,
84
- resolution=self.resolution,
85
- aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
86
- rank=self.rep_config['rank'],
87
- dim=self.rep_config['dim'],
88
- device='cuda',
89
- )
90
- representation.density_shift = 0.0
91
- representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
92
- representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
93
- for k, v in self.layout.items():
94
- setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
95
- representation.trivec = representation.trivec + 1
96
- ret.append(representation)
97
- return ret
98
-
99
- def forward(self, x: sp.SparseTensor) -> List[Strivec]:
100
- h = super().forward(x)
101
- h = h.type(x.dtype)
102
- h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
103
- h = self.out_layer(h)
104
- return self.to_representation(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/models/structured_latent_vae/encoder.py DELETED
@@ -1,72 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from ...modules import sparse as sp
6
- from .base import SparseTransformerBase
7
-
8
-
9
- class SLatEncoder(SparseTransformerBase):
10
- def __init__(
11
- self,
12
- resolution: int,
13
- in_channels: int,
14
- model_channels: int,
15
- latent_channels: int,
16
- num_blocks: int,
17
- num_heads: Optional[int] = None,
18
- num_head_channels: Optional[int] = 64,
19
- mlp_ratio: float = 4,
20
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
21
- window_size: int = 8,
22
- pe_mode: Literal["ape", "rope"] = "ape",
23
- use_fp16: bool = False,
24
- use_checkpoint: bool = False,
25
- qk_rms_norm: bool = False,
26
- ):
27
- super().__init__(
28
- in_channels=in_channels,
29
- model_channels=model_channels,
30
- num_blocks=num_blocks,
31
- num_heads=num_heads,
32
- num_head_channels=num_head_channels,
33
- mlp_ratio=mlp_ratio,
34
- attn_mode=attn_mode,
35
- window_size=window_size,
36
- pe_mode=pe_mode,
37
- use_fp16=use_fp16,
38
- use_checkpoint=use_checkpoint,
39
- qk_rms_norm=qk_rms_norm,
40
- )
41
- self.resolution = resolution
42
- self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
43
-
44
- self.initialize_weights()
45
- if use_fp16:
46
- self.convert_to_fp16()
47
-
48
- def initialize_weights(self) -> None:
49
- super().initialize_weights()
50
- # Zero-out output layers:
51
- nn.init.constant_(self.out_layer.weight, 0)
52
- nn.init.constant_(self.out_layer.bias, 0)
53
-
54
- def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
55
- h = super().forward(x)
56
- h = h.type(x.dtype)
57
- h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
58
- h = self.out_layer(h)
59
-
60
- # Sample from the posterior distribution
61
- mean, logvar = h.feats.chunk(2, dim=-1)
62
- if sample_posterior:
63
- std = torch.exp(0.5 * logvar)
64
- z = mean + std * torch.randn_like(std)
65
- else:
66
- z = mean
67
- z = h.replace(z)
68
-
69
- if return_raw:
70
- return z, mean, logvar
71
- else:
72
- return z