Spaces:
Sleeping
Sleeping
| # app.py — Text detector with image-app style output (40-word min, loader, HTML result) | |
| import os, sys, traceback | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError | |
| # -------- CONFIG -------- | |
| MODEL_ID = os.environ.get("MODEL_ID", "AICodexLab/answerdotai-ModernBERT-base-ai-detector") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) # add as secret if model private | |
| MIN_WORDS = int(os.environ.get("MIN_WORDS", 40)) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| auth_kwargs = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} | |
| model = None | |
| tokenizer = None | |
| load_error = None | |
| # -------- MODEL LOADING (safe) -------- | |
| def try_load_model(): | |
| global model, tokenizer, load_error | |
| try: | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi() | |
| _ = api.model_info(MODEL_ID) # verify existence/access | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **auth_kwargs) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, **auth_kwargs) | |
| model.to(device) | |
| model.eval() | |
| load_error = None | |
| print(f"[INFO] Model loaded: {MODEL_ID}", flush=True) | |
| except RepositoryNotFoundError as e: | |
| load_error = f"Repository not found or access denied: {MODEL_ID}. {e}" | |
| traceback.print_exc() | |
| except HfHubHTTPError as e: | |
| load_error = f"Hugging Face Hub HTTP error: {e}" | |
| traceback.print_exc() | |
| except Exception as e: | |
| load_error = f"Failed to load model {MODEL_ID}: {repr(e)}" | |
| traceback.print_exc() | |
| try_load_model() | |
| # -------- HELPERS -------- | |
| def count_words(text: str) -> int: | |
| if not text: | |
| return 0 | |
| return len([w for w in text.strip().split() if w]) | |
| def make_result_html(verdict_text: str, color: str): | |
| """Return styled HTML block like your example.""" | |
| html = f""" | |
| <div class='result-box' style=" | |
| background: linear-gradient(135deg, {color}33, #1a1a1a); | |
| border: 2px solid {color}; | |
| border-radius: 15px; | |
| padding: 22px; | |
| text-align: center; | |
| color: white; | |
| font-size: 20px; | |
| font-weight: 700; | |
| box-shadow: 0 0 20px {color}55; | |
| animation: fadeIn 0.6s ease-in-out; | |
| "> | |
| {verdict_text} | |
| </div> | |
| """ | |
| return html | |
| # -------- PREDICTION (generator to show loader then result) -------- | |
| def analyze_text(text: str): | |
| # loader HTML (pulse) | |
| loader_html = "<div id='pulse-loader'></div>" | |
| # show loader immediately | |
| yield (loader_html, "") | |
| # check model load | |
| if load_error: | |
| err_html = f"<div style='color:#ff4d4f;font-weight:700;'>Model load error: {load_error}</div>" | |
| yield ("", err_html) | |
| return | |
| if not text or text.strip() == "": | |
| yield ("", "<div style='color:red;'>Please enter some text first.</div>") | |
| return | |
| wc = count_words(text) | |
| if wc < MIN_WORDS: | |
| yield ("", f"<div style='color:#ff9900;font-weight:600;'>⚠️ Please enter at least {MIN_WORDS} words (currently {wc}).</div>") | |
| return | |
| try: | |
| inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt", max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| logits = out.logits | |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
| # interpret probs | |
| if probs.shape[0] == 1: | |
| ai_prob = float(probs[0]) | |
| human_prob = 1.0 - ai_prob | |
| else: | |
| human_prob = float(probs[0]) | |
| ai_prob = float(probs[1]) | |
| # percent confidence (we'll show the predicted label's confidence) | |
| if ai_prob >= human_prob: | |
| label = "AI-generated" | |
| conf = ai_prob * 100.0 | |
| color = "#007BFF" # blue | |
| else: | |
| label = "Human" | |
| conf = human_prob * 100.0 | |
| color = "#4CAF50" # green | |
| verdict = f"{label} ({conf:.1f}% confidence)" | |
| html = make_result_html(verdict, color) | |
| yield ("", html) | |
| except Exception as e: | |
| traceback.print_exc() | |
| yield ("", f"<div style='color:red;'>Error analyzing text: {str(e)}</div>") | |
| # -------- CSS (pulse loader + theme) -------- | |
| css = """ | |
| body, .gradio-container { | |
| font-family: 'Poppins', sans-serif !important; | |
| background: transparent !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| font-weight: 700; | |
| color: #007BFF; | |
| margin-bottom: 10px; | |
| } | |
| .gr-button-primary { | |
| background-color: #007BFF !important; | |
| color: white !important; | |
| font-weight: 600; | |
| border-radius: 10px; | |
| height: 48px; | |
| } | |
| .gr-button-secondary { | |
| background-color: #dc3545 !important; | |
| color: white !important; | |
| border-radius: 10px; | |
| height: 48px; | |
| } | |
| #pulse-loader { | |
| width: 100%; | |
| height: 6px; | |
| background: linear-gradient(90deg, #007BFF, #00C3FF); | |
| animation: pulse 1.2s infinite ease-in-out; | |
| border-radius: 4px; | |
| box-shadow: 0 0 10px #007BFF; | |
| margin-top: 6px; | |
| margin-bottom: 6px; | |
| } | |
| @keyframes pulse { | |
| 0% { transform: scaleX(0.05); opacity: 0.6; } | |
| 50% { transform: scaleX(1); opacity: 1; } | |
| 100% { transform: scaleX(0.05); opacity: 0.6; } | |
| } | |
| @keyframes fadeIn { | |
| from { opacity: 0; transform: scale(0.95); } | |
| to { opacity: 1; transform: scale(1); } | |
| } | |
| .result-box { /* fallback style for non-inline versions */ } | |
| """ | |
| # -------- GRADIO APP -------- | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("<h1>🔎 AI Text Detector</h1>") | |
| gr.Markdown("Detect whether a given text is AI-generated or human-written using the ModernBERT model.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| txt = gr.Textbox(label="Enter your text (at least 40 words)", lines=10, placeholder="Paste your paragraph here...") | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| loader = gr.HTML("") # loader animation | |
| with gr.Column(scale=1): | |
| result_html = gr.HTML(label="Result") | |
| # --- Analyze button --- | |
| analyze_btn.click(analyze_text, inputs=txt, outputs=[loader, result_html]) | |
| # --- Clear button (fix: clears input + output + loader) --- | |
| def clear_all(): | |
| return "", "", "" | |
| clear_btn.click(clear_all, outputs=[txt, loader, result_html]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |