yazoniak's picture
Repo initialized
7336cba verified
import argparse
import torch
import numpy as np
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
DEFAULT_THRESHOLD = 0.5
def preprocess_text(text, anonymize_mentions=True):
if anonymize_mentions:
text = re.sub(r'@\w+', '@anonymized_account', text)
return text
def main():
parser = argparse.ArgumentParser()
parser.add_argument("text", type=str, help="Text to classify")
parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier",
help="Path to model or HF model ID")
parser.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD,
help="Classification threshold (default: 0.5)")
parser.add_argument("--no-anonymize", action="store_true",
help="Disable mention anonymization (not recommended)")
args = parser.parse_args()
print(f"Loading model from: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModelForSequenceClassification.from_pretrained(args.model_path)
model.eval()
labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
anonymize = not args.no_anonymize
processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
if anonymize and processed_text != args.text:
print(f"Preprocessed text: {processed_text}")
print(f"\nInput text: {args.text}\n")
inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze().numpy()
probabilities = 1 / (1 + np.exp(-logits))
predictions = probabilities > args.threshold
assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
if assigned_labels:
print("Assigned Labels:")
print("-" * 40)
for label in assigned_labels:
print(f" {label}")
print()
else:
print("No labels assigned (all below threshold)\n")
print("All Labels (with probabilities):")
print("-" * 40)
for i, label in enumerate(labels):
status = "✓" if predictions[i] else " "
print(f"{status} {label:15s}: {probabilities[i]:.4f}")
if __name__ == "__main__":
main()