RAG / helper.py
Samuel Oberhofer
feat: Redact SVNR from console output and include in vector store
761de06
raw
history blame
3.27 kB
from enum import Enum
import re
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer
from functools import lru_cache
from guards.svnr import is_valid_svnr
MODEL_NAME = 'google/gemma-2-2b-it'
ROLE_ASSISTANT ="assistant"
EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")
@lru_cache(maxsize=1)
def get_similarity_model():
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="cpu")
@lru_cache(maxsize=1)
def get_toxicity_model():
return pipeline(
"text-classification",
model="unitary/toxic-bert",
top_k=None
)
@lru_cache(maxsize=1)
def get_hallucination_model():
return pipeline(
"text-classification",
model='vectara/hallucination_evaluation_model',
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-base'),
trust_remote_code=True
)
def check_toxicity(text: str):
""" Toxicity check with toxicity check pipeline
Args:
response (str): response generated by model
Returns:
"""
try:
# output sequence can be little longer due to tokens != sequence length, this makes sure that text is cropped to model input max size
if len(text) > 512:
text = text[:500]
result = get_toxicity_model()(text)[0]
# print(result)
if result:
toxicity_score = next((s['score'] for s in result if s['score'] > 0.85), sum(s['score'] for s in result) / len(result))
# print(toxicity_score)
toxicity_passed = True if toxicity_score < 0.6 else False
text = AUTO_ANSWERS.LANGUAGE_INAPPROPRIATE.value if not toxicity_passed else ""
return toxicity_passed, toxicity_score, text
else:
raise ValueError("Could not yield result.")
except Exception as e:
print(f"Error while checking language: {e}")
return False, 1.0, AUTO_ANSWERS.UNEXPECTED_ERROR.value
def sanitize(text: str) -> str:
"""
Sanitizes a string by redacting emails and valid Austrian social security numbers (SVNRs).
"""
# Redact SVNRs
potential_svnrs = re.findall(r'\b\d{10}\b', text)
for svnr in potential_svnrs:
if is_valid_svnr(svnr):
text = text.replace(svnr, "[REDACTED SVNR]")
# Redact emails
return EMAIL_PATTERN.sub('[REDACTED_EMAIL]', text)
class AUTO_ANSWERS(Enum):
COULD_NOT_GENERATE = "Could not generate an answer."
REQUEST_TIMED_OUT = "The request timed out."
UNEXPECTED_ERROR = "An unexpected error occured."
ANSWER_NOT_RELEVANT_TO_QUERY = "Answer is not relevant to the query."
ANSWER_NOT_RELEVANT_TO_CONTEXT = "Answer not relevant to given context."
RELEVANCE_CHECK_FAILED = "Relevance check failed."
HALLUCINATION_CHECK_FAILED = "Hallucination check failed."
NOT_SUPPORTED_BY_CONTEXT = "Many facts not supported by context."
INCONSISTENT_WITH_CONTEXT = "Inconsistent with context."
CORRECTNESS_CHECK_FAILED = "Correctness check failed."
LANGUAGE_INAPPROPRIATE = "Inappropriate language detected!"
REPHRASE_SENTENCE = "Try to rephrase your request."
INVALID_INPUT = "Invalid input detected."