Nan values in the embeddings
#2
by
samu
- opened
Occassionally, i get nan values in both image and text embeddings. For the text, having longer sentences solves this, but for images, am not sure about the cause
It's not only limited to the task that am doing, but also on the sample example, getting nan values in both image and text queries. As I said, it's a very occassional issue because some days, it works great, other days i get the nan values
import os
import gc
import glob
import sqlite3
from pathlib import Path
import glob
from tqdm import tqdm
import pickle
import io
import random
import math
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import PIL
from PIL import Image
import base64
import matplotlib.pyplot as plt
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available
os.environ['TOKENIZERS_PARALLELISM'] = "false"
def seeding(SEED):
np.random.seed(SEED)
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print('seeding done!!!')
def clear_cache():
"""Clear GPU memory cache for different platforms."""
gc.collect()
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
except Exception as e:
print(f"Warning: Could not clear cache: {str(e)}")
def get_device():
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
device_map = get_device()
print(f"Device Type is : {device_map}")
class CONFIG:
model_name = "colqwen2-v1.0-hf"
device_map = get_device()
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None
local_files_only = True
trust_remote_code = True
batch_size = 4
seed = 0
seeding(CONFIG.seed)
# Load the model and the processor
model = ColQwen2ForRetrieval.from_pretrained(
CONFIG.model_name,
torch_dtype=torch.bfloat16,
device_map=CONFIG.device_map,
attn_implementation=CONFIG.attn_implementation,
local_files_only=CONFIG.local_files_only,
trust_remote_code=CONFIG.trust_remote_code,
).eval()
print(f"Model dtype after loading: {model.dtype}")
processor = ColQwen2Processor.from_pretrained(CONFIG.model_name)
def pillow_to_base64(image: PIL.Image, quality=75):
image_data = io.BytesIO()
image.save(image_data, format='PNG', optimize=True, quality=quality)
image_data.seek(0)
base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8')
return base64_encoded
def base64_to_pillow(base64_string):
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]
try:
image_bytes = base64.b64decode(base64_string)
image_stream = io.BytesIO(image_bytes)
image = PIL.Image.open(image_stream)
return image
except Exception as e:
print(f"Error converting Base64 to Pillow Image: {e}")
return None
class ChunkDataset(Dataset):
def __init__(self, img_paths):
self.paths = img_paths
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
base64_str = pillow_to_base64(img)
return img, base64_str
def collate_fn(batch):
images, base64s = zip(*batch)
processed = processor(images=list(images)).to(model.device)
processed['base64'] = list(base64s)
return processed
img_paths = list(Path("chunks_output/").glob("*.png"))
len(img_paths)
ds = ChunkDataset(img_paths=img_paths)
dls = DataLoader(ds, batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_fn)
b = next(iter(dls))
bs64 = b.pop('base64')
clear_cache()
# model.eval()
b = {k:v.to(model.device) if isinstance(v, torch.Tensor) else v for k,v in b.items()}
with torch.no_grad():
outputs = model(**b).embeddings
The outputs
tensor([[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ 0.0199, -0.0850, -0.0713, ..., 0.0535, -0.0033, 0.1099],
[ 0.0194, -0.1084, -0.0272, ..., 0.0762, -0.0449, -0.1396],
[-0.0012, -0.0996, -0.0198, ..., 0.0762, -0.0334, -0.1299],
...,
[ 0.0703, -0.1162, -0.1001, ..., 0.0659, 0.0347, 0.0601],
[ 0.0464, -0.1211, -0.0294, ..., 0.0376, -0.0413, -0.1445],
[ 0.2734, -0.1348, 0.0835, ..., -0.0659, -0.1133, -0.0277]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]]],
device='mps:0', dtype=torch.bfloat16)