| import argparse | |
| import torch | |
| import numpy as np | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| DEFAULT_THRESHOLD = 0.5 | |
| def preprocess_text(text, anonymize_mentions=True): | |
| if anonymize_mentions: | |
| text = re.sub(r'@\w+', '@anonymized_account', text) | |
| return text | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("text", type=str, help="Text to classify") | |
| parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", | |
| help="Path to model or HF model ID") | |
| parser.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD, | |
| help="Classification threshold (default: 0.5)") | |
| parser.add_argument("--no-anonymize", action="store_true", | |
| help="Disable mention anonymization (not recommended)") | |
| args = parser.parse_args() | |
| print(f"Loading model from: {args.model_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(args.model_path) | |
| model.eval() | |
| labels = [model.config.id2label[i] for i in range(model.config.num_labels)] | |
| anonymize = not args.no_anonymize | |
| processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) | |
| if anonymize and processed_text != args.text: | |
| print(f"Preprocessed text: {processed_text}") | |
| print(f"\nInput text: {args.text}\n") | |
| inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits.squeeze().numpy() | |
| probabilities = 1 / (1 + np.exp(-logits)) | |
| predictions = probabilities > args.threshold | |
| assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] | |
| if assigned_labels: | |
| print("Assigned Labels:") | |
| print("-" * 40) | |
| for label in assigned_labels: | |
| print(f" {label}") | |
| print() | |
| else: | |
| print("No labels assigned (all below threshold)\n") | |
| print("All Labels (with probabilities):") | |
| print("-" * 40) | |
| for i, label in enumerate(labels): | |
| status = "✓" if predictions[i] else " " | |
| print(f"{status} {label:15s}: {probabilities[i]:.4f}") | |
| if __name__ == "__main__": | |
| main() | |