| 1259 | |
| 1260 | |
| 1261 | class TimestepEmbedding(nn.Module): |
| 1262 | def __init__( |
| 1263 | self, |
| 1264 | in_channels: int, |
| 1265 | time_embed_dim: int, |
| 1266 | act_fn: str = "silu", |
| 1267 | out_dim: int = None, |
| 1268 | post_act_fn: str | None = None, |
| 1269 | cond_proj_dim=None, |
| 1270 | sample_proj_bias=True, |
| 1271 | ): |
| 1272 | super().__init__() |
| 1273 | |
| 1274 | self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) |
| 1275 | |
| 1276 | if cond_proj_dim is not None: |
| 1277 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
| 1278 | else: |
| 1279 | self.cond_proj = None |
| 1280 | |
| 1281 | self.act = get_activation(act_fn) |
| 1282 | |
| 1283 | if out_dim is not None: |
| 1284 | time_embed_dim_out = out_dim |
| 1285 | else: |
| 1286 | time_embed_dim_out = time_embed_dim |
| 1287 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) |
| 1288 | |
| 1289 | if post_act_fn is None: |
| 1290 | self.post_act = None |
| 1291 | else: |
| 1292 | self.post_act = get_activation(post_act_fn) |
| 1293 | |
| 1294 | def forward(self, sample, condition=None): |
| 1295 | if condition is not None: |
| 1296 | sample = sample + self.cond_proj(condition) |
| 1297 | sample = self.linear_1(sample) |
| 1298 | |
| 1299 | if self.act is not None: |
| 1300 | sample = self.act(sample) |
| 1301 | |
| 1302 | sample = self.linear_2(sample) |
| 1303 | |
| 1304 | if self.post_act is not None: |
| 1305 | sample = self.post_act(sample) |
| 1306 | return sample |
| 1307 | |
| 1308 | |
| 1309 | class Timesteps(nn.Module): |
no outgoing calls
no test coverage detected
searching dependent graphs…