| | import numpy as np |
| | import json |
| | from huggingface_hub import hf_hub_download |
| | import re |
| | import emoji |
| | from transformers import BertTokenizer |
| | import onnxruntime as ort |
| |
|
| | def preprocess_text(text): |
| | """Preprocess the input text to match training conditions.""" |
| | text = re.sub(r'u/\w+', '[USER]', text) |
| | text = re.sub(r'r/\w+', '[SUBREDDIT]', text) |
| | text = re.sub(r'http[s]?://\S+', '[URL]', text) |
| | text = emoji.demojize(text, delimiters=(" ", " ")) |
| | text = text.lower() |
| | return text |
| |
|
| | def load_model_and_resources(): |
| | """Load the ONNX model, tokenizer, emotion labels, and thresholds from Hugging Face.""" |
| | repo_id = "logasanjeev/emotions-analyzer-bert" |
| | |
| | try: |
| | tokenizer = BertTokenizer.from_pretrained(repo_id) |
| | except Exception as e: |
| | raise RuntimeError(f"Error loading tokenizer: {str(e)}") |
| |
|
| | try: |
| | model_path = hf_hub_download(repo_id=repo_id, filename="model.onnx") |
| | session = ort.InferenceSession(model_path) |
| | except Exception as e: |
| | raise RuntimeError(f"Error loading ONNX model: {str(e)}") |
| |
|
| | try: |
| | thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json") |
| | with open(thresholds_file, "r") as f: |
| | thresholds_data = json.load(f) |
| | if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data): |
| | raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.") |
| | emotion_labels = thresholds_data["emotion_labels"] |
| | thresholds = thresholds_data["thresholds"] |
| | except Exception as e: |
| | raise RuntimeError(f"Error loading thresholds: {str(e)}") |
| |
|
| | return session, tokenizer, emotion_labels, thresholds |
| |
|
| | SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = None, None, None, None |
| |
|
| | def predict_emotions(text): |
| | """Predict emotions for the given text using the GoEmotions BERT ONNX model. |
| | |
| | Args: |
| | text (str): The input text to analyze. |
| | |
| | Returns: |
| | tuple: (predictions, processed_text) |
| | - predictions (str): Formatted string of predicted emotions and their confidence scores. |
| | - processed_text (str): The preprocessed input text. |
| | """ |
| | global SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS |
| | |
| | if SESSION is None: |
| | SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = load_model_and_resources() |
| |
|
| | processed_text = preprocess_text(text) |
| | |
| | encodings = TOKENIZER( |
| | processed_text, |
| | padding='max_length', |
| | truncation=True, |
| | max_length=128, |
| | return_tensors='np' |
| | ) |
| | |
| | inputs = { |
| | 'input_ids': encodings['input_ids'].astype(np.int64), |
| | 'attention_mask': encodings['attention_mask'].astype(np.int64) |
| | } |
| | |
| | logits = SESSION.run(None, inputs)[0][0] |
| | logits = 1 / (1 + np.exp(-logits)) |
| | |
| | predictions = [] |
| | for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)): |
| | if logit >= thresh: |
| | predictions.append((EMOTION_LABELS[i], round(logit, 4))) |
| | |
| | predictions.sort(key=lambda x: x[1], reverse=True) |
| | |
| | result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted." |
| | return result, processed_text |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT ONNX model.") |
| | parser.add_argument("text", type=str, help="The input text to analyze for emotions.") |
| | args = parser.parse_args() |
| | |
| | result, processed = predict_emotions(args.text) |
| | print(f"Input: {args.text}") |
| | print(f"Processed: {processed}") |
| | print("Predicted Emotions:") |
| | print(result) |