cheesecz commited on
Commit
038a896
·
verified ·
1 Parent(s): 42768f4

create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ import json
6
+
7
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
8
+
9
+ MODEL_NAME = "s-nlp/roberta-base-formality-ranker"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = model.to(device)
15
+
16
+ def predict_formality(text):
17
+ # Tokenize input
18
+ encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
+ encoding = {k: v.to(device) for k, v in encoding.items()}
20
+
21
+ # Predict formality score
22
+ with torch.no_grad():
23
+ logits = model(**encoding).logits
24
+ score = logits.softmax(dim=1)[:, 1].item()
25
+
26
+ # Calculate percentages
27
+ formal_percent = round(score * 100)
28
+ informal_percent = 100 - formal_percent
29
+
30
+ # Create response in the new format
31
+ response = {
32
+ "formality_score": round(score, 3),
33
+ "formal_percent": formal_percent,
34
+ "informal_percent": informal_percent,
35
+ "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
36
+ }
37
+
38
+ return response
39
+
40
+ demo = gr.Interface(
41
+ fn=predict_formality,
42
+ inputs=gr.Textbox(label="Enter your text", lines=3),
43
+ outputs=gr.JSON(label="Formality Analysis"),
44
+ title="Formality Classifier",
45
+ description="Enter text to analyze its formality level.",
46
+ examples=[
47
+ ["Hello, how are you doing today?"],
48
+ ["Hey, what's up?"],
49
+ ["I would like to request your assistance with this matter."]
50
+ ]
51
+ )
52
+
53
+ # Launch the app
54
+ if __name__ == "__main__":
55
+ demo.launch(server_name="0.0.0.0", server_port=7860)