Update pipeline.py
Browse files- pipeline.py +6 -4
pipeline.py
CHANGED
|
@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
|
|
| 340 |
# assign weights to the prompts and normalize in the sense of mean
|
| 341 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
| 342 |
if (not skip_parsing) and (not skip_weighting):
|
| 343 |
-
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
| 344 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
| 345 |
-
|
|
|
|
| 346 |
if uncond_prompt is not None:
|
| 347 |
-
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
| 348 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
| 349 |
-
|
|
|
|
| 350 |
|
| 351 |
if uncond_prompt is not None:
|
| 352 |
return text_embeddings, uncond_embeddings
|
|
|
|
| 340 |
# assign weights to the prompts and normalize in the sense of mean
|
| 341 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
| 342 |
if (not skip_parsing) and (not skip_weighting):
|
| 343 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 344 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
| 345 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 346 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 347 |
if uncond_prompt is not None:
|
| 348 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 349 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
| 350 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 351 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 352 |
|
| 353 |
if uncond_prompt is not None:
|
| 354 |
return text_embeddings, uncond_embeddings
|