|
|
from enum import Enum |
|
|
import re |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import pipeline, AutoTokenizer |
|
|
from functools import lru_cache |
|
|
from guards.svnr import is_valid_svnr |
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = 'google/gemma-2-2b-it' |
|
|
ROLE_ASSISTANT ="assistant" |
|
|
EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+") |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_similarity_model(): |
|
|
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu") |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_toxicity_model(): |
|
|
return pipeline( |
|
|
"text-classification", |
|
|
model="unitary/toxic-bert", |
|
|
top_k=None |
|
|
) |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def get_hallucination_model(): |
|
|
return pipeline( |
|
|
"text-classification", |
|
|
model='vectara/hallucination_evaluation_model', |
|
|
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-base'), |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
def check_toxicity(text: str): |
|
|
""" Toxicity check with toxicity check pipeline |
|
|
Args: |
|
|
response (str): response generated by model |
|
|
Returns: |
|
|
|
|
|
""" |
|
|
try: |
|
|
|
|
|
if len(text) > 512: |
|
|
text = text[:500] |
|
|
|
|
|
result = get_toxicity_model()(text)[0] |
|
|
|
|
|
if result: |
|
|
toxicity_score = next((s['score'] for s in result if s['score'] > 0.85), sum(s['score'] for s in result) / len(result)) |
|
|
|
|
|
toxicity_passed = True if toxicity_score < 0.6 else False |
|
|
text = AUTO_ANSWERS.LANGUAGE_INAPPROPRIATE.value if not toxicity_passed else "" |
|
|
return toxicity_passed, toxicity_score, text |
|
|
else: |
|
|
raise ValueError("Could not yield result.") |
|
|
except Exception as e: |
|
|
print(f"Error while checking language: {e}") |
|
|
return False, 1.0, AUTO_ANSWERS.UNEXPECTED_ERROR.value |
|
|
|
|
|
def sanitize(text: str) -> str: |
|
|
""" |
|
|
Sanitizes a string by redacting emails and valid Austrian social security numbers (SVNRs). |
|
|
""" |
|
|
|
|
|
potential_svnrs = re.findall(r'\b\d{10}\b', text) |
|
|
for svnr in potential_svnrs: |
|
|
if is_valid_svnr(svnr): |
|
|
text = text.replace(svnr, "[REDACTED SVNR]") |
|
|
|
|
|
|
|
|
return EMAIL_PATTERN.sub('[REDACTED_EMAIL]', text) |
|
|
|
|
|
class AUTO_ANSWERS(Enum): |
|
|
COULD_NOT_GENERATE = "Could not generate an answer." |
|
|
REQUEST_TIMED_OUT = "The request timed out." |
|
|
UNEXPECTED_ERROR = "An unexpected error occured." |
|
|
ANSWER_NOT_RELEVANT_TO_QUERY = "Answer is not relevant to the query." |
|
|
ANSWER_NOT_RELEVANT_TO_CONTEXT = "Answer not relevant to given context." |
|
|
RELEVANCE_CHECK_FAILED = "Relevance check failed." |
|
|
HALLUCINATION_CHECK_FAILED = "Hallucination check failed." |
|
|
NOT_SUPPORTED_BY_CONTEXT = "Many facts not supported by context." |
|
|
INCONSISTENT_WITH_CONTEXT = "Inconsistent with context." |
|
|
CORRECTNESS_CHECK_FAILED = "Correctness check failed." |
|
|
LANGUAGE_INAPPROPRIATE = "Inappropriate language detected!" |
|
|
REPHRASE_SENTENCE = "Try to rephrase your request." |
|
|
INVALID_INPUT = "Invalid input detected." |
|
|
|