cheesecz commited on
Commit
4bf9f99
·
verified ·
1 Parent(s): 7cf74f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -30
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- import gradio as gr
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
  import warnings
@@ -19,6 +20,11 @@ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
  # Move model to appropriate device
20
  model = model.to(device)
21
 
 
 
 
 
 
22
  def calculate_formality_percentages(score):
23
  # Convert score to grayscale percentage (0-100)
24
  grayscale = int(score * 100)
@@ -27,10 +33,15 @@ def calculate_formality_percentages(score):
27
  informal_percent = 100 - grayscale
28
  return formal_percent, informal_percent
29
 
30
- def predict_formality(text):
 
 
 
 
 
31
  try:
32
  # Tokenize input
33
- encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
  encoding = {k: v.to(device) for k, v in encoding.items()}
35
 
36
  # Predict formality score
@@ -51,32 +62,8 @@ def predict_formality(text):
51
 
52
  return response
53
  except Exception as e:
54
- return {
55
- "error": str(e),
56
- "formality_score": 0,
57
- "formal_percent": 0,
58
- "informal_percent": 0,
59
- "classification": "Error processing the text."
60
- }
61
-
62
- # Create Gradio interface
63
- demo = gr.Interface(
64
- fn=predict_formality,
65
- inputs=gr.Textbox(label="Enter your text", lines=3),
66
- outputs=gr.JSON(label="Formality Analysis"),
67
- title="Formality Classifier",
68
- description="Enter text to analyze its formality level.",
69
- examples=[
70
- ["Hello, how are you doing today?"],
71
- ["Hey, what's up?"],
72
- ["I would like to request your assistance with this matter."]
73
- ]
74
- )
75
 
76
- # Launch the app
77
  if __name__ == "__main__":
78
- demo.queue() # Enable request queuing
79
- demo.launch(
80
- server_name="0.0.0.0",
81
- server_port=7860
82
- )
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  import warnings
 
20
  # Move model to appropriate device
21
  model = model.to(device)
22
 
23
+ app = FastAPI(title="Formality Classifier API")
24
+
25
+ class TextInput(BaseModel):
26
+ text: str
27
+
28
  def calculate_formality_percentages(score):
29
  # Convert score to grayscale percentage (0-100)
30
  grayscale = int(score * 100)
 
33
  informal_percent = 100 - grayscale
34
  return formal_percent, informal_percent
35
 
36
+ @app.get("/")
37
+ async def home():
38
+ return {"message": "Formality Classifier API is running! Use /predict to classify text."}
39
+
40
+ @app.post("/predict")
41
+ async def predict_formality(input_data: TextInput):
42
  try:
43
  # Tokenize input
44
+ encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True)
45
  encoding = {k: v.to(device) for k, v in encoding.items()}
46
 
47
  # Predict formality score
 
62
 
63
  return response
64
  except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
67
  if __name__ == "__main__":
68
+ import uvicorn
69
+ uvicorn.run(app, host="0.0.0.0", port=7860)