Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| import subprocess | |
| import json | |
| import spaces | |
| from PIL import Image, ImageDraw | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import requests | |
| # Dictionary of model names and their corresponding HuggingFace model IDs | |
| MODEL_OPTIONS = { | |
| "Microsoft Handwritten": "microsoft/trocr-base-handwritten", | |
| "Medieval Base": "medieval-data/trocr-medieval-base", | |
| "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", | |
| "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", | |
| "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", | |
| "Medieval Textualis": "medieval-data/trocr-medieval-textualis", | |
| "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", | |
| "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", | |
| "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", | |
| "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", | |
| "Medieval Print": "medieval-data/trocr-medieval-print" | |
| } | |
| def load_model(model_name): | |
| model_id = MODEL_OPTIONS[model_name] | |
| processor = TrOCRProcessor.from_pretrained(model_id) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_id) | |
| # Move model to GPU if available, else use CPU | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| return processor, model | |
| def detect_lines(image_path): | |
| # API endpoint | |
| url = "https://wjbmattingly-kraken-api.hf.space/detect_lines" | |
| # Run Kraken for line detection | |
| lines_json_path = "lines.json" | |
| # Prepare the file for upload | |
| files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')} | |
| # Specify the model to use | |
| data = {'model_name': 'catmus-medieval.mlmodel'} | |
| # Send the POST request | |
| response = requests.post(url, files=files, data=data) | |
| result = response.json()["result"]["lines"] | |
| return result | |
| def extract_line_images(image, lines): | |
| line_images = [] | |
| for line in lines: | |
| polygon = line['boundary'] | |
| # Calculate bounding box | |
| x_coords, y_coords = zip(*polygon) | |
| x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords)) | |
| # Crop the line from the original image | |
| line_image = image.crop((x1, y1, x2, y2)) | |
| # Create a mask for the polygon | |
| mask = Image.new('L', (x2-x1, y2-y1), 0) | |
| adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon] | |
| ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255) | |
| # Convert images to numpy arrays | |
| line_array = np.array(line_image) | |
| mask_array = np.array(mask) | |
| # Apply the mask | |
| masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255) | |
| # Convert back to PIL Image | |
| masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB') | |
| line_images.append(masked_line_image) | |
| return line_images | |
| def visualize_lines(image, lines): | |
| output_image = image.copy() | |
| draw = ImageDraw.Draw(output_image) | |
| for line in lines: | |
| polygon = [(int(x), int(y)) for x, y in line['boundary']] | |
| draw.polygon(polygon, outline="red") | |
| return output_image | |
| def transcribe_lines(line_images, model_name): | |
| processor, model = load_model(model_name) | |
| transcriptions = [] | |
| for line_image in line_images: | |
| # Process the line image | |
| pixel_values = processor(images=line_image, return_tensors="pt").pixel_values | |
| # Generate (no beam search) | |
| generated_ids = model.generate(pixel_values) | |
| # Decode | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| transcriptions.append(generated_text) | |
| return transcriptions | |
| def process_document(image, model_name): | |
| # Save the uploaded image temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: | |
| image.save(temp_file, format="JPEG") | |
| temp_file_path = temp_file.name | |
| # Step 1: Detect lines | |
| lines = detect_lines(temp_file_path) | |
| # Visualize detected lines | |
| output_image = visualize_lines(image, lines) | |
| # Step 2: Extract line images | |
| line_images = extract_line_images(image, lines) | |
| # Step 3: Transcribe lines | |
| transcriptions = transcribe_lines(line_images, model_name) | |
| # Clean up temporary file | |
| os.unlink(temp_file_path) | |
| return output_image, "\n".join(transcriptions) | |
| # Gradio interface | |
| def gradio_process_document(image, model_name): | |
| output_image, transcriptions = process_document(image, model_name) | |
| return output_image, transcriptions | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Document OCR and Transcription") | |
| gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.") | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here | |
| model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model") | |
| submit_button = gr.Button("Process") | |
| with gr.Row(): | |
| output_image = gr.Image(type="pil", label="Detected Lines") | |
| output_text = gr.Textbox(label="Transcription") | |
| submit_button.click( | |
| fn=gradio_process_document, | |
| inputs=[input_image, model_dropdown], | |
| outputs=[output_image, output_text] | |
| ) | |
| iface.launch(debug=True) |