MCPcopy Index your code
hub / github.com/huggingface/diffusers / ConsistencyDecoder

Class ConsistencyDecoder

scripts/convert_consistency_decoder.py:82–179  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

80
81
82class ConsistencyDecoder:
83 def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
84 self.n_distilled_steps = 64
85 download_target = _download(
86 "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
87 download_root,
88 )
89 self.ckpt = torch.jit.load(download_target).to(device)
90 self.device = device
91 sigma_data = 0.5
92 betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
93 alphas = 1.0 - betas
94 alphas_cumprod = torch.cumprod(alphas, dim=0)
95 self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
96 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
97 sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
98 sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
99 self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
100 self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
101 self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
102
103 @staticmethod
104 def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
105 with torch.no_grad():
106 space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
107 rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
108 if truncate_start:
109 rounded_timesteps[rounded_timesteps == total_timesteps] -= space
110 else:
111 rounded_timesteps[rounded_timesteps == total_timesteps] -= space
112 rounded_timesteps[rounded_timesteps == 0] += space
113 return rounded_timesteps
114
115 @staticmethod
116 def ldm_transform_latent(z, extra_scale_factor=1):
117 channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
118 channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
119
120 if len(z.shape) != 4:
121 raise ValueError()
122
123 z = z * 0.18215
124 channels = [z[:, i] for i in range(z.shape[1])]
125
126 channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
127 return torch.stack(channels, dim=1)
128
129 @torch.no_grad()
130 def __call__(
131 self,
132 features: torch.Tensor,
133 schedule=[1.0, 0.5],
134 generator=None,
135 ):
136 features = self.ldm_transform_latent(features)
137 ts = self.round_timesteps(
138 torch.arange(0, 1024),
139 1024,

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…