from enum import Enum import re from sentence_transformers import SentenceTransformer from transformers import pipeline, AutoTokenizer from functools import lru_cache from guards.svnr import is_valid_svnr from dataclasses import dataclass from typing import List @dataclass class Answer(): answer: str sources: List[str] processing_time: float MODEL_NAME = 'google/gemma-2-2b-it' ROLE_ASSISTANT ="assistant" EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+") @lru_cache(maxsize=1) def get_similarity_model(): return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu") @lru_cache(maxsize=1) def get_toxicity_model(): return pipeline( "text-classification", model="unitary/toxic-bert", top_k=None ) @lru_cache(maxsize=1) def get_hallucination_model(): return pipeline( "text-classification", model='vectara/hallucination_evaluation_model', tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-base'), trust_remote_code=True ) def check_toxicity(text: str): """ Toxicity check with toxicity check pipeline Args: response (str): response generated by model Returns: """ try: # output sequence can be little longer due to tokens != sequence length, this makes sure that text is cropped to model input max size if len(text) > 512: text = text[:500] result = get_toxicity_model()(text)[0] # print(result) if result: toxicity_score = next((s['score'] for s in result if s['score'] > 0.85), sum(s['score'] for s in result) / len(result)) # print(toxicity_score) toxicity_passed = True if toxicity_score < 0.6 else False text = AUTO_ANSWERS.LANGUAGE_INAPPROPRIATE.value if not toxicity_passed else "" return toxicity_passed, toxicity_score, text else: raise ValueError("Could not yield result.") except Exception as e: print(f"Error while checking language: {e}") return False, 1.0, AUTO_ANSWERS.UNEXPECTED_ERROR.value def sanitize(text: str) -> str: """ Sanitizes a string by redacting emails and valid Austrian social security numbers (SVNRs). """ # Redact SVNRs potential_svnrs = re.findall(r'\b\d{10}\b', text) for svnr in potential_svnrs: if is_valid_svnr(svnr): text = text.replace(svnr, "[REDACTED SVNR]") # Redact emails return EMAIL_PATTERN.sub('[REDACTED_EMAIL]', text) class AUTO_ANSWERS(Enum): COULD_NOT_GENERATE = "Could not generate an answer." REQUEST_TIMED_OUT = "The request timed out." UNEXPECTED_ERROR = "An unexpected error occured." ANSWER_NOT_RELEVANT_TO_QUERY = "Answer is not relevant to the query." ANSWER_NOT_RELEVANT_TO_CONTEXT = "Answer not relevant to given context." RELEVANCE_CHECK_FAILED = "Relevance check failed." HALLUCINATION_CHECK_FAILED = "Hallucination check failed." NOT_SUPPORTED_BY_CONTEXT = "Many facts not supported by context." INCONSISTENT_WITH_CONTEXT = "Inconsistent with context." CORRECTNESS_CHECK_FAILED = "Correctness check failed." LANGUAGE_INAPPROPRIATE = "Inappropriate language detected!" REPHRASE_SENTENCE = "Try to rephrase your request." INVALID_INPUT = "Invalid input detected."