MikkoLipsanen commited on
Commit
c0612a1
·
verified ·
1 Parent(s): ecbc35c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +607 -177
app.py CHANGED
@@ -1,8 +1,13 @@
1
- from huggingface_hub import login, snapshot_download
 
2
  from transformers import TrOCRProcessor
 
 
3
  import gradio as gr
4
- import numpy as np
5
  import onnxruntime
 
 
6
  import torch
7
  import time
8
  import json
@@ -12,198 +17,623 @@ from plotting_functions import PlotHTR
12
  from segment_image import SegmentImage
13
  from onnx_text_recognition import TextRecognition
14
 
15
-
16
- LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
17
- REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
18
-
19
- # Download repository to cache
20
- TROCR_MODEL_PATH = snapshot_download(
21
- repo_id="Kansallisarkisto/multicentury-htr-model-small-onnx"
22
  )
 
23
 
24
- # Allowed source paths for input images
25
- ALLOWED_SOURCES = ('https://astia.narc.fi', '/tmp/gradio')
 
 
 
 
26
 
27
- login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
28
 
29
- print(f"Is CUDA available: {torch.cuda.is_available()}")
30
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def get_segmenter():
33
- """Initialize segmentation class."""
 
34
  try:
35
- segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
36
- device='cuda:0',
37
- line_iou=0.3,
38
- region_iou=0.5,
39
- line_overlap=0.5,
40
- line_nms_iou=0.7,
41
- region_nms_iou=0.3,
42
- line_conf_threshold=0.25,
43
- region_conf_threshold=0.5,
44
- region_model_path=REGION_MODEL_PATH,
45
- order_regions=True,
46
- region_half_precision=False,
47
- line_half_precision=False)
48
- return segmenter
49
  except Exception as e:
50
- print('Failed to initialize SegmentImage class: %s' % e)
 
51
 
52
- def get_recognizer():
53
- """Initialize text recognition class."""
 
 
 
 
 
 
 
 
54
  try:
55
- recognizer = TextRecognition(
56
- model_path = TROCR_MODEL_PATH,
57
- device = 'cuda:0',
58
- batch_size = 10
59
- )
60
- return recognizer
 
 
 
 
 
 
61
  except Exception as e:
62
- print('Failed to initialize TextRecognition class: %s' % e)
63
-
64
- segmenter = get_segmenter()
65
- recognizer = get_recognizer()
66
- plotter = PlotHTR()
67
-
68
- color_codes = """**Text region type:** <br>
69
- Paragraph ![#EE1289](https://placehold.co/15x15/EE1289/EE1289.png)
70
- Marginalia ![#00C957](https://placehold.co/15x15/00C957/00C957.png)
71
- Page number ![#0000FF](https://placehold.co/15x15/0000FF/0000FF.png)"""
72
-
73
- def merge_lines(segment_predictions):
74
- img_lines = []
75
- for region in segment_predictions:
76
- img_lines += region['lines']
77
- return img_lines
78
-
79
- def get_text_predictions(image, segment_predictions, recognizer):
80
- """Collects text prediction data into dicts based on detected text regions."""
81
- img_lines = merge_lines(segment_predictions)
82
- # Process all lines of an image
83
- texts = recognizer.process_lines(img_lines, image)
84
- return texts
85
-
86
- def is_allowed_source(file_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- Filter function to determine if a file source is allowed.
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  """
90
- # Check allowed paths
91
- if file_path.startswith(ALLOWED_SOURCES):
92
- return True
93
- print(f"File path not allowed: {file_path}")
94
- return False
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- async def get_filepath(request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
- Function for extracting input file path from Request object.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
  try:
101
- # Get the raw request body
102
  body = await request.body()
103
- if body:
104
- body_str = body.decode('utf-8')
105
- # Try to parse as JSON
106
- try:
107
- body_json = json.loads(body_str)
108
- # Extract file path if present in the data structure
109
- if 'data' in body_json and isinstance(body_json['data'], list):
110
- for item in body_json['data']:
111
- if isinstance(item, dict) and 'path' in item:
112
- file_path = item['path']
113
- print(f"Found file path: {file_path}")
114
- return file_path
115
- except json.JSONDecodeError:
116
- print("Body is not valid JSON")
 
 
 
 
 
117
  except Exception as e:
118
- print(f"Error reading request body: {e}")
119
-
120
- # Run demo code
121
- with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo:
122
- gr.Markdown("# Multicentury HTR Demo")
123
- gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition.
124
- The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents
125
- from 16th, 17th, 18th, 19th and 20th centuries.
126
-
127
- Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window
128
- on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs.
129
- Best results are obtained when using high quality scans of documents with a regular layout.
130
-
131
- Please note that this is a demo. 24/7 functionality is not quaranteed.
132
-
133
- # Monen vuosisadan käsialantunnistusmalli
134
-
135
- Käsialantunnistusputkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin
136
- käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1500-luvulta 1900-luvulle.
137
-
138
- Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image*
139
- painikkeesta, ja kun kuva on prosessoitu, tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja
140
- tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla,
141
- joissa on normaalin kirjan mukainen taitto.
142
-
143
- Huom! Tämä on demosovellus. Ympärivuorokautista toimivuutta ei luvata.
144
- """)
145
-
146
- with gr.Tab("Text content"):
147
- with gr.Row():
148
- input_img = gr.Image(label="Input image", type="pil")
149
- textbox = gr.Textbox(label="Predicted text content", lines=10)
150
- button = gr.Button("Process image")
151
- processing_time = gr.Markdown()
152
- with gr.Tab("Text regions"):
153
- region_img = gr.Image(label="Predicted text regions", type="numpy")
154
- gr.Markdown(color_codes)
155
- with gr.Tab("Text lines"):
156
- line_img = gr.Image(label="Predicted text lines", type="numpy")
157
- gr.Markdown(color_codes)
158
-
159
- async def run_pipeline(image, request: gr.Request):
160
- if request:
161
- #print("=== Request Information ===")
162
- #print(f"Request URL: {request.url}")
163
- #print(f"Request method: {request.method}")
164
- #print(f"Client host: {request.client.host}")
165
- #print(f"Headers: {dict(request.headers)}")
166
- #print(f"Query params: {dict(request.query_params)}")
167
- file_path = await get_filepath(request)
168
- # Only files from allowed sources are processed
169
- if not is_allowed_source(file_path):
170
- return {'textbox': 'Error: File source not allowed'}
171
- else:
172
- # Predict region and line segments
173
- start = time.time()
174
- segment_predictions = segmenter.get_segmentation(image)
175
- print('segmentation ok')
176
- if segment_predictions:
177
- region_plot = plotter.plot_regions(segment_predictions, image)
178
- line_plot = plotter.plot_lines(segment_predictions, image)
179
- text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
180
- print('text pred ok')
181
- text = "\n".join(text_predictions)
182
- end = time.time()
183
- proc_time = end - start
184
- proc_time_str = f"Processing time: {proc_time:.4f}s"
185
- return {
186
- region_img: region_plot,
187
- line_img: line_plot,
188
- textbox: text,
189
- processing_time: proc_time_str
190
- }
191
- else:
192
- end = time.time()
193
- proc_time = end - start
194
- proc_time_str = f"Processing time: {proc_time:.4f}s"
195
- return {
196
- region_img: None,
197
- line_img: None,
198
- textbox: None,
199
- processing_time: proc_time_str
200
- }
201
-
202
- button.click(fn=run_pipeline,
203
- inputs=input_img,
204
- outputs=[region_img, line_img, textbox, processing_time])
205
- #api_name=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
207
  if __name__ == "__main__":
208
- demo.queue()
209
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login, snapshot_download, hf_hub_download
2
+ from typing import Optional, Tuple, Dict, Any
3
  from transformers import TrOCRProcessor
4
+ from datetime import datetime
5
+ from pathlib import Path
6
  import gradio as gr
7
+ import numpy as np
8
  import onnxruntime
9
+ import tempfile
10
+ import logging
11
  import torch
12
  import time
13
  import json
 
17
  from segment_image import SegmentImage
18
  from onnx_text_recognition import TextRecognition
19
 
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24
+ handlers=[
25
+ logging.StreamHandler() # Explicit stdout handler for HF Spaces
26
+ ]
27
  )
28
+ logger = logging.getLogger(__name__)
29
 
30
+ # Log startup info for debugging in HF Spaces
31
+ logger.info("="*50)
32
+ logger.info("HTR Application Starting")
33
+ logger.info(f"Python version: {os.sys.version}")
34
+ logger.info(f"Running on Hugging Face Spaces: {os.getenv('SPACE_ID', 'Local')}")
35
+ logger.info("="*50)
36
 
 
37
 
38
+ # Configuration from environment variables
39
+ class Config:
40
+ """Application configuration from environment variables."""
41
+ HF_TOKEN = os.getenv("HF_TOKEN")
42
+ SEGMENTATION_MAX_SIZE = 768
43
+ RECOGNITION_BATCH_SIZE = 10
44
+ SEGMENTATION_CONFIDENCE_THRESHOLD = 0.15
45
+ SEGMENTATION_LINE_PRECENTAGE_THRESHOLD = 7e-05
46
+ SEGMENTATION_REGION_PRECENTAGE_THRESHOLD = 7e-05
47
+ SEGMENTATION_LINE_IOU = 0.3
48
+ SEGMENTATION_REGION_IOU = 0.3
49
+ SEGMENTATION_LINE_OVERLAP_THRESHOLD = 0.5
50
+ SEGMENTATION_REGION_OVERLAP_THRESHOLD = 0.5
51
+ ALLOWED_SOURCES = ("https://astia.narc.fi, /tmp/gradio")
52
+
53
+ # Model paths
54
+ TROCR_MODEL_REPO = "Kansallisarkisto/multicentury-htr-model-small-onnx"
55
+ SEGMENTATION_MODEL_REPO = "Kansallisarkisto/rfdetr_textline_textregion_detection_model"
56
+ SEGMENTATION_MODEL_FILE = "rfdetr_text_seg_model_202510.pth"
57
 
58
+
59
+ # Login to HuggingFace if token is available
60
+ if Config.HF_TOKEN:
61
  try:
62
+ login(token=Config.HF_TOKEN, add_to_git_credential=True)
63
+ logger.info("✓ Logged in to HuggingFace")
 
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
+ logger.warning(f"Failed to login to HuggingFace: {e}")
66
+
67
 
68
+ def download_models() -> Tuple[str, str]:
69
+ """
70
+ Download required models from HuggingFace Hub.
71
+
72
+ Returns:
73
+ Tuple of (text_recognition_model_path, segmentation_model_path)
74
+
75
+ Raises:
76
+ RuntimeError: If model download fails
77
+ """
78
  try:
79
+ logger.info("Downloading text recognition model...")
80
+ trocr_path = snapshot_download(repo_id=Config.TROCR_MODEL_REPO)
81
+ logger.info(f"✓ Text recognition model downloaded to {trocr_path}")
82
+
83
+ logger.info("Downloading segmentation model...")
84
+ seg_path = hf_hub_download(
85
+ repo_id=Config.SEGMENTATION_MODEL_REPO,
86
+ filename=Config.SEGMENTATION_MODEL_FILE
87
+ )
88
+ logger.info(f"✓ Segmentation model downloaded to {seg_path}")
89
+
90
+ return trocr_path, seg_path
91
  except Exception as e:
92
+ logger.error(f"Failed to download models: {e}")
93
+ raise RuntimeError(f"Model download failed: {e}")
94
+
95
+
96
+ # Download models
97
+ TROCR_MODEL_PATH, SEGMENTATION_MODEL_PATH = download_models()
98
+
99
+ # Log CUDA availability
100
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
101
+ if torch.cuda.is_available():
102
+ logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
103
+
104
+
105
+ class HTRPipeline:
106
+ """
107
+ Handwritten Text Recognition pipeline combining segmentation and recognition.
108
+
109
+ This class manages the initialization and execution of document segmentation
110
+ and text recognition models.
111
+ """
112
+
113
+ def __init__(self,
114
+ segmentation_model_path: str,
115
+ recognition_model_path: str,
116
+ segmentation_max_size: int = 768,
117
+ recognition_batch_size: int = 10,
118
+ segmentation_confidence_threshold: float = 0.15,
119
+ segmentation_line_percentage_threshold: float = 7e-05,
120
+ segmentation_region_percentage_threshold: float = 7e-05,
121
+ segmentation_line_iou: float = 0.3,
122
+ segmentation_region_iou: float = 0.3,
123
+ segmentation_line_overlap_threshold: float = 0.5,
124
+ segmentation_region_overlap_threshold: float = 0.5
125
+ ):
126
+ """
127
+ Initialize HTR pipeline with segmentation and recognition models.
128
+
129
+ Args:
130
+ segmentation_model_path: Path to segmentation model weights
131
+ recognition_model_path: Path to recognition model directory
132
+ segmentation_max_size: Maximum image dimension for segmentation
133
+ recognition_batch_size: Batch size for text recognition
134
+ segmentation_confidence_threshold: Minimum confidence score for detections
135
+ segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines
136
+ segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions
137
+ segmentation_line_iou: IoU threshold for merging overlapping line polygons
138
+ segmentation_region_iou: IoU threshold for merging overlapping region polygons
139
+ segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines
140
+ segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions
141
+ """
142
+ self.segmenter = self._init_segmenter(segmentation_model_path,
143
+ segmentation_max_size,
144
+ segmentation_confidence_threshold,
145
+ segmentation_line_percentage_threshold,
146
+ segmentation_region_percentage_threshold,
147
+ segmentation_line_iou,
148
+ segmentation_region_iou,
149
+ segmentation_line_overlap_threshold,
150
+ segmentation_region_overlap_threshold
151
+ )
152
+ self.recognizer = self._init_recognizer(recognition_model_path, recognition_batch_size)
153
+ self.plotter = PlotHTR()
154
+
155
+ if self.segmenter is None or self.recognizer is None:
156
+ raise RuntimeError("Failed to initialize HTR pipeline components")
157
+
158
+ def _init_segmenter(self,
159
+ model_path: str,
160
+ max_size: int,
161
+ segmentation_confidence_threshold: float,
162
+ segmentation_line_percentage_threshold: float,
163
+ segmentation_region_percentage_threshold: float,
164
+ segmentation_line_iou: float,
165
+ segmentation_region_iou: float,
166
+ segmentation_line_overlap_threshold: float,
167
+ segmentation_region_overlap_threshold: float
168
+ ) -> Optional[SegmentImage]:
169
  """
170
+ Initialize document segmentation model.
171
+
172
+ Args:
173
+ model_path: Path to segmentation model
174
+ max_size: Maximum dimension for image preprocessing
175
+ segmentation_confidence_threshold: Minimum confidence score for detections
176
+ segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines
177
+ segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions
178
+ segmentation_line_iou: IoU threshold for merging overlapping line polygons
179
+ segmentation_region_iou: IoU threshold for merging overlapping region polygons
180
+ segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines
181
+ segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions
182
+ Returns:
183
+ Initialized SegmentImage instance or None if initialization fails
184
  """
185
+ try:
186
+ segmenter = SegmentImage(
187
+ model_path=model_path,
188
+ max_size=max_size,
189
+ confidence_threshold=segmentation_confidence_threshold,
190
+ line_percentage_threshold=segmentation_line_percentage_threshold,
191
+ region_percentage_threshold=segmentation_region_percentage_threshold,
192
+ line_iou=segmentation_line_iou,
193
+ region_iou=segmentation_region_iou,
194
+ line_overlap_threshold=segmentation_line_overlap_threshold,
195
+ region_overlap_threshold=segmentation_region_overlap_threshold
196
+ )
197
+ logger.info("✓ Segmentation model initialized")
198
+ return segmenter
199
+ except Exception as e:
200
+ logger.error(f"Failed to initialize segmentation model: {e}")
201
+ return None
202
 
203
+ def _init_recognizer(self, model_path: str, batch_size: int) -> Optional[TextRecognition]:
204
+ """
205
+ Initialize text recognition model.
206
+
207
+ Args:
208
+ model_path: Path to recognition model directory
209
+ batch_size: Number of text lines to process in parallel
210
+
211
+ Returns:
212
+ Initialized TextRecognition instance or None if initialization fails
213
+ """
214
+ try:
215
+ recognizer = TextRecognition(
216
+ model_path=model_path,
217
+ device='cuda:0' if torch.cuda.is_available() else 'cpu',
218
+ batch_size=batch_size
219
+ )
220
+ logger.info("✓ Text recognition model initialized")
221
+ return recognizer
222
+ except Exception as e:
223
+ logger.error(f"Failed to initialize text recognition model: {e}")
224
+ return None
225
+
226
+ def _merge_lines(self, segment_predictions: list) -> list:
227
+ """
228
+ Merge text lines from all regions into a single list.
229
+
230
+ Args:
231
+ segment_predictions: List of region dictionaries containing line data
232
+
233
+ Returns:
234
+ Flat list of all text line polygons
235
+ """
236
+ return [line for region in segment_predictions for line in region.get('lines', [])]
237
+
238
+ def process_image(self, image) -> Dict[str, Any]:
239
+ """
240
+ Process a document image through the complete HTR pipeline.
241
+
242
+ Args:
243
+ image: PIL Image object or numpy array
244
+
245
+ Returns:
246
+ Dictionary containing:
247
+ - success: bool indicating if processing succeeded
248
+ - segment_predictions: List of detected regions and lines
249
+ - text_predictions: List of recognized text strings
250
+ - processing_time: Time taken in seconds
251
+ - error: Error message if success is False
252
+ """
253
+ start_time = time.time()
254
+ result = {
255
+ 'success': False,
256
+ 'segment_predictions': None,
257
+ 'text_predictions': None,
258
+ 'processing_time': 0.0,
259
+ 'error': None
260
+ }
261
+
262
+ try:
263
+ # Convert PIL image to numpy if needed
264
+ if not isinstance(image, np.ndarray):
265
+ image = np.array(image.convert('RGB'))
266
+
267
+ # Run segmentation
268
+ segment_predictions = self.segmenter.get_segmentation(image)
269
+
270
+ if not segment_predictions:
271
+ result['error'] = "No text lines detected in the image"
272
+ result['processing_time'] = time.time() - start_time
273
+ return result
274
+
275
+ logger.info("✓ Segmentation completed")
276
+
277
+ # Extract all lines for recognition
278
+ img_lines = self._merge_lines(segment_predictions)
279
+
280
+ # Run text recognition
281
+ text_predictions = self.recognizer.process_lines(img_lines, image)
282
+ logger.info("✓ Text recognition completed")
283
+
284
+ result['success'] = True
285
+ result['segment_predictions'] = segment_predictions
286
+ result['text_predictions'] = text_predictions
287
+
288
+ except Exception as e:
289
+ logger.error(f"Error during image processing: {e}", exc_info=True)
290
+ result['error'] = str(e)
291
+
292
+ finally:
293
+ result['processing_time'] = time.time() - start_time
294
+
295
+ return result
296
+
297
+ def is_allowed_source(file_path: Optional[str]) -> bool:
298
  """
299
+ Check if a file path is from an allowed source.
300
+
301
+ This security measure prevents processing of files from untrusted sources,
302
+ limiting uploads to specific domains and temporary directories.
303
+
304
+ Args:
305
+ file_path: Path to the uploaded file
306
+
307
+ Returns:
308
+ True if source is allowed, False otherwise
309
+ """
310
+ if not file_path:
311
+ logger.warning("No file path provided")
312
+ return False
313
+
314
+ # Check if path starts with any allowed source
315
+ is_allowed = any(file_path.startswith(source) for source in Config.ALLOWED_SOURCES)
316
+
317
+ if not is_allowed:
318
+ logger.warning(f"File path not allowed: {file_path}")
319
+
320
+ return is_allowed
321
+
322
+
323
+ async def extract_filepath_from_request(request: gr.Request) -> Optional[str]:
324
+ """
325
+ Extract file path from Gradio request object.
326
+
327
+ Args:
328
+ request: Gradio Request object
329
+
330
+ Returns:
331
+ File path string or None if not found
332
  """
333
  try:
 
334
  body = await request.body()
335
+ if not body:
336
+ return None
337
+
338
+ body_str = body.decode('utf-8')
339
+ body_json = json.loads(body_str)
340
+
341
+ # Navigate through Gradio's request structure
342
+ if 'data' in body_json and isinstance(body_json['data'], list):
343
+ for item in body_json['data']:
344
+ if isinstance(item, dict) and 'path' in item:
345
+ file_path = item['path']
346
+ logger.info(f"Extracted file path: {file_path}")
347
+ return file_path
348
+
349
+ return None
350
+
351
+ except json.JSONDecodeError:
352
+ logger.warning("Request body is not valid JSON")
353
+ return None
354
  except Exception as e:
355
+ logger.error(f"Error extracting file path: {e}")
356
+ return None
357
+
358
+
359
+ # Initialize HTR pipeline
360
+ try:
361
+ pipeline = HTRPipeline(
362
+ segmentation_model_path=SEGMENTATION_MODEL_PATH,
363
+ recognition_model_path=TROCR_MODEL_PATH,
364
+ segmentation_max_size=Config.SEGMENTATION_MAX_SIZE,
365
+ recognition_batch_size=Config.RECOGNITION_BATCH_SIZE,
366
+ segmentation_confidence_threshold = Config.SEGMENTATION_CONFIDENCE_THRESHOLD,
367
+ segmentation_line_percentage_threshold = Config.SEGMENTATION_LINE_PRECENTAGE_THRESHOLD,
368
+ segmentation_region_percentage_threshold = Config.SEGMENTATION_REGION_PRECENTAGE_THRESHOLD,
369
+ segmentation_line_iou = Config.SEGMENTATION_LINE_IOU,
370
+ segmentation_region_iou = Config.SEGMENTATION_REGION_IOU,
371
+ segmentation_line_overlap_threshold = Config.SEGMENTATION_LINE_OVERLAP_THRESHOLD,
372
+ segmentation_region_overlap_threshold = Config.SEGMENTATION_REGION_OVERLAP_THRESHOLD
373
+ )
374
+ logger.info("✓ HTR Pipeline initialized successfully")
375
+ except Exception as e:
376
+ logger.error(f"Failed to initialize HTR pipeline: {e}")
377
+ raise
378
+
379
+
380
+ def create_demo() -> gr.Blocks:
381
+ """
382
+ Create and configure the Gradio demo interface.
383
+
384
+ Returns:
385
+ Configured Gradio Blocks interface
386
+ """
387
+
388
+ with gr.Blocks(
389
+ theme=gr.themes.Monochrome(),
390
+ title="Multicentury HTR Demo"
391
+ ) as demo:
392
+
393
+ gr.Image("logo.png",
394
+ width=200,
395
+ height=100,
396
+ show_label=False,
397
+ show_download_button=False,
398
+ show_fullscreen_button=False,
399
+ container=False,
400
+ interactive=False
401
+ )
402
+
403
+ gr.Markdown("# 📜 Multicentury Handwritten Text Recognition")
404
+
405
+ with gr.Tabs():
406
+ # English documentation
407
+ with gr.Tab("English"):
408
+ gr.Markdown("""
409
+ ## About this demo
410
+
411
+ This HTR (Handwritten Text Recognition) pipeline combines two machine learning models:
412
+
413
+ 1. **Text Region & Line Detection**: Identifies text regions and individual lines in document images
414
+ 2. **Handwritten Text Recognition**: Transcribes the detected text lines
415
+
416
+ The models have been trained by the National Archives of Finland in autumn 2025 using handwritten documents
417
+ from the 16th to 20th centuries.
418
+
419
+ ### How to use
420
+
421
+ 1. Upload an image in the **Text Content** tab
422
+ 2. Click **Process Image**
423
+ 3. View results: transcribed text, detected regions, and text lines
424
+
425
+ ### To obtain best results
426
+
427
+ - Use high-quality scans
428
+ - Ensure good contrast between text and background
429
+ - Note that regular document layouts work best
430
+
431
+ ⚠️ **Note**: This is a demo application. 24/7 availability is not guaranteed.
432
+ """)
433
+
434
+ # Finnish documentation
435
+ with gr.Tab("Suomeksi"):
436
+ gr.Markdown("""
437
+ ## Tietoa demosta
438
+
439
+ Käsialantunnistusputki sisältää kaksi koneoppimismallia:
440
+
441
+ 1. **Tekstialueiden ja -rivien tunnistus**: Tunnistaa tekstialueet ja yksittäiset rivit dokumenttikuvista
442
+ 2. **Käsinkirjoitetun tekstin tunnistus**: Litteroi tunnistetut tekstirivit
443
+
444
+ Mallit on koulutettu Kansallisarkistossa syksyllä 2025 käsinkirjoitetulla aineistolla,
445
+ joka ajoittuu 1500-luvulta 1900-luvulle.
446
+
447
+ ### Käyttöohje
448
+
449
+ 1. Lataa kuva **Text Content** -välilehdellä
450
+ 2. Paina **Process Image** -painiketta
451
+ 3. Tarkastele tuloksia: litteroitu teksti, tunnistetut alueet ja tekstirivit
452
+
453
+ ### Parhaat tulokset saat kun
454
+
455
+ - Käytät korkealaatuisia skannauksia
456
+ - Varmistat hyvän kontrastin tekstin ja taustan välillä
457
+ - Huomioit että monimutkaiset rakenteet (esim. taulukot) voivat vaikeuttaa tunnistusta
458
+
459
+ ⚠️ **Huom**: Tämä on demosovellus. Ympärivuorokautista toimivuutta ei luvata.
460
+ """)
461
+
462
+ gr.Markdown("---")
463
+
464
+ with gr.Tabs():
465
+ with gr.Tab("📄 Text Content"):
466
+ with gr.Row():
467
+ with gr.Column(scale=1):
468
+ input_img = gr.Image(
469
+ label="Input Image",
470
+ type="pil",
471
+ height=400
472
+ )
473
+ with gr.Row():
474
+ process_btn = gr.Button(
475
+ "🚀 Process Image",
476
+ variant="primary",
477
+ size="lg"
478
+ )
479
+ clear_btn = gr.ClearButton(
480
+ components=[input_img],
481
+ value="🗑️ Clear"
482
+ )
483
+
484
+ with gr.Column(scale=1):
485
+ textbox = gr.Textbox(
486
+ label="Recognized Text",
487
+ lines=15,
488
+ max_lines=30,
489
+ show_copy_button=True,
490
+ placeholder="Processed text will appear here..."
491
+ )
492
+ download_text_file = gr.File(
493
+ label="💾 Download Text",
494
+ visible=False,
495
+ interactive=False
496
+ )
497
+
498
+ processing_time = gr.Markdown(
499
+ "",
500
+ elem_classes="processing-time"
501
+ )
502
+ status_message = gr.Markdown(
503
+ "",
504
+ elem_classes="error-message"
505
+ )
506
+
507
+ with gr.Tab("🗺️ Text Regions"):
508
+ region_img = gr.Image(
509
+ label="Detected Text Regions",
510
+ type="numpy",
511
+ height=500
512
+ )
513
+ region_info = gr.Markdown("Upload and process an image to see detected regions")
514
+
515
+ with gr.Tab("📝 Text Lines"):
516
+ line_img = gr.Image(
517
+ label="Detected Text Lines",
518
+ type="numpy",
519
+ height=500
520
+ )
521
+ line_info = gr.Markdown("Upload and process an image to see detected text lines")
522
+
523
+ async def process_pipeline(image, request: gr.Request):
524
+ """
525
+ Main processing function for the Gradio interface.
526
+
527
+ Validates input, checks file source, runs HTR pipeline, and formats results.
528
+ """
529
+ # Reset outputs
530
+ outputs = {
531
+ region_img: None,
532
+ line_img: None,
533
+ textbox: "",
534
+ processing_time: "",
535
+ status_message: "",
536
+ download_text_file: gr.update(visible=False, value=None),
537
+ region_info: "",
538
+ line_info: ""
539
+ }
540
+
541
+ # Check file source (security measure)
542
+ if request:
543
+ file_path = await extract_filepath_from_request(request)
544
+ if file_path and not is_allowed_source(file_path):
545
+ outputs[status_message] = "❌ **Error**: File source not allowed for security reasons"
546
+ yield tuple(outputs.values())
547
+ return
548
+
549
+ # Show processing status
550
+ outputs[status_message] = "⏳ Processing image..."
551
+ yield tuple(outputs.values())
552
+
553
+ # Run HTR pipeline
554
+ result = pipeline.process_image(image)
555
+
556
+ # Format processing time
557
+ time_str = f"⏱️ Processing time: {result['processing_time']:.2f}s"
558
+ outputs[processing_time] = time_str
559
+
560
+ if not result['success']:
561
+ error = result['error'] or "Unknown error occurred"
562
+ outputs[status_message] = f"❌ **Error**: {error}"
563
+ yield tuple(outputs.values())
564
+ return
565
+
566
+ # Process successful results
567
+ try:
568
+ segment_predictions = result['segment_predictions']
569
+ text_predictions = result['text_predictions']
570
+
571
+ # Generate visualizations
572
+ region_plot = pipeline.plotter.plot_regions(segment_predictions, image)
573
+ line_plot = pipeline.plotter.plot_lines(segment_predictions, image)
574
+
575
+ # Format text output
576
+ recognized_text = "\n".join(text_predictions) if text_predictions else ""
577
+
578
+ # Update outputs
579
+ outputs[region_img] = region_plot
580
+ outputs[line_img] = line_plot
581
+ outputs[textbox] = recognized_text
582
+ outputs[status_message] = f"Recognized {len(text_predictions)} text lines"
583
+
584
+ ## Create downloadable text file if text was recognized
585
+ if recognized_text:
586
+ # Create temporary file with proper filename
587
+ temp_dir = tempfile.gettempdir()
588
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
589
+ filename = f"htr_result_{timestamp}.txt"
590
+ filepath = os.path.join(temp_dir, filename)
591
+
592
+ # Write text to file
593
+ with open(filepath, 'w', encoding='utf-8') as f:
594
+ f.write(recognized_text)
595
+
596
+ outputs[download_text_file] = gr.update(visible=True, value=filepath)
597
+
598
+ # Update info sections
599
+ num_regions = len(segment_predictions)
600
+ outputs[region_info] = f"Detected **{num_regions}** text region(s)"
601
+ outputs[line_info] = f"Detected **{len(text_predictions)}** text line(s)"
602
+
603
+ except Exception as e:
604
+ logger.error(f"Error formatting results: {e}", exc_info=True)
605
+ outputs[status_message] = f"❌ **Error**: Failed to format results - {e}"
606
+
607
+ yield tuple(outputs.values())
608
+
609
+ # Connect button to processing function
610
+ process_btn.click(
611
+ fn=process_pipeline,
612
+ inputs=[input_img],
613
+ outputs=[
614
+ region_img,
615
+ line_img,
616
+ textbox,
617
+ processing_time,
618
+ status_message,
619
+ download_text_file,
620
+ region_info,
621
+ line_info
622
+ ],
623
+ api_name=False # Disable API endpoint for security
624
+ )
625
+
626
+ return demo
627
+
628
 
629
+ # Create and launch demo
630
  if __name__ == "__main__":
631
+ demo = create_demo()
632
+ demo.queue(
633
+ max_size=30, # 30 users can queue without being rejected
634
+ default_concurrency_limit=1 # Only one image processes at a time
635
+ )
636
+ demo.launch(
637
+ show_error=True,
638
+ max_threads=2 # Minimal threads: 1 for processing + 1 for queue management
639
+ )