Seth0330 commited on
Commit
d12949e
·
verified ·
1 Parent(s): 59560f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -325
app.py CHANGED
@@ -1,339 +1,306 @@
1
  # app.py
2
- # Invoice Extraction Donut (public HF model) + Tesseract tables
3
- # Robust PDF handling:
4
- # 1) Try pdf2image with Poppler path detection (Fix A)
5
- # 2) If Poppler is missing, auto-fallback to PyMuPDF (no Poppler required)
 
 
 
 
 
 
 
 
6
 
7
- import os, io, re, json, shutil
8
- from typing import List
9
  import numpy as np
10
- import pandas as pd
11
- from PIL import Image, ImageOps, ImageFilter
12
-
13
  import streamlit as st
14
-
15
- # OCR (detection only) and PDF->image
16
- import pytesseract
17
- from pytesseract import Output
18
- from pdf2image import convert_from_bytes
19
-
20
- # HF Donut (public model)
21
  import torch
22
- from transformers import DonutProcessor, VisionEncoderDecoderModel
23
-
24
- # ------------------------------------------------------------------
25
- st.set_page_config(
26
- page_title="Invoice Extraction Donut (public) + Tesseract tables",
27
- layout="wide"
28
- )
29
-
30
- device = "cuda" if torch.cuda.is_available() else "cpu"
31
-
32
- # ----------------------------- Sidebar -----------------------------
33
- st.sidebar.header("Model (Hugging Face — public)")
34
- model_id = st.sidebar.text_input(
35
- "HF model id",
36
- value="naver-clova-ix/donut-base-finetuned-cord-v2",
37
- help="Use a public model id; this one works without token."
38
- )
39
- task_prompt = st.sidebar.text_input(
40
- "Task prompt (Donut)",
41
- value="<s_cord-v2>",
42
- help="Keep default for CORD-style invoices."
43
- )
44
- det_lang = st.sidebar.text_input("Tesseract language(s) — detection only", value="eng")
45
- show_boxes = st.sidebar.checkbox("Show word boxes (debug)", value=False)
46
-
47
- # ----------------------------- PDF loader (Fix A + fallback) -----------------------------
48
- def _find_poppler_path():
49
- # Return a folder containing pdfinfo/pdftoppm if not on PATH
50
- if shutil.which("pdfinfo") and shutil.which("pdftoppm"):
51
- return None
52
- for p in ["/usr/bin", "/usr/local/bin", "/usr/share/bin"]:
53
- if os.path.exists(os.path.join(p, "pdfinfo")) and os.path.exists(os.path.join(p, "pdftoppm")):
54
- return p
55
- return None
56
-
57
- def _pages_via_pdf2image(file_bytes: bytes) -> List[Image.Image]:
58
- poppler_path = _find_poppler_path()
59
- if poppler_path:
60
- return convert_from_bytes(file_bytes, dpi=300, poppler_path=poppler_path)
61
- else:
62
- return convert_from_bytes(file_bytes, dpi=300)
63
-
64
- def _pages_via_pymupdf(file_bytes: bytes) -> List[Image.Image]:
65
- import fitz # PyMuPDF
66
- doc = fitz.open(stream=file_bytes, filetype="pdf")
67
- pages = []
68
- for page in doc:
69
- # Use a mild upscale for better OCR if you want: matrix = fitz.Matrix(2, 2)
70
- pix = page.get_pixmap() # or: page.get_pixmap(matrix=matrix)
71
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
72
- pages.append(img)
73
- return pages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]:
76
- name = (name or "").lower()
77
- if name.endswith(".pdf"):
78
- # Try Poppler route first
79
- try:
80
- return _pages_via_pdf2image(file_bytes)
81
- except Exception:
82
- # Fallback: PyMuPDF (no Poppler required)
83
- return _pages_via_pymupdf(file_bytes)
84
- return [Image.open(io.BytesIO(file_bytes)).convert("RGB")]
85
 
86
- def preprocess_for_detection(img: Image.Image) -> Image.Image:
87
- g = ImageOps.grayscale(img)
88
- g = ImageOps.autocontrast(g)
89
- g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
90
- return g
91
 
92
- # ----------------------------- Donut loader -----------------------------
93
  @st.cache_resource(show_spinner=True)
94
- def load_donut(_model_id: str):
95
- processor = DonutProcessor.from_pretrained(_model_id)
96
- model = VisionEncoderDecoderModel.from_pretrained(_model_id)
97
- model.to(device).eval()
98
- return processor, model
99
-
100
- def donut_infer(img: Image.Image, processor: DonutProcessor, model: VisionEncoderDecoderModel, prompt: str):
101
- inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
102
- with torch.no_grad():
103
- outputs = model.generate(**inputs, max_length=1024, num_beams=1, early_stopping=True)
104
- seq = processor.batch_decode(outputs, skip_special_tokens=True)[0]
105
- parsed = None
106
- try:
107
- start = seq.find("{")
108
- end = seq.rfind("}")
109
- if start != -1 and end != -1 and end > start:
110
- parsed = json.loads(seq[start:end+1])
111
- except Exception:
112
- parsed = None
113
- return seq, parsed
114
-
115
- # ----------------------------- Key fields & tables -----------------------------
116
- CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?"
117
- MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
118
- DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
119
- INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,}))"
120
- PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))"
121
- TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})"
122
- SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
123
- TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
124
-
125
- def parse_fields_regex(fulltext: str):
126
- t = re.sub(r"[ \t]+", " ", fulltext)
127
- t = re.sub(r"\n{2,}", "\n", t)
128
- out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None}
129
- m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None
130
- m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None
131
- m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I)
132
- out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None))
133
- m = re.search(SUBTOTAL_PAT, t, re.I|re.S);
134
- if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
135
- m = re.search(TAX_PAT, t, re.I|re.S);
136
- if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
137
- m = re.search(TOTAL_PAT, t, re.I|re.S);
138
- if m:
139
- out["total"], out["currency"] = m.group("amt").replace(",", ""), m.group("curr") or out["currency"]
140
- if out["currency"] in ["$", "C$", "€", "£"]:
141
- out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]]
142
- return out
143
-
144
- def normalize_kv_from_donut(parsed: dict):
145
- out = {k: None for k in ["invoice_number","invoice_date","po_number","subtotal","tax","total","currency"]}
146
-
147
- def search_keys(obj, key_list):
148
- if isinstance(obj, dict):
149
- for k, v in obj.items():
150
- kl = k.lower()
151
- if any(kk in kl for kk in key_list):
152
- return v if isinstance(v, str) else None
153
- found = search_keys(v, key_list)
154
- if found is not None:
155
- return found
156
- elif isinstance(obj, list):
157
- for it in obj:
158
- found = search_keys(it, key_list)
159
- if found is not None:
160
- return found
161
- return None
162
-
163
- mapping = {
164
- "invoice_number": ["invoice_number","invoice no","invoice_no","invoice","inv_no"],
165
- "invoice_date": ["invoice_date","date","bill_date"],
166
- "po_number": ["po_number","po","purchase_order"],
167
- "subtotal": ["subtotal","sub_total"],
168
- "tax": ["tax","gst","vat","hst"],
169
- "total": ["total","amount_total","amount_due","grand_total"],
170
- }
171
- for k, keys in mapping.items():
172
- val = search_keys(parsed, keys)
173
- if isinstance(val, str):
174
- out[k] = val.strip()
175
-
176
- txt = json.dumps(parsed, ensure_ascii=False)
177
- m = re.search(r"(USD|CAD|EUR|GBP|\$|C\$|€|£)", txt, re.I)
178
- if m:
179
- sym = m.group(1)
180
- out["currency"] = {"$":"USD","C$":"CAD","€":"EUR","£":"GBP"}.get(sym, sym.upper())
181
  return out
182
 
183
- def items_from_words_simple(tsv: pd.DataFrame) -> pd.DataFrame:
184
- HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"]
185
- if tsv.empty:
186
- return pd.DataFrame()
187
-
188
- lines = []
189
- for (b,p,l), g in tsv.groupby(["block_num","par_num","line_num"]):
190
- text = " ".join([w for w in g["text"].astype(str).tolist() if w.strip()])
191
- if text.strip():
192
- lines.append({
193
- "block_num": b, "par_num": p, "line_num": l,
194
- "text": text.lower(),
195
- "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
196
- "left": g["left"].min(), "right": (g["left"]+g["width"]).max()
197
- })
198
- L = pd.DataFrame(lines)
199
- if L.empty:
200
- return pd.DataFrame()
201
- L["header_score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s))
202
- hdrs = L[L["header_score"] >= 2].sort_values(["header_score","top"], ascending=[False,True])
203
- if hdrs.empty:
204
- return pd.DataFrame()
205
-
206
- H = hdrs.iloc[0]
207
- header_top, header_bottom = H["top"], H["bottom"]
208
-
209
- header_words = tsv[(tsv["top"] >= header_top - 5) & ((tsv["top"] + tsv["height"]) <= header_bottom + 5)]
210
- header_words = header_words.sort_values("left")
211
- if header_words.empty:
212
- return pd.DataFrame()
213
- xs = header_words["left"].tolist()
214
- hdr_tokens = [t.lower() for t in header_words["text"].tolist()]
215
-
216
- below = tsv[tsv["top"] > header_bottom + 5].copy()
217
- totals_mask = below["text"].str.lower().str.contains(
218
- r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
- if totals_mask.any():
221
- stop_y = below.loc[totals_mask, "top"].min()
222
- below = below[below["top"] < stop_y - 4]
223
- if below.empty:
224
- return pd.DataFrame()
225
-
226
- rows = []
227
- for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
228
- g = g.sort_values("left")
229
- buckets = {i:[] for i in range(len(xs))}
230
- for _, w in g.iterrows():
231
- if not str(w["text"]).strip():
232
- continue
233
- idx = int(np.abs(np.array(xs) - w["left"]).argmin())
234
- buckets[idx].append(str(w["text"]))
235
- vals = [" ".join(buckets[i]).strip() for i in range(len(xs))]
236
- if any(vals):
237
- rows.append(vals)
238
- if not rows:
239
- return pd.DataFrame()
240
-
241
- df_rows = pd.DataFrame(rows).fillna("")
242
- names = []
243
- for i in range(df_rows.shape[1]):
244
- wl = hdr_tokens[i] if i < len(hdr_tokens) else f"col_{i}"
245
- if "desc" in wl or wl in ["item","description"]:
246
- names.append("description")
247
- elif wl in ["qty","quantity"]:
248
- names.append("quantity")
249
- elif "unit" in wl or "rate" in wl or "price" in wl:
250
- names.append("unit_price")
251
- elif "amount" in wl or "total" in wl:
252
- names.append("line_total")
253
- else:
254
- names.append(f"col_{i}")
255
- df_rows.columns = names
256
- df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
257
- return df_rows.reset_index(drop=True)
258
-
259
- # ----------------------------- App -----------------------------
260
- st.title("Invoice Extraction — Donut (public) + Tesseract tables")
261
-
262
- up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"])
263
- if not up:
264
- st.info("Upload a scanned invoice to begin.")
265
- st.stop()
266
-
267
- with st.spinner(f"Loading model '{model_id}' from Hugging Face…"):
268
- processor, donut_model = load_donut(model_id)
269
-
270
- pages = load_pages(up.read(), up.name)
271
- page_idx = 0
272
- if len(pages) > 1:
273
- page_idx = st.number_input("Page", 1, len(pages), 1) - 1
274
- img = pages[page_idx]
275
-
276
- col1, col2 = st.columns([1.1, 1.3], gap="large")
277
-
278
- with col1:
279
- st.subheader("Preview")
280
- st.image(img, use_column_width=True)
281
- det_img = preprocess_for_detection(img)
282
- with st.expander("Detection view (preprocessed for boxes)"):
283
- st.image(det_img, use_column_width=True)
284
-
285
- with col2:
286
- st.subheader("OCR & Extraction")
287
-
288
- with st.spinner("Running Donut…"):
289
- seq, parsed = donut_infer(img, processor, donut_model, task_prompt)
290
-
291
- if parsed:
292
- key_fields = normalize_kv_from_donut(parsed)
293
- donut_payload = parsed
294
- else:
295
- key_fields = parse_fields_regex(seq)
296
- donut_payload = {"generated_text": seq}
297
-
298
- k1,k2,k3 = st.columns(3)
299
- with k1:
300
- st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}")
301
- st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or '—'}")
302
- with k2:
303
- st.write(f"**PO #:** {key_fields.get('po_number') or '—'}")
304
- st.write(f"**Subtotal:** {key_fields.get('subtotal') or '—'}")
305
- with k3:
306
- st.write(f"**Tax:** {key_fields.get('tax') or '—'}")
307
- tot = key_fields.get('total') or '—'
308
- cur = key_fields.get('currency') or ''
309
- st.write(f"**Total:** {tot} {cur}".strip())
310
-
311
- with st.spinner("Detecting words with Tesseract (for table)…"):
312
- tsv = pytesseract.image_to_data(det_img, lang=det_lang, output_type=Output.DATAFRAME)
313
- tsv = tsv.dropna(subset=["text"]).reset_index(drop=True)
314
- tsv["x2"] = tsv["left"] + tsv["width"]
315
- tsv["y2"] = tsv["top"] + tsv["height"]
316
-
317
- st.markdown("**Line Items**")
318
- items = items_from_words_simple(tsv)
319
- if items.empty:
320
- st.caption("No line items confidently detected.")
321
- else:
322
- st.dataframe(items, use_container_width=True)
323
 
324
- result = {
325
- "file": up.name,
326
- "page": page_idx + 1,
327
- "key_fields": key_fields,
328
- "items": items.to_dict(orient="records") if not items.empty else [],
329
- "donut_raw": donut_payload,
330
- }
331
- st.download_button("Download JSON", data=json.dumps(result, indent=2),
332
- file_name="invoice_extraction.json", mime="application/json")
333
- if not items.empty:
334
- st.download_button("Download Items CSV", data=items.to_csv(index=False),
335
- file_name="invoice_items.csv", mime="text/csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- if show_boxes:
338
- st.caption("First 20 Tesseract word boxes")
339
- st.dataframe(tsv[["left","top","width","height","text","conf"]].head(20), use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
+ # Invoice -> JSON (Paste Text Only)
3
+ # Pipeline:
4
+ # 1) User pastes invoice text into a textbox (no OCR calls).
5
+ # 2) Extract naive key:value candidates.
6
+ # 3) Map candidates to your static schema headers with all-MiniLM-L6-v2.
7
+ # 4) Ask MD2JSON-T5-small-V1 to emit STRICT JSON using your schema (nulls if missing).
8
+ # If the model returns invalid JSON, we fall back to a schema-shaped object filled from mapped data (null where missing).
9
+
10
+ import os
11
+ import re
12
+ import json
13
+ from typing import List, Dict
14
 
 
 
15
  import numpy as np
 
 
 
16
  import streamlit as st
 
 
 
 
 
 
 
17
  import torch
18
+ from transformers import pipeline
19
+ from sentence_transformers import SentenceTransformer, util
20
+
21
+ st.set_page_config(page_title="Invoice → JSON (Paste Text) · MiniLM + MD2JSON", layout="wide")
22
+ st.title("Invoice JSON (Paste Text Only)")
23
+
24
+ # ----------------------------- Your schema (fixed) -----------------------------
25
+ SCHEMA_JSON = {
26
+ "invoice_header": {
27
+ "car_number": None,
28
+ "shipment_number": None,
29
+ "shipping_point": None,
30
+ "currency": None,
31
+ "invoice_number": None,
32
+ "invoice_date": None,
33
+ "order_number": None,
34
+ "customer_order_number": None,
35
+ "our_order_number": None,
36
+ "sales_order_number": None,
37
+ "purchase_order_number": None,
38
+ "order_date": None,
39
+ "supplier_name": None,
40
+ "supplier_address": None,
41
+ "supplier_phone": None,
42
+ "supplier_email": None,
43
+ "supplier_tax_id": None,
44
+ "customer_name": None,
45
+ "customer_address": None,
46
+ "customer_phone": None,
47
+ "customer_email": None,
48
+ "customer_tax_id": None,
49
+ "ship_to_name": None,
50
+ "ship_to_address": None,
51
+ "bill_to_name": None,
52
+ "bill_to_address": None,
53
+ "remit_to_name": None,
54
+ "remit_to_address": None,
55
+ "tax_id": None,
56
+ "tax_registration_number": None,
57
+ "vat_number": None,
58
+ "payment_terms": None,
59
+ "payment_method": None,
60
+ "payment_reference": None,
61
+ "bank_account_number": None,
62
+ "iban": None,
63
+ "swift_code": None,
64
+ "total_before_tax": None,
65
+ "tax_amount": None,
66
+ "tax_rate": None,
67
+ "shipping_charges": None,
68
+ "discount": None,
69
+ "total_due": None,
70
+ "amount_paid": None,
71
+ "balance_due": None,
72
+ "due_date": None,
73
+ "invoice_status": None,
74
+ "reference_number": None,
75
+ "project_code": None,
76
+ "department": None,
77
+ "contact_person": None,
78
+ "notes": None,
79
+ "additional_info": None
80
+ },
81
+ "line_items": [
82
+ {
83
+ "quantity": None,
84
+ "units": None,
85
+ "description": None,
86
+ "footage": None,
87
+ "price": None,
88
+ "amount": None,
89
+ "notes": None
90
+ }
91
+ ]
92
+ }
93
 
94
+ STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
 
 
 
 
 
 
 
 
 
95
 
96
+ # ----------------------------- Sidebar controls -----------------------------
97
+ st.sidebar.header("Settings")
98
+ threshold = st.sidebar.slider("Semantic match threshold (cosine)", 0.0, 1.0, 0.60, 0.01)
99
+ max_new_tokens = st.sidebar.slider("Max new tokens (MD2JSON)", 128, 2048, 512, 32)
100
+ show_intermediates = st.sidebar.checkbox("Show intermediate mapping/hints", value=False)
101
 
102
+ # ----------------------------- Model loaders (cached) -----------------------------
103
  @st.cache_resource(show_spinner=True)
104
+ def load_models():
105
+ sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
106
+ json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
107
+ return sentence_model, json_converter
108
+
109
+ sentence_model, json_converter = load_models()
110
+
111
+ # ----------------------------- Helpers -----------------------------
112
+ def extract_candidates(text: str) -> Dict[str, str]:
113
+ """
114
+ Very simple key:value candidate extraction from lines containing a colon.
115
+ Extend with regexes for phones/emails/totals as needed.
116
+ """
117
+ out: Dict[str, str] = {}
118
+ for raw in text.splitlines():
119
+ line = raw.strip()
120
+ if not line or ":" not in line:
121
+ continue
122
+ k, v = line.split(":", 1)
123
+ k, v = k.strip(), v.strip()
124
+ if k and v:
125
+ out[k] = v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return out
127
 
128
+ def map_to_schema_headers(candidates: Dict[str, str],
129
+ static_headers: List[str],
130
+ thresh: float) -> Dict[str, str]:
131
+ if not candidates:
132
+ return {}
133
+ cand_keys = list(candidates.keys())
134
+ # Embed all at once
135
+ cand_emb = sentence_model.encode(cand_keys, normalize_embeddings=True)
136
+ head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
137
+ cos = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy() # [Nc, Nh]
138
+ mapped: Dict[str, str] = {}
139
+ for i, key in enumerate(cand_keys):
140
+ j = int(np.argmax(cos[i]))
141
+ score = float(cos[i][j])
142
+ if score >= thresh:
143
+ mapped[static_headers[j]] = candidates[key]
144
+ return mapped
145
+
146
+ def build_prompt(schema_text: str, invoice_text: str, mapped_hints: Dict[str, str]) -> str:
147
+ """
148
+ Construct a strong instruction prompt for MD2JSON.
149
+ Includes your schema, strict rules, raw text, and optional hints from semantic mapping.
150
+ """
151
+ # Your strict instruction block:
152
+ instruction = (
153
+ 'Use this schema:\n'
154
+ '{\n'
155
+ ' "invoice_header": {\n'
156
+ ' "car_number": "string or null",\n'
157
+ ' "shipment_number": "string or null",\n'
158
+ ' "shipping_point": "string or null",\n'
159
+ ' "currency": "string or null",\n'
160
+ ' "invoice_number": "string or null",\n'
161
+ ' "invoice_date": "string or null",\n'
162
+ ' "order_number": "string or null",\n'
163
+ ' "customer_order_number": "string or null",\n'
164
+ ' "our_order_number": "string or null",\n'
165
+ ' "sales_order_number": "string or null",\n'
166
+ ' "purchase_order_number": "string or null",\n'
167
+ ' "order_date": "string or null",\n'
168
+ ' "supplier_name": "string or null",\n'
169
+ ' "supplier_address": "string or null",\n'
170
+ ' "supplier_phone": "string or null",\n'
171
+ ' "supplier_email": "string or null",\n'
172
+ ' "supplier_tax_id": "string or null",\n'
173
+ ' "customer_name": "string or null",\n'
174
+ ' "customer_address": "string or null",\n'
175
+ ' "customer_phone": "string or null",\n'
176
+ ' "customer_email": "string or null",\n'
177
+ ' "customer_tax_id": "string or null",\n'
178
+ ' "ship_to_name": "string or null",\n'
179
+ ' "ship_to_address": "string or null",\n'
180
+ ' "bill_to_name": "string or null",\n'
181
+ ' "bill_to_address": "string or null",\n'
182
+ ' "remit_to_name": "string or null",\n'
183
+ ' "remit_to_address": "string or null",\n'
184
+ ' "tax_id": "string or null",\n'
185
+ ' "tax_registration_number": "string or null",\n'
186
+ ' "vat_number": "string or null",\n'
187
+ ' "payment_terms": "string or null",\n'
188
+ ' "payment_method": "string or null",\n'
189
+ ' "payment_reference": "string or null",\n'
190
+ ' "bank_account_number": "string or null",\n'
191
+ ' "iban": "string or null",\n'
192
+ ' "swift_code": "string or null",\n'
193
+ ' "total_before_tax": "string or null",\n'
194
+ ' "tax_amount": "string or null",\n'
195
+ ' "tax_rate": "string or null",\n'
196
+ ' "shipping_charges": "string or null",\n'
197
+ ' "discount": "string or null",\n'
198
+ ' "total_due": "string or null",\n'
199
+ ' "amount_paid": "string or null",\n'
200
+ ' "balance_due": "string or null",\n'
201
+ ' "due_date": "string or null",\n'
202
+ ' "invoice_status": "string or null",\n'
203
+ ' "reference_number": "string or null",\n'
204
+ ' "project_code": "string or null",\n'
205
+ ' "department": "string or null",\n'
206
+ ' "contact_person": "string or null",\n'
207
+ ' "notes": "string or null",\n'
208
+ ' "additional_info": "string or null"\n'
209
+ ' },\n'
210
+ ' "line_items": [\n'
211
+ ' {\n'
212
+ ' "quantity": "string or null",\n'
213
+ ' "units": "string or null",\n'
214
+ ' "description": "string or null",\n'
215
+ ' "footage": "string or null",\n'
216
+ ' "price": "string or null",\n'
217
+ ' "amount": "string or null",\n'
218
+ ' "notes": "string or null"\n'
219
+ ' }\n'
220
+ ' ]\n'
221
+ '}\n'
222
+ 'If a field is missing for a line item or header, use null. '
223
+ 'Do not invent fields. Do not add any header or shipment data to any line item. '
224
+ 'Return ONLY the JSON object, no explanation.\n'
225
+ )
226
+ hint_block = ""
227
+ if mapped_hints:
228
+ # Provide soft hints to help the T5 model place values correctly (optional but useful)
229
+ hints_serialized = " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
230
+ hint_block = f"\nHints:\n{hints_serialized}\n"
231
+
232
+ return (
233
+ instruction +
234
+ "\nInvoice Text:\n" +
235
+ invoice_text.strip() +
236
+ hint_block
237
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ def strict_json_or_fallback(text: str, mapped: Dict[str, str]) -> Dict:
240
+ """Try to parse model output as JSON. If invalid, return a schema-shaped fallback using mapped values."""
241
+ # Try full parse
242
+ try:
243
+ return json.loads(text)
244
+ except Exception:
245
+ pass
246
+ # Try to extract the biggest {...} block
247
+ start = text.find("{")
248
+ end = text.rfind("}")
249
+ if start != -1 and end != -1 and end > start:
250
+ try:
251
+ return json.loads(text[start:end+1])
252
+ except Exception:
253
+ pass
254
+ # Fallback: shape like schema and fill mapped header fields; empty line_items
255
+ fallback = json.loads(json.dumps(SCHEMA_JSON)) # deep copy
256
+ for k, v in mapped.items():
257
+ if k in fallback["invoice_header"]:
258
+ fallback["invoice_header"][k] = v
259
+ # Leave line_items as a single template row with nulls
260
+ return fallback
261
+
262
+ # ----------------------------- UI: paste text -----------------------------
263
+ invoice_text = st.text_area(
264
+ "Paste the invoice text here (unstructured).",
265
+ height=320,
266
+ placeholder="Paste the invoice content (OCR/plain text) ..."
267
+ )
268
 
269
+ colA, colB = st.columns([1,1])
270
+
271
+ with colA:
272
+ if st.button("Generate JSON", type="primary", use_container_width=True):
273
+ if not invoice_text.strip():
274
+ st.error("Please paste the invoice text first.")
275
+ st.stop()
276
+
277
+ with st.spinner("Extracting candidates..."):
278
+ candidates = extract_candidates(invoice_text)
279
+
280
+ with st.spinner("Semantic mapping with all-MiniLM-L6-v2..."):
281
+ mapped = map_to_schema_headers(candidates, STATIC_HEADERS, threshold)
282
+
283
+ if show_intermediates:
284
+ st.subheader("Candidates (first 20)")
285
+ st.json(dict(list(candidates.items())[:20]))
286
+ st.subheader("Mapped headers")
287
+ st.json(mapped)
288
+
289
+ with st.spinner("Generating structured JSON with MD2JSON-T5-small-V1..."):
290
+ prompt = build_prompt(json.dumps(SCHEMA_JSON, indent=2), invoice_text, mapped)
291
+ gen = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
292
+ out = gen(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
293
+ final_json = strict_json_or_fallback(out, mapped)
294
+
295
+ st.subheader("Final JSON")
296
+ st.json(final_json)
297
+ st.download_button("Download JSON", data=json.dumps(final_json, indent=2),
298
+ file_name="invoice.json", mime="application/json", use_container_width=True)
299
+
300
+ with colB:
301
+ st.caption("Tips")
302
+ st.markdown(
303
+ "- You can tune the **semantic threshold** in the sidebar to make header mapping more/less strict.\n"
304
+ "- If the T5 model outputs invalid JSON, the app returns a **schema-shaped fallback** filled from the mapped fields.\n"
305
+ "- Extend `extract_candidates()` to add regexes for emails, phones, totals, etc., to boost recall."
306
+ )