import gradio as gr # --- Import necessary classes --- from transformers import ( TrOCRProcessor, VisionEncoderDecoderModel, pipeline, AutoTokenizer, RobertaForSequenceClassification, AutoConfig # <--- Import AutoConfig ) # --- from PIL import Image import traceback import warnings import json import os import shutil # Not used directly, but keep for potential manual use # --- Model IDs --- TROCR_MODELS = { "Printed Text": "microsoft/trocr-large-printed", "Handwritten": "microsoft/trocr-large-handwritten", } DETECTOR_MODEL_ID = "SuperAnnotate/roberta-large-llm-content-detector" print(f"Using AI Detector Model: {DETECTOR_MODEL_ID}") # --- Load OCR Models (no changes here) --- print("Loading OCR models...") OCR_PIPELINES = {} for name, model_id in TROCR_MODELS.items(): try: proc = TrOCRProcessor.from_pretrained(model_id) mdl = VisionEncoderDecoderModel.from_pretrained(model_id) OCR_PIPELINES[name] = (proc, mdl) print(f"Loaded {name} OCR model.") except Exception as e: print(f"Error loading OCR model {name} ({model_id}): {e}") # --- Explicitly load config, tokenizer, and model --- print(f"Loading AI detector components ({DETECTOR_MODEL_ID})...") DETECTOR_PIPELINE = None detector_tokenizer = None detector_model = None try: # 1. Load Configuration FIRST print("Loading detector config...") detector_config = AutoConfig.from_pretrained(DETECTOR_MODEL_ID) print(f"Loaded config. Expected hidden size: {detector_config.hidden_size}") # Should be 1024 # Add an assertion to halt if config is wrong (optional but helpful) if detector_config.hidden_size != 1024: raise ValueError(f"Loaded config specifies hidden size {detector_config.hidden_size}, but expected 1024 for roberta-large. Check cache for {DETECTOR_MODEL_ID}.") # 2. Load Tokenizer print("Loading detector tokenizer...") detector_tokenizer = AutoTokenizer.from_pretrained(DETECTOR_MODEL_ID) # 3. Load Model using the specific class AND the loaded config print("Loading detector model with loaded config...") detector_model = RobertaForSequenceClassification.from_pretrained( DETECTOR_MODEL_ID, config=detector_config # <--- Pass the loaded config ) print("AI detector model and tokenizer loaded successfully.") # 4. Create Pipeline print("Creating AI detector pipeline...") DETECTOR_PIPELINE = pipeline( "text-classification", model=detector_model, tokenizer=detector_tokenizer, top_k=None ) print("Created AI detector pipeline.") # --- Optional: Label Test (keep from previous version) --- if DETECTOR_PIPELINE: try: print("Testing detector pipeline labels...") sample_output = DETECTOR_PIPELINE("This is a reasonably long test sentence to check the model labels.", truncation=True) print(f"Sample detector output structure: {sample_output}") # ... (rest of label testing code) ... if sample_output and isinstance(sample_output, list) and len(sample_output) > 0: if isinstance(sample_output[0], list) and len(sample_output[0]) > 0: labels = [item.get('label', 'N/A') for item in sample_output[0] if isinstance(item, dict)] print(f"Detected labels from sample run: {labels}") elif isinstance(sample_output[0], dict): labels = [item.get('label', 'N/A') for item in sample_output if isinstance(item, dict)] print(f"Detected labels from sample run (non-nested): {labels}") if detector_model and detector_model.config and detector_model.config.id2label: print(f"Labels from model config: {detector_model.config.id2label}") # Should show {0: 'Human', 1: 'AI'} except Exception as test_e: print(f"Could not perform detector label test: {test_e}") traceback.print_exc() except Exception as e: print(f"CRITICAL Error loading AI detector components ({DETECTOR_MODEL_ID}): {e}") traceback.print_exc() # --- Simplified Cache Clearing Suggestion --- # Get cache path using environment variable or default hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) hub_cache_path = os.path.join(hf_home, "hub") # Models are usually in the 'hub' subfolder print("\n--- TROUBLESHOOTING SUGGESTION ---") print(f"The model loading failed: {e}") print("\nThis *strongly* indicates a problem with the cached files for this model.") print("The most likely solution is to MANUALLY clear the cache for this model.") print(f"\n1. Stop this application.") print(f"2. Go to your Hugging Face hub cache directory (usually found under '{hub_cache_path}').") print(f" (If you've set HF_HOME environment variable, check there instead: '{hf_home}')") # Construct the model-specific cache folder name model_cache_folder_name = f"models--{DETECTOR_MODEL_ID.replace('/', '--')}" print(f"3. Delete the specific folder for this model: '{model_cache_folder_name}'") print(f" Full path example: {os.path.join(hub_cache_path, model_cache_folder_name)}") print(f"4. Restart the application. This will force a fresh download.") print("\nMake sure no other applications are using the cache while deleting.") print("--- END TROUBLESHOOTING ---") # --- # DETECTOR_PIPELINE remains None # --- Functions get_ai_and_human_scores, analyze_image, classify_text remain the same --- # (Ensure get_ai_and_human_scores correctly handles "AI" and "Human" based on config) def get_ai_and_human_scores(results): """ Processes detector results to get likelihood scores for both AI and Human classes. Handles various label formats including 'AI'/'Human', 'LABEL_0'/'LABEL_1', etc. Returns: tuple: (ai_display_string, human_display_string) """ ai_prob = 0.0 human_prob = 0.0 status_message = "Status: Initializing..." # Default status if not results: print("Warning: Received empty results for AI detection.") status_message = "Error: No results received" return status_message, "N/A" # Handle potential nested list structure score_list = [] if isinstance(results, list) and len(results) > 0: if isinstance(results[0], list) and len(results[0]) > 0: score_list = results[0] elif isinstance(results[0], dict): score_list = results else: status_message = f"Error: Unexpected detector output format (inner list type: {type(results[0])})" print(f"Warning: {status_message}. Results[0]: {results[0]}") return status_message, "N/A" else: status_message = f"Error: Unexpected detector output format (outer type: {type(results)})" print(f"Warning: {status_message}. Results: {results}") return status_message, "N/A" # Build label→score map (uppercase labels for robust matching) lbl2score = {} parse_errors = [] for entry in score_list: if isinstance(entry, dict) and "label" in entry and "score" in entry: try: score = float(entry["score"]) lbl2score[entry["label"].upper()] = score except (ValueError, TypeError): parse_errors.append(f"Invalid score format: {entry}") else: parse_errors.append(f"Invalid entry format: {entry}") if parse_errors: print(f"Warning: Encountered parsing errors in score list: {parse_errors}") if not lbl2score: status_message = "Error: Could not parse any valid scores from detector output" print(f"Warning: {status_message}. Score list was: {score_list}") return status_message, "N/A" label_keys_found = ", ".join(lbl2score.keys()) found_pair = False inferred = False # --- Determine AI and Human probabilities based on labels --- upper_keys = lbl2score.keys() # Prioritize AI/HUMAN as per model config if "AI" in upper_keys and "HUMAN" in upper_keys: ai_prob = lbl2score["AI"] human_prob = lbl2score["HUMAN"] found_pair = True status_message = "OK (Used AI/HUMAN labels)" # Fallbacks elif "LABEL_1" in upper_keys and "LABEL_0" in upper_keys: ai_prob = lbl2score["LABEL_1"] human_prob = lbl2score["LABEL_0"] found_pair = True status_message = "OK (Used LABEL_1/LABEL_0 - Check Mapping)" print("Warning: Used fallback LABEL_1/LABEL_0. Config expects AI/HUMAN.") # Add other fallbacks if necessary (FAKE/REAL, MACHINE/HUMAN) # Inference logic if not found_pair: if "AI" in upper_keys: ai_prob = lbl2score["AI"] human_prob = max(0.0, 1.0 - ai_prob) inferred = True status_message = "OK (Inferred from AI label)" elif "HUMAN" in upper_keys: human_prob = lbl2score["HUMAN"] ai_prob = max(0.0, 1.0 - human_prob) inferred = True status_message = "OK (Inferred from HUMAN label)" # Add fallback inference if needed if not inferred: status_message = f"Error: Could not determine AI/Human pair from labels [{label_keys_found}]" print(f"Warning: {status_message}") # --- Format output strings --- ai_display_str = f"{ai_prob*100:.2f}%" human_display_str = f"{human_prob*100:.2f}%" if "Error:" in status_message: ai_display_str = status_message human_display_str = "N/A" print(f"Score Status: {status_message}. AI={ai_display_str}, Human={human_display_str}") return ai_display_str, human_display_str # --- analyze_image function (no changes needed) --- def analyze_image(image: Image.Image, ocr_choice: str): """Performs OCR and AI Content Detection, returns both AI and Human %.""" extracted = "" ai_result_str = "N/A" human_result_str = "N/A" status_update = "Awaiting input..." if image is None: status_update = "Please upload an image first." return extracted, ai_result_str, human_result_str, status_update if not ocr_choice or ocr_choice not in TROCR_MODELS: status_update = "Please select a valid OCR model." return extracted, ai_result_str, human_result_str, status_update if OCR_PIPELINES.get(ocr_choice) is None: return "", "N/A", "N/A", f"Error: OCR model '{ocr_choice}' failed to load or is unavailable." if DETECTOR_PIPELINE is None: return "", "N/A", "N/A", f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?)." try: status_update = f"Processing with {ocr_choice} OCR..." print(status_update) proc, mdl = OCR_PIPELINES[ocr_choice] if image.mode != "RGB": image = image.convert("RGB") pix = proc(images=image, return_tensors="pt").pixel_values tokens = mdl.generate(pix, max_length=1024) extracted = proc.batch_decode(tokens, skip_special_tokens=True)[0] extracted = extracted.strip() if not extracted: status_update = "OCR completed, but no text was extracted." print(status_update) return extracted, "N/A", "N/A", status_update status_update = f"Detecting AI/Human content in {len(extracted)} characters..." print(status_update) results = DETECTOR_PIPELINE(extracted) ai_result_str, human_result_str = get_ai_and_human_scores(results) if "Error:" in ai_result_str: status_update = ai_result_str else: status_update = "Analysis complete." print(f"Final Status: {status_update}") return extracted, ai_result_str, human_result_str, status_update except Exception as e: error_msg = f"Error during image analysis: {e}" print(error_msg) traceback.print_exc() status_update = error_msg return extracted, "Error", "Error", status_update # --- classify_text function (no changes needed) --- def classify_text(text: str): """Classifies provided text, returning both AI and Human %.""" ai_result_str = "N/A" human_result_str = "N/A" if DETECTOR_PIPELINE is None: return f"Critical Error: AI Detector model ({DETECTOR_MODEL_ID}) failed during startup. Check logs for details (possible cache issue?).", "N/A" if not text or text.isspace(): return "Please enter some text.", "N/A" print("Classifying text...") try: results = DETECTOR_PIPELINE(text) ai_result_str, human_result_str = get_ai_and_human_scores(results) if "Error:" not in ai_result_str: print("Classification complete.") else: print(f"Classification completed with issues: {ai_result_str}") return ai_result_str, human_result_str except Exception as e: error_msg = f"Error during text classification: {e}" print(error_msg) traceback.print_exc() return error_msg, "Error" # --- Gradio Interface (no changes needed) --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" ## OCR + AI/Human Content Detection Upload an image or paste text. The tool extracts text via OCR (if image) and analyzes it using an AI content detector (`{DETECTOR_MODEL_ID}`) to estimate the likelihood of it being AI-generated vs. Human-written. **Disclaimer:** AI content detection is challenging and not 100% accurate. These likelihoods are estimates based on the model's training data and may not be definitive. Performance varies with text type, length, and AI generation methods. **Label Assumption:** Uses the model's configured labels (`AI`/`Human`). Fallbacks for other label formats are included but may be less reliable if the model deviates from its configuration. """ ) with gr.Tab("Analyze Image"): with gr.Row(): with gr.Column(scale=2): img_in = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"]) with gr.Column(scale=1): ocr_dd = gr.Dropdown( list(TROCR_MODELS.keys()), label="1. Select OCR Model", info="Choose based on text type in image." ) run_btn = gr.Button("2. Analyze Image", variant="primary") status_img = gr.Label(value="Awaiting image analysis...", label="Status") with gr.Row(): text_out_img = gr.Textbox(label="Extracted Text", lines=10, interactive=False) with gr.Column(scale=1): ai_out_img = gr.Textbox(label="AI Likelihood %", interactive=False) with gr.Column(scale=1): human_out_img = gr.Textbox(label="Human Likelihood %", interactive=False) run_btn.click( fn=analyze_image, inputs=[img_in, ocr_dd], outputs=[text_out_img, ai_out_img, human_out_img, status_img], queue=True ) with gr.Tab("Classify Text"): with gr.Column(): text_in_classify = gr.Textbox(label="Paste or type text here", lines=10) classify_btn = gr.Button("Classify Text", variant="primary") with gr.Row(): with gr.Column(scale=1): ai_out_classify = gr.Textbox(label="AI Likelihood %", interactive=False) with gr.Column(scale=1): human_out_classify = gr.Textbox(label="Human Likelihood %", interactive=False) classify_btn.click( fn=classify_text, inputs=[text_in_classify], outputs=[ai_out_classify, human_out_classify], queue=True ) gr.HTML(f"") if __name__ == "__main__": print("Starting Gradio demo...") demo.launch(share=False, server_name="0.0.0.0")