Spaces:
Sleeping
Sleeping
| from optimum.onnxruntime import ORTModelForVision2Seq | |
| from transformers import TrOCRProcessor | |
| from huggingface_hub import login | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime | |
| import torch | |
| import time | |
| import os | |
| from plotting_functions import PlotHTR | |
| from segment_image import SegmentImage | |
| from onnx_text_recognition import TextRecognition | |
| LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" | |
| REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection" | |
| TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" | |
| TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" | |
| login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| def get_segmenter(): | |
| """Initialize segmentation class.""" | |
| try: | |
| segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, | |
| device='cuda:0', | |
| line_iou=0.3, | |
| region_iou=0.5, | |
| line_overlap=0.5, | |
| line_nms_iou=0.7, | |
| region_nms_iou=0.3, | |
| line_conf_threshold=0.25, | |
| region_conf_threshold=0.5, | |
| region_model_path=REGION_MODEL_PATH, | |
| order_regions=True, | |
| region_half_precision=False, | |
| line_half_precision=False) | |
| return segmenter | |
| except Exception as e: | |
| print('Failed to initialize SegmentImage class: %s' % e) | |
| def get_recognizer(): | |
| """Initialize text recognition class.""" | |
| try: | |
| recognizer = TextRecognition( | |
| processor_path = TROCR_PROCESSOR_PATH, | |
| model_path = TROCR_MODEL_PATH, | |
| device = 'cuda:0', | |
| half_precision = True, | |
| line_threshold = 10 | |
| ) | |
| return recognizer | |
| except Exception as e: | |
| print('Failed to initialize TextRecognition class: %s' % e) | |
| segmenter = get_segmenter() | |
| recognizer = get_recognizer() | |
| plotter = PlotHTR() | |
| color_codes = """**Text region type:** <br> | |
| Paragraph  | |
| Marginalia  | |
| Page number """ | |
| def merge_lines(segment_predictions): | |
| img_lines = [] | |
| for region in segment_predictions: | |
| img_lines += region['lines'] | |
| return img_lines | |
| def get_text_predictions(image, segment_predictions, recognizer): | |
| """Collects text prediction data into dicts based on detected text regions.""" | |
| img_lines = merge_lines(segment_predictions) | |
| height, width = segment_predictions[0]['img_shape'] | |
| # Process all lines of an image | |
| texts = recognizer.process_lines(img_lines, image, height, width) | |
| return texts | |
| # Run demo code | |
| with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo: | |
| gr.Markdown("# Multicentury HTR Demo") | |
| gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition. | |
| The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents | |
| from 17th, 18th, 19th and 20th centuries. | |
| Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window | |
| on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs. | |
| Best results are obtained when using high quality scans of documents with a regular layout. | |
| Please note that this is a demo. 24/7 functionality is not quaranteed. | |
| # Monen vuosisadan käsialantunnistus malli | |
| Käsialantunnistus putkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin | |
| käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1600-luvulta 1900-luvulle. | |
| Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image* | |
| painikkeesta ja kuva on prosessoitu tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja | |
| tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla, | |
| joissa on normaalin kirjan mukainen taitto. | |
| Huom! Tämä on demo sovellus. Ympärivuorokautista toimivuutta ei luvata. | |
| """) | |
| with gr.Tab("Text content"): | |
| with gr.Row(): | |
| input_img = gr.Image(label="Input image", type="pil") | |
| textbox = gr.Textbox(label="Predicted text content", lines=10) | |
| button = gr.Button("Process image") | |
| processing_time = gr.Markdown() | |
| with gr.Tab("Text regions"): | |
| region_img = gr.Image(label="Predicted text regions", type="numpy") | |
| gr.Markdown(color_codes) | |
| with gr.Tab("Text lines"): | |
| line_img = gr.Image(label="Predicted text lines", type="numpy") | |
| gr.Markdown(color_codes) | |
| def run_pipeline(image): | |
| # Predict region and line segments | |
| start = time.time() | |
| segment_predictions = segmenter.get_segmentation(image) | |
| print('segmentation ok') | |
| if segment_predictions: | |
| region_plot = plotter.plot_regions(segment_predictions, image) | |
| line_plot = plotter.plot_lines(segment_predictions, image) | |
| text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer) | |
| print('text pred ok') | |
| text = "\n".join(text_predictions) | |
| end = time.time() | |
| proc_time = end - start | |
| proc_time_str = f"Processing time: {proc_time:.4f}s" | |
| return { | |
| region_img: region_plot, | |
| line_img: line_plot, | |
| textbox: text, | |
| processing_time: proc_time_str | |
| } | |
| else: | |
| end = time.time() | |
| proc_time = end - start | |
| proc_time_str = f"Processing time: {proc_time:.4f}s" | |
| return { | |
| region_img: None, | |
| line_img: None, | |
| textbox: None, | |
| processing_time: proc_time_str | |
| } | |
| button.click(fn=run_pipeline, | |
| inputs=input_img, | |
| outputs=[region_img, line_img, textbox, processing_time]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(show_error=True) | |