| 80 | |
| 81 | |
| 82 | class ConsistencyDecoder: |
| 83 | def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")): |
| 84 | self.n_distilled_steps = 64 |
| 85 | download_target = _download( |
| 86 | "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt", |
| 87 | download_root, |
| 88 | ) |
| 89 | self.ckpt = torch.jit.load(download_target).to(device) |
| 90 | self.device = device |
| 91 | sigma_data = 0.5 |
| 92 | betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device) |
| 93 | alphas = 1.0 - betas |
| 94 | alphas_cumprod = torch.cumprod(alphas, dim=0) |
| 95 | self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| 96 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
| 97 | sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) |
| 98 | sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) |
| 99 | self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) |
| 100 | self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 |
| 101 | self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 |
| 102 | |
| 103 | @staticmethod |
| 104 | def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True): |
| 105 | with torch.no_grad(): |
| 106 | space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") |
| 107 | rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space |
| 108 | if truncate_start: |
| 109 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space |
| 110 | else: |
| 111 | rounded_timesteps[rounded_timesteps == total_timesteps] -= space |
| 112 | rounded_timesteps[rounded_timesteps == 0] += space |
| 113 | return rounded_timesteps |
| 114 | |
| 115 | @staticmethod |
| 116 | def ldm_transform_latent(z, extra_scale_factor=1): |
| 117 | channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294] |
| 118 | channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034] |
| 119 | |
| 120 | if len(z.shape) != 4: |
| 121 | raise ValueError() |
| 122 | |
| 123 | z = z * 0.18215 |
| 124 | channels = [z[:, i] for i in range(z.shape[1])] |
| 125 | |
| 126 | channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)] |
| 127 | return torch.stack(channels, dim=1) |
| 128 | |
| 129 | @torch.no_grad() |
| 130 | def __call__( |
| 131 | self, |
| 132 | features: torch.Tensor, |
| 133 | schedule=[1.0, 0.5], |
| 134 | generator=None, |
| 135 | ): |
| 136 | features = self.ldm_transform_latent(features) |
| 137 | ts = self.round_timesteps( |
| 138 | torch.arange(0, 1024), |
| 139 | 1024, |
no outgoing calls
no test coverage detected
searching dependent graphs…