Visual Document Retrieval
Transformers
Safetensors
ColPali
English
colqwen2
pretraining

Nan values in the embeddings

#2
by samu - opened

Occassionally, i get nan values in both image and text embeddings. For the text, having longer sentences solves this, but for images, am not sure about the cause

Vidore org

Hey @samu could you provide a minimal reproducible code that highlights this behavior, so that we can investigate ?

It's not only limited to the task that am doing, but also on the sample example, getting nan values in both image and text queries. As I said, it's a very occassional issue because some days, it works great, other days i get the nan values

import os
import gc
import glob
import sqlite3
from pathlib import Path
import glob
from tqdm import tqdm
import pickle
import io
import random
import math
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import PIL
from PIL import Image
import base64
import matplotlib.pyplot as plt
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available

os.environ['TOKENIZERS_PARALLELISM'] = "false"

def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print('seeding done!!!')

def clear_cache():
    """Clear GPU memory cache for different platforms."""
    gc.collect()
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            torch.mps.empty_cache()
    except Exception as e:
        print(f"Warning: Could not clear cache: {str(e)}")

def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

device_map = get_device()
print(f"Device Type is : {device_map}")

class CONFIG:
    model_name = "colqwen2-v1.0-hf"
    device_map = get_device()
    attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None
    local_files_only = True
    trust_remote_code = True
    batch_size = 4
    seed = 0

seeding(CONFIG.seed)

# Load the model and the processor
model = ColQwen2ForRetrieval.from_pretrained(
    CONFIG.model_name,
    torch_dtype=torch.bfloat16,
    device_map=CONFIG.device_map,
    attn_implementation=CONFIG.attn_implementation,
    local_files_only=CONFIG.local_files_only, 
    trust_remote_code=CONFIG.trust_remote_code,
).eval()
print(f"Model dtype after loading: {model.dtype}")

processor = ColQwen2Processor.from_pretrained(CONFIG.model_name)

def pillow_to_base64(image: PIL.Image, quality=75):
    image_data = io.BytesIO()
    image.save(image_data, format='PNG', optimize=True, quality=quality)
    image_data.seek(0)
    base64_encoded = base64.b64encode(image_data.getvalue()).decode('utf-8')
    return base64_encoded

def base64_to_pillow(base64_string):
    if "base64," in base64_string:
        base64_string = base64_string.split("base64,")[1]
    try:
        image_bytes = base64.b64decode(base64_string)
        image_stream = io.BytesIO(image_bytes)
        image = PIL.Image.open(image_stream)
        return image
    except Exception as e:
        print(f"Error converting Base64 to Pillow Image: {e}")
        return None

class ChunkDataset(Dataset):
    def __init__(self, img_paths):
        self.paths = img_paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        base64_str = pillow_to_base64(img)
        return img, base64_str

def collate_fn(batch):
    images, base64s = zip(*batch)
    processed = processor(images=list(images)).to(model.device)
    processed['base64'] = list(base64s)
    return processed

img_paths = list(Path("chunks_output/").glob("*.png"))
len(img_paths)

ds = ChunkDataset(img_paths=img_paths)
dls = DataLoader(ds, batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_fn)

b = next(iter(dls))
bs64 = b.pop('base64')
clear_cache()
# model.eval()
b = {k:v.to(model.device) if isinstance(v, torch.Tensor) else v for k,v in b.items()}
with torch.no_grad():
    outputs = model(**b).embeddings

The outputs

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[ 0.0199, -0.0850, -0.0713,  ...,  0.0535, -0.0033,  0.1099],
         [ 0.0194, -0.1084, -0.0272,  ...,  0.0762, -0.0449, -0.1396],
         [-0.0012, -0.0996, -0.0198,  ...,  0.0762, -0.0334, -0.1299],
         ...,
         [ 0.0703, -0.1162, -0.1001,  ...,  0.0659,  0.0347,  0.0601],
         [ 0.0464, -0.1211, -0.0294,  ...,  0.0376, -0.0413, -0.1445],
         [ 0.2734, -0.1348,  0.0835,  ..., -0.0659, -0.1133, -0.0277]],

        [[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]],
       device='mps:0', dtype=torch.bfloat16)

Sign up or log in to comment