IFMedTechdemo commited on
Commit
c1b0ad4
·
verified ·
1 Parent(s): acb99cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -180
app.py CHANGED
@@ -1,200 +1,314 @@
1
  import gradio as gr
2
- import pickle
3
- import os
4
  import cv2
5
- import pandas as pd
 
 
6
  import re
 
 
 
7
  from symspellpy import SymSpell, Verbosity
8
- from rapidocr import RapidOCR, EngineType, LangDet, LangRec, ModelType, OCRVersion
9
- import logging
10
-
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
-
17
- # Constants - Match both separated and merged prefixes
18
- ANCHOR_PATTERN = re.compile(
19
- r"\b(tab\.?|cap\.?|t\.?)\s*([a-zA-Z0-9]+)",
20
- re.IGNORECASE
21
- )
22
-
23
-
24
  # ============================================================================
25
- # GLOBAL SINGLETONS
26
  # ============================================================================
27
- _ocr_engine = None
28
- _drug_db = None
29
- _sym_spell = None
30
- _cache_path = os.path.join(os.path.dirname(__file__), "cache","Final_Medibot_Database_Cleaned_pickleFile.pkl")#"Final_Tata_Kaggle_merged_pickleFile.pkl")# "database_cache.pkl")
31
-
32
-
33
- def ensure_cache_dir():
34
- """Ensure cache directory exists."""
35
- cache_dir = os.path.dirname(_cache_path)
36
- if not os.path.exists(cache_dir):
37
- os.makedirs(cache_dir, exist_ok=True)
38
-
39
 
40
  def initialize_database():
41
- """
42
- Load drug database and SymSpell once.
43
- Uses cache if available to skip expensive recomputation.
44
- """
45
- global _drug_db, _sym_spell
46
-
47
- ensure_cache_dir()
48
-
49
- # Try to load from cache
50
- if os.path.exists(_cache_path):
51
- logger.info("Loading database from cache...")
52
- try:
53
- with open(_cache_path, 'rb') as f:
54
- cache_data = pickle.load(f)
55
- _drug_db = cache_data['drug_db']
56
- _sym_spell = cache_data['sym_spell']
57
- logger.info(f"Cache loaded: {len(_drug_db)} drugs")
58
- return
59
- except Exception as e:
60
- logger.warning(f"Cache load failed: {e}. Recomputing...")
61
-
62
- # Compute from scratch
63
- logger.info("Initializing database...")
64
- _drug_db = {}
65
-
66
  try:
67
- df = pd.read_csv("Dataset.csv")
68
- for idx, row in df.iterrows():
69
- drug_name = str(row.get('drug_name', '')).strip().lower()
70
- if drug_name:
71
- _drug_db[drug_name] = True
72
- except Exception as e:
73
- logger.warning(f"Dataset loading failed: {e}. Using minimal DB.")
74
- _drug_db = {"aspirin": True, "paracetamol": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Initialize SymSpell with drug DB
77
- _sym_spell = SymSpell(max_dictionary_edit_distance=1)
78
- for drug in _drug_db:
79
- _sym_spell.create_dictionary_entry(drug, 1000)
 
 
 
80
 
81
- # Cache for next startup
82
- try:
83
- cache_data = {
84
- 'drug_db': _drug_db,
85
- 'sym_spell': _sym_spell
86
- }
87
- with open(_cache_path, 'wb') as f:
88
- pickle.dump(cache_data, f)
89
- logger.info(f"Database cached: {len(_drug_db)} drugs")
90
- except Exception as e:
91
- logger.warning(f"Cache save failed: {e}")
92
-
93
-
94
- def get_ocr_engine():
95
- """Get or create the RapidOCR engine with MOBILE + ONNX optimization."""
96
- global _ocr_engine
97
- if _ocr_engine is None:
98
- logger.info("Initializing RapidOCR engine with MOBILE models...")
99
- _ocr_engine = RapidOCR(
100
- params={
101
- "Global.max_side_len": 1280,
102
- "Det.engine_type": EngineType.ONNXRUNTIME,
103
- "Det.lang_type": LangDet.CH,
104
- "Det.model_type": ModelType.MOBILE,
105
- "Det.ocr_version": OCRVersion.PPOCRV5,
106
- "Rec.engine_type": EngineType.ONNXRUNTIME,
107
- "Rec.lang_type": LangRec.CH,
108
- "Rec.model_type": ModelType.MOBILE,
109
- "Rec.ocr_version": OCRVersion.PPOCRV5,
110
- }
111
- )
112
- return _ocr_engine
113
-
114
-
115
- def validate_drug_match(term: str) -> str:
116
- """Map term to canonical database drug, or None if noise."""
117
- term = term.lower()
118
-
119
- if term in _drug_db:
120
- return term
121
 
122
- # Skip SymSpell for very short or long words
123
- if len(term) < 3 or len(term) > 15:
124
- return None
125
 
126
- # Fuzzy match via SymSpell
127
- suggestions = _sym_spell.lookup(term, Verbosity.CLOSEST, max_edit_distance=1)
128
- if suggestions and suggestions[0].term in _drug_db:
129
- return suggestions[0].term
 
 
 
130
 
 
 
 
 
 
 
 
 
 
131
  return None
132
 
133
-
134
- def extract_drugs_from_line(line_text: str):
135
  """
136
- Extract drug names that follow or are merged with tab/cap/t prefixes.
 
 
 
137
  """
138
- drugs = []
139
- matches = ANCHOR_PATTERN.finditer(line_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- for match in matches:
142
- prefix = match.group(1).lower()
143
- next_word = match.group(2)
144
- canonical = validate_drug_match(next_word)
145
- if canonical:
146
- drugs.append(canonical)
147
 
148
- return drugs
 
 
 
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- def process_image_ocr(image):
152
- """Fast OCR + drug extraction using MOBILE models and ONNX runtime."""
153
- logger.info("Processing image...")
154
 
155
- ocr_engine = get_ocr_engine()
156
 
157
- # Preprocess: resize if too large
158
- height, width = image.shape[:2]
159
- if width > 1280:
160
- scale = 1280 / width
161
- image = cv2.resize(image, None, fx=scale, fy=scale)
162
 
163
- # Run OCR with optimized settings
164
- try:
165
- ocr_result = ocr_engine(
166
- image,
167
- use_det=True,
168
- use_cls=False,
169
- use_rec=True,
170
- )
171
- except Exception as e:
172
- logger.error(f"OCR failed: {e}")
173
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- # Handle new RapidOCR return format
176
- if not ocr_result or not hasattr(ocr_result, 'txts'):
177
- return {"drugs": [], "raw_lines": []}
 
 
 
 
 
 
 
 
178
 
179
- # Extract drugs
180
- drugs_found = set()
181
- raw_lines = []
182
 
183
- for line_text in ocr_result.txts:
184
- if not line_text:
 
 
 
 
 
185
  continue
186
 
187
- raw_lines.append(line_text)
188
- line_drugs = extract_drugs_from_line(line_text)
189
- drugs_found.update(line_drugs)
190
 
191
- return {
192
- "drugs": sorted(list(drugs_found)),
193
- "raw_lines": raw_lines,
194
- "drugs_count": len(drugs_found),
195
- "elapse": f"{ocr_result.elapse:.3f}s"
196
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
 
 
198
 
199
  def process_input(image_input):
200
  """Gradio interface handler."""
@@ -202,26 +316,39 @@ def process_input(image_input):
202
  return "Please upload an image.", {}
203
 
204
  try:
205
- # Convert Gradio RGB to BGR for OpenCV
206
- image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
207
- result = process_image_ocr(image)
208
-
209
- if "error" in result:
210
- return f"Error: {result['error']}", {}
211
-
 
 
 
 
 
 
 
 
 
 
 
 
212
  # Summary text
213
- summary = f"Found {result['drugs_count']} medication(s) in {result.get('elapse', 'N/A')}"
214
 
215
  # JSON output with all medications
216
  medications_json = {
217
- "total_medications": result["drugs_count"],
218
- "processing_time": result.get("elapse", "N/A"),
219
  "medications": [
220
  {
221
  "id": idx + 1,
222
- "name": drug.title()
 
223
  }
224
- for idx, drug in enumerate(result["drugs"])
225
  ]
226
  }
227
 
@@ -230,15 +357,7 @@ def process_input(image_input):
230
  logger.error(f"Processing error: {e}")
231
  return f"Error: {str(e)}", {}
232
 
233
-
234
- # ============================================================================
235
- # Gradio Interface
236
- # ============================================================================
237
-
238
  logger.info("Starting Medibot...")
239
- initialize_database()
240
- logger.info("Database initialized. Ready for inference.")
241
-
242
 
243
  with gr.Blocks(title="Medibot - Fast OCR") as demo:
244
  gr.Markdown("# Medibot: Prescription OCR")
 
1
  import gradio as gr
 
 
2
  import cv2
3
+ import time
4
+ import logging
5
+ import os
6
  import re
7
+ import pickle
8
+ import json
9
+ import pandas as pd
10
  from symspellpy import SymSpell, Verbosity
11
+ from rapidocr import RapidOCR, EngineType, LangCls, LangDet, LangRec, ModelType, OCRVersion
 
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
 
 
 
17
  # ============================================================================
18
+ # Database Initialization (from src/database/init.py)
19
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def initialize_database():
22
+ # Assuming data/Dataset.csv is relative to the current script or fixed path
23
+ # Adjust path if necessary. app.py is in root, data is in ./data
24
+ data_path = os.path.join(os.path.dirname(__file__), "data/Dataset.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
+ df = pd.read_csv(data_path, encoding='utf-8')
27
+ except UnicodeDecodeError:
28
+ df = pd.read_csv(data_path, encoding='latin1')
29
+ drug_db = set(df["Combined_Drugs"].astype(str).str.lower().str.strip())
30
+
31
+ sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
32
+
33
+ for drug in drug_db:
34
+ d = drug.lower()
35
+ sym_spell.create_dictionary_entry(d, 100000)
36
+ parts = d.split()
37
+ if len(parts) > 1:
38
+ for p in parts:
39
+ sym_spell.create_dictionary_entry(p, 100000)
40
+
41
+ drug_token_index = {}
42
+ for full in drug_db:
43
+ toks = full.split()
44
+ for tok in toks:
45
+ drug_token_index.setdefault(tok, set()).add(full)
46
+
47
+ ANCHOR_PREFIXES = ["tab", "cap"]
48
+
49
+ ANCHORS = [
50
+ r"tab\.?", r"cap\.?"
51
+ ]
52
+ ANCHOR_PATTERN = re.compile(r"\b(" + "|".join(ANCHORS) + r")\b", re.IGNORECASE)
53
 
54
+ return {
55
+ 'drug_db': drug_db,
56
+ 'sym_spell': sym_spell,
57
+ 'drug_token_index': drug_token_index,
58
+ 'ANCHOR_PREFIXES': ANCHOR_PREFIXES,
59
+ 'ANCHOR_PATTERN': ANCHOR_PATTERN
60
+ }
61
 
62
+ # Initialize Database Globally
63
+ logger.info("Initializing database...")
64
+ cache_path = os.path.join(os.path.dirname(__file__), "cache/database_cache.pkl")
65
+ try:
66
+ with open(cache_path, 'rb') as f:
67
+ cache = pickle.load(f)
68
+ drug_db = cache['drug_db']
69
+ sym_spell = cache['sym_spell']
70
+ drug_token_index = cache['drug_token_index']
71
+ ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES']
72
+ ANCHOR_PATTERN = cache['ANCHOR_PATTERN']
73
+ logger.info("Database loaded from cache.")
74
+ except FileNotFoundError:
75
+ logger.info("Cache not found. Initializing from CSV...")
76
+ cache = initialize_database()
77
+ drug_db = cache['drug_db']
78
+ sym_spell = cache['sym_spell']
79
+ drug_token_index = cache['drug_token_index']
80
+ ANCHOR_PREFIXES = cache['ANCHOR_PREFIXES']
81
+ ANCHOR_PATTERN = cache['ANCHOR_PATTERN']
82
+ # Save cache
83
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
84
+ with open(cache_path, 'wb') as f:
85
+ pickle.dump(cache, f)
86
+ logger.info("Database initialized and cached.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # ============================================================================
89
+ # Helper Functions (from src/utils/drug_matching.py)
90
+ # ============================================================================
91
 
92
+ def is_potential_med_line(text: str, ANCHOR_PATTERN) -> bool:
93
+ t = text.lower()
94
+ # User requested ONLY anchor check
95
+ anchor_match = ANCHOR_PATTERN.search(t)
96
+ if anchor_match:
97
+ return True
98
+ return False
99
 
100
+ def validate_drug_match(term: str, drug_db, drug_token_index):
101
+ """
102
+ Map SymSpell term -> canonical database drug, or None if noise.
103
+ """
104
+ if term in drug_db:
105
+ return term
106
+ if term in drug_token_index:
107
+ # pick one canonical name; you can change selection logic if needed
108
+ return sorted(drug_token_index[term])[0]
109
  return None
110
 
111
+ def normalize_anchored_tokens(raw_text: str, ANCHOR_PREFIXES):
 
112
  """
113
+ Use TAB/CAP/T. as anchors, not something to delete:
114
+ - 'TABCLOPITAB75MG TAB' -> ['clopitab']
115
+ - 'TAB SOBISISTAB' -> ['sobisistab']
116
+ - 'TABSTARPRESSXL25MGTAB' -> ['starpressxl']
117
  """
118
+ t = raw_text.lower()
119
+ # Remove dosage and numbers but keep anchor letters
120
+ t = re.sub(r"\d+\s*(mg|ml|gm|%|u|mcg)", " ", t)
121
+ t = re.sub(r"\d+", " ", t)
122
+
123
+ # Remove punctuation including full-width parentheses
124
+ t = re.sub(r"[^\w\s]", " ", t)
125
+
126
+ tokens = t.split()
127
+
128
+ normalized = []
129
+ skip_next = False
130
+
131
+ for i, tok in enumerate(tokens):
132
+ if skip_next:
133
+ skip_next = False
134
+ continue
135
 
136
+ base = tok
 
 
 
 
 
137
 
138
+ # Case 1: token starts with anchor as prefix (no space)
139
+ for pref in ANCHOR_PREFIXES:
140
+ if base.startswith(pref) and len(base) > len(pref):
141
+ base = base[len(pref):]
142
+ break
143
 
144
+ # Case 2: token is pure anchor and should attach to next token
145
+ if base in ["tab", "cap", "t"]:
146
+ if i + 1 < len(tokens):
147
+ merged = tokens[i + 1]
148
+ for pref in ANCHOR_PREFIXES:
149
+ if merged.startswith(pref) and len(merged) > len(pref):
150
+ merged = merged[len(pref):]
151
+ break
152
+ base = merged
153
+ skip_next = True
154
+ else:
155
+ continue
156
 
157
+ base = base.strip()
158
+ normalized.append(base)
 
159
 
160
+ return normalized
161
 
162
+ # ============================================================================
163
+ # OCR Processor (from src/ocr/processor.py)
164
+ # ============================================================================
 
 
165
 
166
+ def process_image_ocr(image_path):
167
+ # Load image using cv2
168
+ img = cv2.imread(image_path)
169
+ if img is None:
170
+ raise ValueError(f"Could not load image from {image_path}")
171
+
172
+ # Create RapidOCR engine with default parameters
173
+ ocr_engine = RapidOCR(
174
+ params={
175
+ "Global.max_side_len": 2000,
176
+ "Det.engine_type": EngineType.ONNXRUNTIME,
177
+ "Det.lang_type": LangDet.EN,
178
+ "Det.model_type": ModelType.MOBILE,
179
+ "Det.ocr_version": OCRVersion.PPOCRV4,
180
+ "Cls.engine_type": EngineType.ONNXRUNTIME,
181
+ "Cls.lang_type": LangCls.CH,
182
+ "Cls.model_type": ModelType.MOBILE,
183
+ "Cls.ocr_version": OCRVersion.PPOCRV4,
184
+ "Rec.engine_type": EngineType.ONNXRUNTIME,
185
+ "Rec.lang_type": LangRec.EN,
186
+ "Rec.model_type": ModelType.MOBILE,
187
+ "Rec.ocr_version": OCRVersion.PPOCRV4,
188
+ }
189
+ )
190
 
191
+ # Run OCR
192
+ ocr_result = ocr_engine(
193
+ img,
194
+ use_det=True,
195
+ use_cls=True,
196
+ use_rec=True,
197
+ text_score=0.3,
198
+ box_thresh=0.3,
199
+ unclip_ratio=2.0,
200
+ return_word_box=False,
201
+ )
202
 
203
+ ocr_data = ocr_result.txts
 
 
204
 
205
+ found_meds_with_originals = {}
206
+
207
+ for item in ocr_data:
208
+ text_lower = item.lower()
209
+
210
+ # Simplified line-level gate: ONLY check anchors
211
+ if not is_potential_med_line(text_lower, ANCHOR_PATTERN):
212
  continue
213
 
214
+ # Skip doctor name lines
215
+ if "dr." in text_lower or "dr " in text_lower:
216
+ continue
217
 
218
+ # Anchor-aware tokens
219
+ candidate_tokens = normalize_anchored_tokens(item, ANCHOR_PREFIXES)
220
+
221
+ # Save original normalized text for exact match checking
222
+ normalized_text_str = " ".join(candidate_tokens)
223
+
224
+ # Optional SymSpell segmentation on normalized tokens
225
+ if candidate_tokens:
226
+ segmentation = sym_spell.word_segmentation(" ".join(candidate_tokens))
227
+ corrected_string = segmentation.corrected_string
228
+ candidate_tokens = corrected_string.split()
229
+
230
+ line_matches = []
231
+ i = 0
232
+ n = len(candidate_tokens)
233
+
234
+ while i < n:
235
+ match_found = False
236
+ # Greedy longest match: try phrases of length 5 down to 1
237
+ for length in range(min(5, n - i), 0, -1):
238
+ phrase_tokens = candidate_tokens[i : i + length]
239
+ phrase = " ".join(phrase_tokens)
240
+
241
+ # Check exact phrase in DB
242
+ if phrase in drug_db:
243
+ # Found a multi-word (or single-word) drug match!
244
+ if phrase in normalized_text_str:
245
+ line_matches.append((phrase, "exact", phrase))
246
+ else:
247
+ line_matches.append((phrase, "fuzzy", phrase))
248
+
249
+ i += length
250
+ match_found = True
251
+ break
252
+
253
+ if match_found:
254
+ continue
255
+
256
+ # Fallback: Single token processing (Fuzzy / Partial)
257
+ word = candidate_tokens[i]
258
+ i += 1
259
+
260
+ # Check for exact match first (as a single token)
261
+ canonical = validate_drug_match(word, drug_db, drug_token_index)
262
+ if canonical:
263
+ # Coverage check: detected word must cover a significant portion of the canonical name
264
+ if len(word) / len(canonical) < 0.6:
265
+ continue
266
+
267
+ if word in normalized_text_str:
268
+ line_matches.append((canonical, "exact", word))
269
+ else:
270
+ line_matches.append((canonical, "fuzzy", word))
271
+ continue
272
+
273
+ # Fuzzy matching
274
+ if len(word) < 3:
275
+ continue
276
+
277
+ suggestions = sym_spell.lookup(
278
+ word, Verbosity.CLOSEST, max_edit_distance=1
279
+ )
280
+ if not suggestions:
281
+ continue
282
+
283
+ cand = suggestions[0].term
284
+ canonical = validate_drug_match(cand, drug_db, drug_token_index)
285
+ if canonical:
286
+ # Coverage check for fuzzy match too
287
+ if len(word) / len(canonical) < 0.6:
288
+ continue
289
+ line_matches.append((canonical, "fuzzy", word))
290
+
291
+ # Filter matches for this line:
292
+ exact_matches = [m for m in line_matches if m[1] == "exact"]
293
+ if exact_matches:
294
+ final_matches = exact_matches
295
+ else:
296
+ final_matches = line_matches
297
+
298
+ for match in final_matches:
299
+ canonical = match[0]
300
+ original_text = match[2]
301
+
302
+ if canonical not in found_meds_with_originals:
303
+ found_meds_with_originals[canonical] = []
304
+ if item not in found_meds_with_originals[canonical]:
305
+ found_meds_with_originals[canonical].append(item)
306
+
307
+ return found_meds_with_originals
308
 
309
+ # ============================================================================
310
+ # Gradio Interface
311
+ # ============================================================================
312
 
313
  def process_input(image_input):
314
  """Gradio interface handler."""
 
316
  return "Please upload an image.", {}
317
 
318
  try:
319
+ temp_path = "temp_upload.jpg"
320
+ # Convert RGB (Gradio) to BGR (OpenCV)
321
+ image_bgr = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
322
+ cv2.imwrite(temp_path, image_bgr)
323
+
324
+ start_time = time.time()
325
+
326
+ # Use the robust processor
327
+ found_meds_dict = process_image_ocr(temp_path)
328
+
329
+ elapsed_time = time.time() - start_time
330
+
331
+ # Cleanup
332
+ if os.path.exists(temp_path):
333
+ os.remove(temp_path)
334
+
335
+ drugs_list = sorted(found_meds_dict.keys())
336
+ drugs_count = len(drugs_list)
337
+
338
  # Summary text
339
+ summary = f"Found {drugs_count} medication(s) in {elapsed_time:.3f}s"
340
 
341
  # JSON output with all medications
342
  medications_json = {
343
+ "total_medications": drugs_count,
344
+ "processing_time": f"{elapsed_time:.3f}s",
345
  "medications": [
346
  {
347
  "id": idx + 1,
348
+ "name": drug.title(),
349
+ "original_text": found_meds_dict[drug]
350
  }
351
+ for idx, drug in enumerate(drugs_list)
352
  ]
353
  }
354
 
 
357
  logger.error(f"Processing error: {e}")
358
  return f"Error: {str(e)}", {}
359
 
 
 
 
 
 
360
  logger.info("Starting Medibot...")
 
 
 
361
 
362
  with gr.Blocks(title="Medibot - Fast OCR") as demo:
363
  gr.Markdown("# Medibot: Prescription OCR")