import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from torch import Tensor class ContrastiveLoss(nn.Module): def __init__(self, temperature=.25, distance_metric='cosine'): super(ContrastiveLoss, self).__init__() self.temperature = temperature self.distance_metric = distance_metric def compute_similarity(self, embeddings): if self.distance_metric == 'cosine': embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D] sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T] else: raise ValueError(f"Unsupported distance metric: {self.distance_metric}") return sim / self.temperature def compute_cross_similarity(self, embeddings1, embeddings2): """Compute similarity between two different embedding sets""" if self.distance_metric == 'cosine': embeddings1 = F.normalize(embeddings1, p=2, dim=-1) # [B, 2T, D] embeddings2 = F.normalize(embeddings2, p=2, dim=-1) # [B, 2T, D] sim = torch.matmul(embeddings1, embeddings2.transpose(-1, -2)) # [B, 2T, 2T] else: raise ValueError(f"Unsupported distance metric: {self.distance_metric}") return sim / self.temperature def pairwise_and_no_diag(self, m): m_i = m.unsqueeze(2) # [B, T, 1] m_j = m.unsqueeze(1) # [B, 1, T] out = m_i & m_j # [B, T, T] diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0) return out & ~diag def forward(self, embeddings, anchors, enrollment_embeddings: Optional[Tensor] = None, enrollment_embeddings_mask: Optional[Tensor] = None): """ Args: embeddings: [B, 2T, D] - main embeddings anchors: [B, 2T] - boolean mask indicating anchor positions enrollment_embeddings: Optional[B, 2T, D] - enrollment embeddings for positive pairs enrollment_embeddings_mask: Optional[B, 2T] - boolean mask for valid enrollment positions Returns: Scalar contrastive loss """ # Use enrollment embeddings if provided if enrollment_embeddings is not None and enrollment_embeddings_mask is not None: return self._forward_with_enrollment(embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask) else: # Fall back to original behavior return self._forward_original(embeddings, anchors) def _forward_with_enrollment(self, embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask): """Forward pass using enrollment embeddings as positives""" B, two_T, D = embeddings.shape T = two_T // 2 # Compute similarity between main embeddings and enrollment embeddings cross_sim = self.compute_cross_similarity(embeddings, enrollment_embeddings) # [B, 2T, 2T] # Compute similarity within main embeddings for negatives self_sim = self.compute_similarity(embeddings) # [B, 2T, 2T] # Split anchor mask m1 = anchors[:, :T] # [B, T] m2 = anchors[:, T:] # [B, T] # Split enrollment mask enroll_m1 = enrollment_embeddings_mask[:, :T] # [B, T] enroll_m2 = enrollment_embeddings_mask[:, T:] # [B, T] # Create positive mask: anchor positions can match with corresponding enrollment positions # First speaker (positions 0:T) matches with enrollment first speaker (positions 0:T) pos_mask_1to1 = m1.unsqueeze(2) & enroll_m1.unsqueeze(1) # [B, T, T] # Second speaker (positions T:2T) matches with enrollment second speaker (positions T:2T) pos_mask_2to2 = m2.unsqueeze(2) & enroll_m2.unsqueeze(1) # [B, T, T] # Build full positive mask pos_mask = torch.cat([ torch.cat([pos_mask_1to1, torch.zeros_like(pos_mask_1to1)], dim=2), # [B, T, 2T] torch.cat([torch.zeros_like(pos_mask_2to2), pos_mask_2to2], dim=2) # [B, T, 2T] ], dim=1) # [B, 2T, 2T] # Create negative mask: cross-speaker pairs within main embeddings cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T] neg_mask = torch.cat([ torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T] torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T] ], dim=1) # [B, 2T, 2T] # Exclude self-pairs in negative mask identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T] neg_mask &= ~identity_mask # Also exclude self-pairs in positive mask (diagonal elements) pos_mask &= ~identity_mask # Compute contrastive loss if pos_mask.any(): # Get positive similarities from cross-similarity matrix pos_sim = cross_sim[pos_mask] # [num_pos_pairs] pos_exp = torch.exp(pos_sim) # [num_pos_pairs] # Compute negative exponentials from self-similarity matrix exp_self_sim = torch.exp(self_sim) # [B, 2T, 2T] neg_exp_sum = torch.sum(exp_self_sim * neg_mask.float(), dim=2) # [B, 2T] # Get the negative sums corresponding to each positive pair pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3] batch_idx = pos_indices[:, 0] # [num_pos_pairs] row_idx = pos_indices[:, 1] # [num_pos_pairs] # Get negative sums for each positive pair's anchor neg_sums_for_pos = neg_exp_sum[batch_idx, row_idx] # [num_pos_pairs] # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair denominators = pos_exp + neg_sums_for_pos # [num_pos_pairs] # InfoNCE loss: -log(exp(pos) / denominator) loss = -torch.log(pos_exp / denominators) total_loss = loss.mean() else: # No positive pairs found, return zero loss total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True) return total_loss def _forward_original(self, embeddings, pos_indicator_mask): """Original forward pass for backward compatibility""" B, two_T, D = embeddings.shape T = two_T // 2 sim = self.compute_similarity(embeddings) # [B, 2T, 2T] # Split input mask m1 = pos_indicator_mask[:, :T] # [B, T] m2 = pos_indicator_mask[:, T:] # [B, T] # Positive mask (same speaker pairs, diagonal excluded) pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T] pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T] pos_mask = torch.cat([ torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T] torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T] ], dim=1) # [B, 2T, 2T] # Negative mask (cross-speaker pairs where both are active) cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T] neg_mask = torch.cat([ torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T] torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T] ], dim=1) # [B, 2T, 2T] # Identity mask (exclude [i, i] self-pairs) identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T] pos_mask &= ~identity_mask neg_mask &= ~identity_mask # Fully vectorized InfoNCE computation if pos_mask.any(): # Compute exp(similarities) for numerical stability exp_sim = torch.exp(sim) # [B, 2T, 2T] # Get positive similarities pos_sim = sim[pos_mask] # [num_pos_pairs] pos_exp = torch.exp(pos_sim) # [num_pos_pairs] # For each position, sum the exponentials of its negatives neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T] # Get the negative sums corresponding to each positive pair pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3] batch_idx = pos_indices[:, 0] # [num_pos_pairs] row_idx = pos_indices[:, 1] # [num_pos_pairs] # Get negative sums for each positive pair's anchor neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs] # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs] # InfoNCE loss: -log(exp(pos) / denominator) loss = -torch.log(pos_exp / denominators) total_loss = loss.mean() else: # No positive pairs found, return zero loss total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True) return total_loss