Spaces:
Runtime error
Runtime error
| import ast | |
| import io | |
| import math | |
| import statistics | |
| import string | |
| import cairosvg | |
| import clip | |
| import cv2 | |
| import kagglehub | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from more_itertools import chunked | |
| from PIL import Image, ImageFilter | |
| from transformers import ( | |
| AutoProcessor, | |
| BitsAndBytesConfig, | |
| PaliGemmaForConditionalGeneration, | |
| ) | |
| svg_constraints = kagglehub.package_import('metric/svg-constraints') | |
| class ParticipantVisibleError(Exception): | |
| pass | |
| def score( | |
| solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str, random_seed: int = 0 | |
| ) -> float: | |
| """Calculates a fidelity score by comparing generated SVG images to target text descriptions. | |
| Parameters | |
| ---------- | |
| solution : pd.DataFrame | |
| A DataFrame containing target questions, choices, and answers about an SVG image. | |
| submission : pd.DataFrame | |
| A DataFrame containing generated SVG strings. Must have a column named 'svg'. | |
| row_id_column_name : str | |
| The name of the column containing row identifiers. This column is removed before scoring. | |
| random_seed : int | |
| A seed to set the random state. | |
| Returns | |
| ------- | |
| float | |
| The mean fidelity score (a value between 0 and 1) representing the average similarity between the generated SVGs and their descriptions. | |
| A higher score indicates better fidelity. | |
| Raises | |
| ------ | |
| ParticipantVisibleError | |
| If the 'svg' column in the submission DataFrame is not of string type or if validation of the SVG fails. | |
| Examples | |
| -------- | |
| >>> import pandas as pd | |
| >>> solution = pd.DataFrame({ | |
| ... 'id': ["abcde"], | |
| ... 'question': ['["Is there a red circle?", "What shape is present?"]'], | |
| ... 'choices': ['[["yes", "no"], ["square", "circle", "triangle", "hexagon"]]'], | |
| ... 'answer': ['["yes", "circle"]'], | |
| ... }) | |
| >>> submission = pd.DataFrame({ | |
| ... 'id': ["abcde"], | |
| ... 'svg': ['<svg viewBox="0 0 100 100"><circle cx="50" cy="50" r="40" fill="red"/></svg>'], | |
| ... }) | |
| >>> score(solution, submission, 'row_id', random_seed=42) | |
| 0... | |
| """ | |
| # Convert solution fields to list dtypes and expand | |
| for colname in ['question', 'choices', 'answer']: | |
| solution[colname] = solution[colname].apply(ast.literal_eval) | |
| solution = solution.explode(['question', 'choices', 'answer']) | |
| # Validate | |
| if not pd.api.types.is_string_dtype(submission.loc[:, 'svg']): | |
| raise ParticipantVisibleError('svg must be a string.') | |
| # Check that SVG code meets defined constraints | |
| constraints = svg_constraints.SVGConstraints() | |
| try: | |
| for svg in submission.loc[:, 'svg']: | |
| constraints.validate_svg(svg) | |
| except: | |
| raise ParticipantVisibleError('SVG code violates constraints.') | |
| # Score | |
| vqa_evaluator = VQAEvaluator() | |
| aesthetic_evaluator = AestheticEvaluator() | |
| results = [] | |
| rng = np.random.RandomState(random_seed) | |
| try: | |
| df = solution.merge(submission, on='id') | |
| for i, (_, group) in enumerate(df.loc[ | |
| :, ['id', 'question', 'choices', 'answer', 'svg'] | |
| ].groupby('id')): | |
| questions, choices, answers, svg = [ | |
| group[col_name].to_list() | |
| for col_name in group.drop('id', axis=1).columns | |
| ] | |
| svg = svg[0] # unpack singleton from list | |
| group_seed = rng.randint(0, np.iinfo(np.int32).max) | |
| image_processor = ImageProcessor(image=svg_to_png(svg), seed=group_seed).apply() | |
| image = image_processor.image.copy() | |
| aesthetic_score = aesthetic_evaluator.score(image) | |
| vqa_score = vqa_evaluator.score(questions, choices, answers, image) | |
| image_processor.reset().apply_random_crop_resize().apply_jpeg_compression(quality=90) | |
| ocr_score = vqa_evaluator.ocr(image_processor.image) | |
| instance_score = ( | |
| harmonic_mean(vqa_score, aesthetic_score, beta=0.5) * ocr_score | |
| ) | |
| results.append(instance_score) | |
| except: | |
| raise ParticipantVisibleError('SVG failed to score.') | |
| fidelity = statistics.mean(results) | |
| return float(fidelity) | |
| class VQAEvaluator: | |
| """Evaluates images based on their similarity to a given text description using multiple choice questions.""" | |
| def __init__(self): | |
| self.quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type='nf4', | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| self.letters = string.ascii_uppercase | |
| self.model_path = kagglehub.model_download( | |
| 'google/paligemma-2/transformers/paligemma2-10b-mix-448' | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(self.model_path) | |
| self.model = PaliGemmaForConditionalGeneration.from_pretrained( | |
| self.model_path, | |
| low_cpu_mem_usage=True, | |
| quantization_config=self.quantization_config, | |
| ).to('cuda') | |
| def score(self, questions, choices, answers, image, n=4): | |
| scores = [] | |
| batches = (chunked(qs, n) for qs in [questions, choices, answers]) | |
| for question_batch, choice_batch, answer_batch in zip(*batches, strict=True): | |
| scores.extend( | |
| self.score_batch( | |
| image, | |
| question_batch, | |
| choice_batch, | |
| answer_batch, | |
| ) | |
| ) | |
| return statistics.mean(scores) | |
| def score_batch( | |
| self, | |
| image: Image.Image, | |
| questions: list[str], | |
| choices_list: list[list[str]], | |
| answers: list[str], | |
| ) -> list[float]: | |
| """Evaluates the image based on multiple choice questions and answers. | |
| Parameters | |
| ---------- | |
| image : PIL.Image.Image | |
| The image to evaluate. | |
| questions : list[str] | |
| List of questions about the image. | |
| choices_list : list[list[str]] | |
| List of lists of possible answer choices, corresponding to each question. | |
| answers : list[str] | |
| List of correct answers from the choices, corresponding to each question. | |
| Returns | |
| ------- | |
| list[float] | |
| List of scores (values between 0 and 1) representing the probability of the correct answer for each question. | |
| """ | |
| prompts = [ | |
| self.format_prompt(question, choices) | |
| for question, choices in zip(questions, choices_list, strict=True) | |
| ] | |
| batched_choice_probabilities = self.get_choice_probability( | |
| image, prompts, choices_list | |
| ) | |
| scores = [] | |
| for i, _ in enumerate(questions): | |
| choice_probabilities = batched_choice_probabilities[i] | |
| answer = answers[i] | |
| answer_probability = 0.0 | |
| for choice, prob in choice_probabilities.items(): | |
| if choice == answer: | |
| answer_probability = prob | |
| break | |
| scores.append(answer_probability) | |
| return scores | |
| def format_prompt(self, question: str, choices: list[str]) -> str: | |
| prompt = f'<image>answer en Question: {question}\nChoices:\n' | |
| for i, choice in enumerate(choices): | |
| prompt += f'{self.letters[i]}. {choice}\n' | |
| return prompt | |
| def mask_choices(self, logits, choices_list): | |
| """Masks logits for the first token of each choice letter for each question in the batch.""" | |
| batch_size = logits.shape[0] | |
| masked_logits = torch.full_like(logits, float('-inf')) | |
| for batch_idx in range(batch_size): | |
| choices = choices_list[batch_idx] | |
| for i in range(len(choices)): | |
| letter_token = self.letters[i] | |
| first_token = self.processor.tokenizer.encode( | |
| letter_token, add_special_tokens=False | |
| )[0] | |
| first_token_with_space = self.processor.tokenizer.encode( | |
| ' ' + letter_token, add_special_tokens=False | |
| )[0] | |
| if isinstance(first_token, int): | |
| masked_logits[batch_idx, first_token] = logits[ | |
| batch_idx, first_token | |
| ] | |
| if isinstance(first_token_with_space, int): | |
| masked_logits[batch_idx, first_token_with_space] = logits[ | |
| batch_idx, first_token_with_space | |
| ] | |
| return masked_logits | |
| def get_choice_probability(self, image, prompts, choices_list) -> list[dict]: | |
| inputs = self.processor( | |
| images=[image] * len(prompts), | |
| text=prompts, | |
| return_tensors='pt', | |
| padding='longest', | |
| ).to('cuda') | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits[:, -1, :] # Logits for the last (predicted) token | |
| masked_logits = self.mask_choices(logits, choices_list) | |
| probabilities = torch.softmax(masked_logits, dim=-1) | |
| batched_choice_probabilities = [] | |
| for batch_idx in range(len(prompts)): | |
| choice_probabilities = {} | |
| choices = choices_list[batch_idx] | |
| for i, choice in enumerate(choices): | |
| letter_token = self.letters[i] | |
| first_token = self.processor.tokenizer.encode( | |
| letter_token, add_special_tokens=False | |
| )[0] | |
| first_token_with_space = self.processor.tokenizer.encode( | |
| ' ' + letter_token, add_special_tokens=False | |
| )[0] | |
| prob = 0.0 | |
| if isinstance(first_token, int): | |
| prob += probabilities[batch_idx, first_token].item() | |
| if isinstance(first_token_with_space, int): | |
| prob += probabilities[batch_idx, first_token_with_space].item() | |
| choice_probabilities[choice] = prob | |
| # Renormalize probabilities for each question | |
| total_prob = sum(choice_probabilities.values()) | |
| if total_prob > 0: | |
| renormalized_probabilities = { | |
| choice: prob / total_prob | |
| for choice, prob in choice_probabilities.items() | |
| } | |
| else: | |
| renormalized_probabilities = ( | |
| choice_probabilities # Avoid division by zero if total_prob is 0 | |
| ) | |
| batched_choice_probabilities.append(renormalized_probabilities) | |
| return batched_choice_probabilities | |
| def ocr(self, image, free_chars=4): | |
| inputs = ( | |
| self.processor( | |
| text='<image>ocr\n', | |
| images=image, | |
| return_tensors='pt', | |
| ) | |
| .to(torch.float16) | |
| .to(self.model.device) | |
| ) | |
| input_len = inputs['input_ids'].shape[-1] | |
| with torch.inference_mode(): | |
| outputs = self.model.generate(**inputs, max_new_tokens=32, do_sample=False) | |
| outputs = outputs[0][input_len:] | |
| decoded = self.processor.decode(outputs, skip_special_tokens=True) | |
| num_char = len(decoded) | |
| # Exponentially decreasing towards 0.0 if more than free_chars detected | |
| return min(1.0, math.exp(-num_char + free_chars)) | |
| class AestheticPredictor(nn.Module): | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.layers = nn.Sequential( | |
| nn.Linear(self.input_size, 1024), | |
| nn.Dropout(0.2), | |
| nn.Linear(1024, 128), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| nn.Linear(16, 1), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class AestheticEvaluator: | |
| def __init__(self): | |
| self.model_path = 'improved-aesthetic-predictor/sac+logos+ava1-l14-linearMSE.pth' | |
| self.clip_model_path = 'ViT-L/14' | |
| self.predictor, self.clip_model, self.preprocessor = self.load() | |
| def load(self): | |
| """Loads the aesthetic predictor model and CLIP model.""" | |
| state_dict = torch.load(self.model_path, weights_only=True, map_location='cuda') | |
| # CLIP embedding dim is 768 for CLIP ViT L 14 | |
| predictor = AestheticPredictor(768) | |
| predictor.load_state_dict(state_dict) | |
| predictor.to('cuda') | |
| predictor.eval() | |
| clip_model, preprocessor = clip.load(self.clip_model_path, device='cuda') | |
| return predictor, clip_model, preprocessor | |
| def score(self, image: Image.Image) -> float: | |
| """Predicts the CLIP aesthetic score of an image.""" | |
| image = self.preprocessor(image).unsqueeze(0).to('cuda') | |
| with torch.no_grad(): | |
| image_features = self.clip_model.encode_image(image) | |
| # l2 normalize | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| image_features = image_features.cpu().detach().numpy() | |
| score = self.predictor(torch.from_numpy(image_features).to('cuda').float()) | |
| return score.item() / 10.0 # scale to [0, 1] | |
| def harmonic_mean(a: float, b: float, beta: float = 1.0) -> float: | |
| """ | |
| Calculate the harmonic mean of two values, weighted using a beta parameter. | |
| Args: | |
| a: First value (e.g., precision) | |
| b: Second value (e.g., recall) | |
| beta: Weighting parameter | |
| Returns: | |
| Weighted harmonic mean | |
| """ | |
| # Handle zero values to prevent division by zero | |
| if a <= 0 or b <= 0: | |
| return 0.0 | |
| return (1 + beta**2) * (a * b) / (beta**2 * a + b) | |
| def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image: | |
| """ | |
| Converts an SVG string to a PNG image using CairoSVG. | |
| If the SVG does not define a `viewBox`, it will add one using the provided size. | |
| Parameters | |
| ---------- | |
| svg_code : str | |
| The SVG string to convert. | |
| size : tuple[int, int], default=(384, 384) | |
| The desired size of the output PNG image (width, height). | |
| Returns | |
| ------- | |
| PIL.Image.Image | |
| The generated PNG image. | |
| """ | |
| # Ensure SVG has proper size attributes | |
| if 'viewBox' not in svg_code: | |
| svg_code = svg_code.replace('<svg', f'<svg viewBox="0 0 {size[0]} {size[1]}"') | |
| # Convert SVG to PNG | |
| png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8')) | |
| return Image.open(io.BytesIO(png_data)).convert('RGB').resize(size) | |
| class ImageProcessor: | |
| def __init__(self, image: Image.Image, seed=None): | |
| """Initialize with either a path to an image or a PIL Image object.""" | |
| self.image = image | |
| self.original_image = self.image.copy() | |
| if seed is not None: | |
| self.rng = np.random.RandomState(seed) | |
| else: | |
| self.rng = np.random | |
| def reset(self): | |
| self.image = self.original_image.copy() | |
| return self | |
| def visualize_comparison( | |
| self, | |
| original_name='Original', | |
| processed_name='Processed', | |
| figsize=(10, 5), | |
| show=True, | |
| ): | |
| """Display original and processed images side by side.""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) | |
| ax1.imshow(np.asarray(self.original_image)) | |
| ax1.set_title(original_name) | |
| ax1.axis('off') | |
| ax2.imshow(np.asarray(self.image)) | |
| ax2.set_title(processed_name) | |
| ax2.axis('off') | |
| title = f'{original_name} vs {processed_name}' | |
| fig.suptitle(title) | |
| fig.tight_layout() | |
| if show: | |
| plt.show() | |
| return fig | |
| def apply_median_filter(self, size=3): | |
| """Apply median filter to remove outlier pixel values. | |
| Args: | |
| size: Size of the median filter window. | |
| """ | |
| self.image = self.image.filter(ImageFilter.MedianFilter(size=size)) | |
| return self | |
| def apply_bilateral_filter(self, d=9, sigma_color=75, sigma_space=75): | |
| """Apply bilateral filter to smooth while preserving edges. | |
| Args: | |
| d: Diameter of each pixel neighborhood | |
| sigma_color: Filter sigma in the color space | |
| sigma_space: Filter sigma in the coordinate space | |
| """ | |
| # Convert PIL Image to numpy array for OpenCV | |
| img_array = np.asarray(self.image) | |
| # Apply bilateral filter | |
| filtered = cv2.bilateralFilter(img_array, d, sigma_color, sigma_space) | |
| # Convert back to PIL Image | |
| self.image = Image.fromarray(filtered) | |
| return self | |
| def apply_fft_low_pass(self, cutoff_frequency=0.5): | |
| """Apply low-pass filter in the frequency domain using FFT. | |
| Args: | |
| cutoff_frequency: Normalized cutoff frequency (0-1). | |
| Lower values remove more high frequencies. | |
| """ | |
| # Convert to numpy array, ensuring float32 for FFT | |
| img_array = np.array(self.image, dtype=np.float32) | |
| # Process each color channel separately | |
| result = np.zeros_like(img_array) | |
| for i in range(3): # For RGB channels | |
| # Apply FFT | |
| f = np.fft.fft2(img_array[:, :, i]) | |
| fshift = np.fft.fftshift(f) | |
| # Create a low-pass filter mask | |
| rows, cols = img_array[:, :, i].shape | |
| crow, ccol = rows // 2, cols // 2 | |
| mask = np.zeros((rows, cols), np.float32) | |
| r = int(min(crow, ccol) * cutoff_frequency) | |
| center = [crow, ccol] | |
| x, y = np.ogrid[:rows, :cols] | |
| mask_area = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= r * r | |
| mask[mask_area] = 1 | |
| # Apply mask and inverse FFT | |
| fshift_filtered = fshift * mask | |
| f_ishift = np.fft.ifftshift(fshift_filtered) | |
| img_back = np.fft.ifft2(f_ishift) | |
| img_back = np.real(img_back) | |
| result[:, :, i] = img_back | |
| # Clip to 0-255 range and convert to uint8 after processing all channels | |
| result = np.clip(result, 0, 255).astype(np.uint8) | |
| # Convert back to PIL Image | |
| self.image = Image.fromarray(result) | |
| return self | |
| def apply_jpeg_compression(self, quality=85): | |
| """Apply JPEG compression. | |
| Args: | |
| quality: JPEG quality (0-95). Lower values increase compression. | |
| """ | |
| buffer = io.BytesIO() | |
| self.image.save(buffer, format='JPEG', quality=quality) | |
| buffer.seek(0) | |
| self.image = Image.open(buffer) | |
| return self | |
| def apply_random_crop_resize(self, crop_percent=0.05): | |
| """Randomly crop and resize back to original dimensions. | |
| Args: | |
| crop_percent: Percentage of image to crop (0-0.4). | |
| """ | |
| width, height = self.image.size | |
| crop_pixels_w = int(width * crop_percent) | |
| crop_pixels_h = int(height * crop_percent) | |
| left = self.rng.randint(0, crop_pixels_w + 1) | |
| top = self.rng.randint(0, crop_pixels_h + 1) | |
| right = width - self.rng.randint(0, crop_pixels_w + 1) | |
| bottom = height - self.rng.randint(0, crop_pixels_h + 1) | |
| self.image = self.image.crop((left, top, right, bottom)) | |
| self.image = self.image.resize((width, height), Image.BILINEAR) | |
| return self | |
| def apply(self): | |
| """Apply an ensemble of defenses.""" | |
| return ( | |
| self.apply_random_crop_resize(crop_percent=0.03) | |
| .apply_jpeg_compression(quality=95) | |
| .apply_median_filter(size=9) | |
| .apply_fft_low_pass(cutoff_frequency=0.5) | |
| .apply_bilateral_filter(d=5, sigma_color=75, sigma_space=75) | |
| .apply_jpeg_compression(quality=92) | |
| ) |