MCPcopy Index your code
hub / github.com/huggingface/diffusers / compute_confidence_aware_loss

Function compute_confidence_aware_loss

src/diffusers/training_utils.py:113–196  ·  view source on GitHub ↗

Computes a confidence-aware training loss for token classification-style heads. This loss combines: - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. Arg

(
    logits: torch.Tensor,
    labels: torch.Tensor,
    *,
    lambda_conf: float = 0.0,
    temperature: float = 1.0,
    per_token_weights: torch.Tensor | None = None,
    ignore_index: int = -100,
)

Source from the content-addressed store, hash-verified

111
112
113def compute_confidence_aware_loss(
114 logits: torch.Tensor,
115 labels: torch.Tensor,
116 *,
117 lambda_conf: float = 0.0,
118 temperature: float = 1.0,
119 per_token_weights: torch.Tensor | None = None,
120 ignore_index: int = -100,
121) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
122 """
123 Computes a confidence-aware training loss for token classification-style heads.
124
125 This loss combines:
126 - `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
127 - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
128
129 Args:
130 logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
131 labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
132 are excluded from both losses.
133 lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
134 temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
135 sharpen the distribution and change the strength of the confidence gradients.
136 per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
137 token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
138 ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
139
140 Returns:
141 `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
142 """
143 if logits.ndim < 2:
144 raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
145 if labels.shape != logits.shape[:-1]:
146 raise ValueError(
147 f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
148 )
149 if temperature <= 0:
150 raise ValueError(f"`temperature` must be > 0, got {temperature}.")
151
152 valid = labels.ne(ignore_index)
153 if per_token_weights is None:
154 weights = torch.ones_like(labels, dtype=logits.dtype)
155 else:
156 if per_token_weights.shape != labels.shape:
157 raise ValueError(
158 f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
159 )
160 weights = per_token_weights.to(dtype=logits.dtype)
161
162 # Supervised CE (optionally weighted).
163 vocab_size = logits.shape[-1]
164 per_token_nll = F.cross_entropy(
165 logits.reshape(-1, vocab_size),
166 labels.reshape(-1),
167 reduction="none",
168 ignore_index=ignore_index,
169 ).reshape_as(labels)
170

Callers 2

mainFunction · 0.90

Calls 2

floatMethod · 0.80
toMethod · 0.45

Tested by 1

Used in the wild real call sites across dependent graphs

searching dependent graphs…