MCPcopy Index your code
hub / github.com/huggingface/diffusers / compute_snr

Function compute_snr

src/diffusers/training_utils.py:76–110  ·  view source on GitHub ↗

Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 for the given timesteps using the provided noise scheduler. Args: noise_scheduler (`NoiseScheduler`):

(noise_scheduler, timesteps)

Source from the content-addressed store, hash-verified

74
75
76def compute_snr(noise_scheduler, timesteps):
77 """
78 Computes SNR as per
79 https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
80 for the given timesteps using the provided noise scheduler.
81
82 Args:
83 noise_scheduler (`NoiseScheduler`):
84 An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
85 the SNR values.
86 timesteps (`torch.Tensor`):
87 A tensor of timesteps for which the SNR is computed.
88
89 Returns:
90 `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
91 """
92 alphas_cumprod = noise_scheduler.alphas_cumprod
93 sqrt_alphas_cumprod = alphas_cumprod**0.5
94 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
95
96 # Expand the tensors.
97 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
98 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
99 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
100 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
101 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
102
103 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
104 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
105 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
106 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
107
108 # Compute SNR.
109 snr = (alpha / sigma) ** 2
110 return snr
111
112
113def compute_confidence_aware_loss(

Callers 15

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 2

floatMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…