from huggingface_hub import login, snapshot_download from transformers import TrOCRProcessor import gradio as gr import numpy as np import onnxruntime import torch import time import json import os from plotting_functions import PlotHTR from segment_image import SegmentImage from onnx_text_recognition import TextRecognition LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection" # Download repository to cache TROCR_MODEL_PATH = snapshot_download( repo_id="Kansallisarkisto/multicentury-htr-model-small-onnx" ) # Allowed source paths for input images ALLOWED_SOURCES = ('https://astia.narc.fi', '/tmp/gradio') login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") def get_segmenter(): """Initialize segmentation class.""" try: segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, device='cuda:0', line_iou=0.3, region_iou=0.5, line_overlap=0.5, line_nms_iou=0.7, region_nms_iou=0.3, line_conf_threshold=0.25, region_conf_threshold=0.5, region_model_path=REGION_MODEL_PATH, order_regions=True, region_half_precision=False, line_half_precision=False) return segmenter except Exception as e: print('Failed to initialize SegmentImage class: %s' % e) def get_recognizer(): """Initialize text recognition class.""" try: recognizer = TextRecognition( model_path = TROCR_MODEL_PATH, device = 'cuda:0', batch_size = 10 ) return recognizer except Exception as e: print('Failed to initialize TextRecognition class: %s' % e) segmenter = get_segmenter() recognizer = get_recognizer() plotter = PlotHTR() color_codes = """**Text region type:**
Paragraph ![#EE1289](https://placehold.co/15x15/EE1289/EE1289.png) Marginalia ![#00C957](https://placehold.co/15x15/00C957/00C957.png) Page number ![#0000FF](https://placehold.co/15x15/0000FF/0000FF.png)""" def merge_lines(segment_predictions): img_lines = [] for region in segment_predictions: img_lines += region['lines'] return img_lines def get_text_predictions(image, segment_predictions, recognizer): """Collects text prediction data into dicts based on detected text regions.""" img_lines = merge_lines(segment_predictions) # Process all lines of an image texts = recognizer.process_lines(img_lines, image) return texts def is_allowed_source(file_path): """ Filter function to determine if a file source is allowed. """ # Check allowed paths if file_path.startswith(ALLOWED_SOURCES): return True print(f"File path not allowed: {file_path}") return False async def get_filepath(request): """ Function for extracting input file path from Request object. """ try: # Get the raw request body body = await request.body() if body: body_str = body.decode('utf-8') # Try to parse as JSON try: body_json = json.loads(body_str) # Extract file path if present in the data structure if 'data' in body_json and isinstance(body_json['data'], list): for item in body_json['data']: if isinstance(item, dict) and 'path' in item: file_path = item['path'] print(f"Found file path: {file_path}") return file_path except json.JSONDecodeError: print("Body is not valid JSON") except Exception as e: print(f"Error reading request body: {e}") # Run demo code with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo: gr.Markdown("# Multicentury HTR Demo") gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition. The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents from 16th, 17th, 18th, 19th and 20th centuries. Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs. Best results are obtained when using high quality scans of documents with a regular layout. Please note that this is a demo. 24/7 functionality is not quaranteed. # Monen vuosisadan käsialantunnistusmalli Käsialantunnistusputkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1500-luvulta 1900-luvulle. Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image* painikkeesta, ja kun kuva on prosessoitu, tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla, joissa on normaalin kirjan mukainen taitto. Huom! Tämä on demosovellus. Ympärivuorokautista toimivuutta ei luvata. """) with gr.Tab("Text content"): with gr.Row(): input_img = gr.Image(label="Input image", type="pil") textbox = gr.Textbox(label="Predicted text content", lines=10) button = gr.Button("Process image") processing_time = gr.Markdown() with gr.Tab("Text regions"): region_img = gr.Image(label="Predicted text regions", type="numpy") gr.Markdown(color_codes) with gr.Tab("Text lines"): line_img = gr.Image(label="Predicted text lines", type="numpy") gr.Markdown(color_codes) async def run_pipeline(image, request: gr.Request): if request: #print("=== Request Information ===") #print(f"Request URL: {request.url}") #print(f"Request method: {request.method}") #print(f"Client host: {request.client.host}") #print(f"Headers: {dict(request.headers)}") #print(f"Query params: {dict(request.query_params)}") file_path = await get_filepath(request) # Only files from allowed sources are processed if not is_allowed_source(file_path): return {'textbox': 'Error: File source not allowed'} else: # Predict region and line segments start = time.time() segment_predictions = segmenter.get_segmentation(image) print('segmentation ok') if segment_predictions: region_plot = plotter.plot_regions(segment_predictions, image) line_plot = plotter.plot_lines(segment_predictions, image) text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer) print('text pred ok') text = "\n".join(text_predictions) end = time.time() proc_time = end - start proc_time_str = f"Processing time: {proc_time:.4f}s" return { region_img: region_plot, line_img: line_plot, textbox: text, processing_time: proc_time_str } else: end = time.time() proc_time = end - start proc_time_str = f"Processing time: {proc_time:.4f}s" return { region_img: None, line_img: None, textbox: None, processing_time: proc_time_str } button.click(fn=run_pipeline, inputs=input_img, outputs=[region_img, line_img, textbox, processing_time]) #api_name=False) if __name__ == "__main__": demo.queue() demo.launch(show_error=True)