| 2989 | |
| 2990 | |
| 2991 | class ResnetUpsampleBlock2D(nn.Module): |
| 2992 | def __init__( |
| 2993 | self, |
| 2994 | in_channels: int, |
| 2995 | prev_output_channel: int, |
| 2996 | out_channels: int, |
| 2997 | temb_channels: int, |
| 2998 | resolution_idx: int | None = None, |
| 2999 | dropout: float = 0.0, |
| 3000 | num_layers: int = 1, |
| 3001 | resnet_eps: float = 1e-6, |
| 3002 | resnet_time_scale_shift: str = "default", |
| 3003 | resnet_act_fn: str = "swish", |
| 3004 | resnet_groups: int = 32, |
| 3005 | resnet_pre_norm: bool = True, |
| 3006 | output_scale_factor: float = 1.0, |
| 3007 | add_upsample: bool = True, |
| 3008 | skip_time_act: bool = False, |
| 3009 | ): |
| 3010 | super().__init__() |
| 3011 | resnets = [] |
| 3012 | |
| 3013 | for i in range(num_layers): |
| 3014 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels |
| 3015 | resnet_in_channels = prev_output_channel if i == 0 else out_channels |
| 3016 | |
| 3017 | resnets.append( |
| 3018 | ResnetBlock2D( |
| 3019 | in_channels=resnet_in_channels + res_skip_channels, |
| 3020 | out_channels=out_channels, |
| 3021 | temb_channels=temb_channels, |
| 3022 | eps=resnet_eps, |
| 3023 | groups=resnet_groups, |
| 3024 | dropout=dropout, |
| 3025 | time_embedding_norm=resnet_time_scale_shift, |
| 3026 | non_linearity=resnet_act_fn, |
| 3027 | output_scale_factor=output_scale_factor, |
| 3028 | pre_norm=resnet_pre_norm, |
| 3029 | skip_time_act=skip_time_act, |
| 3030 | ) |
| 3031 | ) |
| 3032 | |
| 3033 | self.resnets = nn.ModuleList(resnets) |
| 3034 | |
| 3035 | if add_upsample: |
| 3036 | self.upsamplers = nn.ModuleList( |
| 3037 | [ |
| 3038 | ResnetBlock2D( |
| 3039 | in_channels=out_channels, |
| 3040 | out_channels=out_channels, |
| 3041 | temb_channels=temb_channels, |
| 3042 | eps=resnet_eps, |
| 3043 | groups=resnet_groups, |
| 3044 | dropout=dropout, |
| 3045 | time_embedding_norm=resnet_time_scale_shift, |
| 3046 | non_linearity=resnet_act_fn, |
| 3047 | output_scale_factor=output_scale_factor, |
| 3048 | pre_norm=resnet_pre_norm, |
no outgoing calls
no test coverage detected
searching dependent graphs…