Seth0330 commited on
Commit
e0f5793
·
verified ·
1 Parent(s): 07c6631

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +425 -112
app.py CHANGED
@@ -1,16 +1,16 @@
1
  # app.py
2
- # Invoice -> JSON (Paste Text Only)
3
- # Pipeline:
4
- # 1) User pastes invoice text into a textbox (no OCR calls).
5
- # 2) Extract naive key:value candidates.
6
- # 3) Map candidates to your static schema headers with all-MiniLM-L6-v2.
7
- # 4) Ask MD2JSON-T5-small-V1 to emit STRICT JSON using your schema (nulls if missing).
8
- # If the model returns invalid JSON, we fall back to a schema-shaped object filled from mapped data (null where missing).
9
-
10
- import os
11
  import re
12
  import json
13
- from typing import List, Dict
 
14
 
15
  import numpy as np
16
  import streamlit as st
@@ -18,11 +18,11 @@ import torch
18
  from transformers import pipeline
19
  from sentence_transformers import SentenceTransformer, util
20
 
21
- st.set_page_config(page_title="Invoice → JSON (Paste Text) · MiniLM + MD2JSON", layout="wide")
22
- st.title("Invoice → JSON (Paste Text Only)")
23
 
24
- # ----------------------------- Your schema (fixed) -----------------------------
25
- SCHEMA_JSON = {
26
  "invoice_header": {
27
  "car_number": None,
28
  "shipment_number": None,
@@ -90,65 +90,330 @@ SCHEMA_JSON = {
90
  }
91
  ]
92
  }
93
-
94
  STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
95
 
96
- # ----------------------------- Sidebar controls -----------------------------
97
  st.sidebar.header("Settings")
98
  threshold = st.sidebar.slider("Semantic match threshold (cosine)", 0.0, 1.0, 0.60, 0.01)
99
  max_new_tokens = st.sidebar.slider("Max new tokens (MD2JSON)", 128, 2048, 512, 32)
100
- show_intermediates = st.sidebar.checkbox("Show intermediate mapping/hints", value=False)
101
 
102
- # ----------------------------- Model loaders (cached) -----------------------------
103
  @st.cache_resource(show_spinner=True)
104
  def load_models():
105
  sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
106
  json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
107
  return sentence_model, json_converter
108
-
109
  sentence_model, json_converter = load_models()
110
 
111
- # ----------------------------- Helpers -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def extract_candidates(text: str) -> Dict[str, str]:
113
  """
114
- Very simple key:value candidate extraction from lines containing a colon.
115
- Extend with regexes for phones/emails/totals as needed.
 
 
116
  """
117
- out: Dict[str, str] = {}
 
 
118
  for raw in text.splitlines():
119
- line = raw.strip()
120
- if not line or ":" not in line:
121
  continue
122
- k, v = line.split(":", 1)
123
- k, v = k.strip(), v.strip()
124
- if k and v:
125
- out[k] = v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return out
127
 
128
- def map_to_schema_headers(candidates: Dict[str, str],
129
- static_headers: List[str],
130
- thresh: float) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if not candidates:
132
  return {}
133
  cand_keys = list(candidates.keys())
134
- # Embed all at once
135
- cand_emb = sentence_model.encode(cand_keys, normalize_embeddings=True)
136
- head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
137
- cos = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy() # [Nc, Nh]
138
  mapped: Dict[str, str] = {}
139
- for i, key in enumerate(cand_keys):
140
- j = int(np.argmax(cos[i]))
141
- score = float(cos[i][j])
142
- if score >= thresh:
143
- mapped[static_headers[j]] = candidates[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return mapped
145
 
146
- def build_prompt(schema_text: str, invoice_text: str, mapped_hints: Dict[str, str]) -> str:
147
- """
148
- Construct a strong instruction prompt for MD2JSON.
149
- Includes your schema, strict rules, raw text, and optional hints from semantic mapping.
150
- """
151
- # Your strict instruction block:
152
  instruction = (
153
  'Use this schema:\n'
154
  '{\n'
@@ -223,84 +488,132 @@ def build_prompt(schema_text: str, invoice_text: str, mapped_hints: Dict[str, st
223
  'Do not invent fields. Do not add any header or shipment data to any line item. '
224
  'Return ONLY the JSON object, no explanation.\n'
225
  )
226
- hint_block = ""
227
  if mapped_hints:
228
- # Provide soft hints to help the T5 model place values correctly (optional but useful)
229
- hints_serialized = " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
230
- hint_block = f"\nHints:\n{hints_serialized}\n"
231
-
232
- return (
233
- instruction +
234
- "\nInvoice Text:\n" +
235
- invoice_text.strip() +
236
- hint_block
237
- )
238
 
239
- def strict_json_or_fallback(text: str, mapped: Dict[str, str]) -> Dict:
240
- """Try to parse model output as JSON. If invalid, return a schema-shaped fallback using mapped values."""
241
- # Try full parse
242
  try:
243
  return json.loads(text)
244
- except Exception:
245
  pass
246
- # Try to extract the biggest {...} block
247
  start = text.find("{")
248
  end = text.rfind("}")
249
  if start != -1 and end != -1 and end > start:
250
  try:
251
  return json.loads(text[start:end+1])
252
- except Exception:
253
  pass
254
- # Fallback: shape like schema and fill mapped header fields; empty line_items
255
- fallback = json.loads(json.dumps(SCHEMA_JSON)) # deep copy
256
- for k, v in mapped.items():
257
- if k in fallback["invoice_header"]:
258
- fallback["invoice_header"][k] = v
259
- # Leave line_items as a single template row with nulls
260
- return fallback
261
-
262
- # ----------------------------- UI: paste text -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  invoice_text = st.text_area(
264
- "Paste the invoice text here (unstructured).",
265
  height=320,
266
  placeholder="Paste the invoice content (OCR/plain text) ..."
267
  )
268
 
269
- colA, colB = st.columns([1,1])
270
-
271
- with colA:
272
- if st.button("Generate JSON", type="primary", use_container_width=True):
273
- if not invoice_text.strip():
274
- st.error("Please paste the invoice text first.")
275
- st.stop()
276
-
277
- with st.spinner("Extracting candidates..."):
278
- candidates = extract_candidates(invoice_text)
279
-
280
- with st.spinner("Semantic mapping with all-MiniLM-L6-v2..."):
281
- mapped = map_to_schema_headers(candidates, STATIC_HEADERS, threshold)
282
-
283
- if show_intermediates:
284
- st.subheader("Candidates (first 20)")
285
- st.json(dict(list(candidates.items())[:20]))
286
- st.subheader("Mapped headers")
287
- st.json(mapped)
288
-
289
- with st.spinner("Generating structured JSON with MD2JSON-T5-small-V1..."):
290
- prompt = build_prompt(json.dumps(SCHEMA_JSON, indent=2), invoice_text, mapped)
291
- gen = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
292
- out = gen(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
293
- final_json = strict_json_or_fallback(out, mapped)
294
-
295
- st.subheader("Final JSON")
296
- st.json(final_json)
297
- st.download_button("Download JSON", data=json.dumps(final_json, indent=2),
298
- file_name="invoice.json", mime="application/json", use_container_width=True)
299
-
300
- with colB:
301
- st.caption("Tips")
302
- st.markdown(
303
- "- You can tune the **semantic threshold** in the sidebar to make header mapping more/less strict.\n"
304
- "- If the T5 model outputs invalid JSON, the app returns a **schema-shaped fallback** filled from the mapped fields.\n"
305
- "- Extend `extract_candidates()` to add regexes for emails, phones, totals, etc., to boost recall."
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
+ # Invoice -> JSON (Paste Text Only) with better accuracy:
3
+ # - Pipe-table aware parsing
4
+ # - Regex extractors for common headers (Invoice No, Dates, PO, totals, taxes, GSTIN, etc.)
5
+ # - Line-item table parser (SNO, Description, Qty, UOM, Rate, Total Value)
6
+ # - Synonym dictionary -> canonical schema keys
7
+ # - Semantic mapping (MiniLM) for leftovers
8
+ # - MD2JSON prompt with strong hints; final schema = RULES MODEL (model cannot remove found values)
9
+
 
10
  import re
11
  import json
12
+ from typing import List, Dict, Any, Tuple
13
+ import copy
14
 
15
  import numpy as np
16
  import streamlit as st
 
18
  from transformers import pipeline
19
  from sentence_transformers import SentenceTransformer, util
20
 
21
+ st.set_page_config(page_title="Invoice → JSON (Paste Text) · Accurate v2", layout="wide")
22
+ st.title("Invoice → JSON (Paste Text) — Accurate v2")
23
 
24
+ # ----------------------------- Schema -----------------------------
25
+ SCHEMA_JSON: Dict[str, Any] = {
26
  "invoice_header": {
27
  "car_number": None,
28
  "shipment_number": None,
 
90
  }
91
  ]
92
  }
 
93
  STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
94
 
95
+ # ----------------------------- Sidebar -----------------------------
96
  st.sidebar.header("Settings")
97
  threshold = st.sidebar.slider("Semantic match threshold (cosine)", 0.0, 1.0, 0.60, 0.01)
98
  max_new_tokens = st.sidebar.slider("Max new tokens (MD2JSON)", 128, 2048, 512, 32)
99
+ show_intermediates = st.sidebar.checkbox("Show intermediates", value=True)
100
 
101
+ # ----------------------------- Models (cached) -----------------------------
102
  @st.cache_resource(show_spinner=True)
103
  def load_models():
104
  sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
105
  json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
106
  return sentence_model, json_converter
 
107
  sentence_model, json_converter = load_models()
108
 
109
+ # ----------------------------- Synonym map -> schema keys -----------------------------
110
+ SYN2KEY: Dict[str, str] = {
111
+ # direct header synonyms
112
+ "invoice no": "invoice_number",
113
+ "invoice number": "invoice_number",
114
+ "invoice#": "invoice_number",
115
+ "inv no": "invoice_number",
116
+ "inv#": "invoice_number",
117
+
118
+ "invoice date": "invoice_date",
119
+ "date of invoice": "invoice_date",
120
+
121
+ "po no": "purchase_order_number",
122
+ "po number": "purchase_order_number",
123
+ "purchase order": "purchase_order_number",
124
+ "order no": "order_number",
125
+ "order number": "order_number",
126
+ "sales order": "sales_order_number",
127
+ "customer order": "customer_order_number",
128
+ "our order": "our_order_number",
129
+
130
+ "due date": "due_date",
131
+ "date of supply": "order_date",
132
+
133
+ "gstin": "supplier_tax_id",
134
+ "gstin no": "supplier_tax_id",
135
+ "tax id": "tax_id",
136
+ "vat number": "vat_number",
137
+ "tax registration number": "tax_registration_number",
138
+
139
+ "place of supply": "shipping_point",
140
+ "state code": "additional_info", # keep if you prefer a specific field
141
+
142
+ "taxable value": "total_before_tax",
143
+ "total value": "total_due",
144
+ "total amount": "total_due",
145
+ "amount due": "total_due",
146
+
147
+ "bank": "bank_account_number", # we’ll fix value using bank block parsing
148
+ "account no": "bank_account_number",
149
+ "account number": "bank_account_number",
150
+ "ifs code": "swift_code", # India: really IFSC; we’ll drop it into 'payment_reference' or keep separate
151
+ "ifsc": "payment_reference",
152
+ "swift code": "swift_code",
153
+ "iban": "iban",
154
+
155
+ "e-way bill no": "reference_number",
156
+ "eway bill": "reference_number",
157
+
158
+ "dispatched via": "additional_info",
159
+ "documents dispatched through": "additional_info",
160
+ "kind attn": "contact_person",
161
+
162
+ # parties
163
+ "billed to": "bill_to_name",
164
+ "receiver": "bill_to_name",
165
+ "shipped to": "ship_to_name",
166
+ "consignee": "ship_to_name",
167
+ }
168
+
169
+ # ----------------------------- Utilities -----------------------------
170
+ def norm(s: str) -> str:
171
+ return re.sub(r"\s+", " ", s).strip()
172
+
173
+ def to_lower(s: str) -> str:
174
+ return s.lower().strip()
175
+
176
+ def deep_copy_schema() -> Dict[str, Any]:
177
+ return json.loads(json.dumps(SCHEMA_JSON))
178
+
179
+ # ----------------------------- Pipe-table aware candidate extractor -----------------------------
180
  def extract_candidates(text: str) -> Dict[str, str]:
181
  """
182
+ Build candidates from:
183
+ 1) colon lines: Key: Value
184
+ 2) pipe rows: | ... | ... | (pick obvious key:value pairs like "Invoice No: X" inside cells)
185
+ 3) single-value lines for totals (Taxable Value, Total, etc.)
186
  """
187
+ cands: Dict[str, str] = {}
188
+
189
+ # 1) colon lines
190
  for raw in text.splitlines():
191
+ line = raw.strip().strip("|").strip()
192
+ if not line:
193
  continue
194
+ if ":" in line:
195
+ # multiple '|'? try to split cells and parse each cell
196
+ if "|" in raw:
197
+ parts = [p.strip() for p in raw.split("|") if p.strip()]
198
+ for cell in parts:
199
+ if ":" in cell:
200
+ k, v = cell.split(":", 1)
201
+ cands[norm(k)] = norm(v)
202
+ else:
203
+ k, v = line.split(":", 1)
204
+ cands[norm(k)] = norm(v)
205
+
206
+ # 2) rows with ' | ' patterns but without colon in cells (rare)
207
+ for raw in text.splitlines():
208
+ if "|" in raw and ":" not in raw:
209
+ parts = [p.strip() for p in raw.split("|") if p.strip() and not set(p.strip()) <= set("-")]
210
+ # Heuristic: e.g., ["Dispatched Via","From","To","Under","No","Dated","Freight","Freight Amount"]
211
+ # Hard to build k:v reliably here without a header row + next row; we skip unless obvious.
212
+
213
+ # 3) totals without colon (e.g., "Taxable Value: 201801.60" already handled; but catch "Taxable Value 201801.60")
214
+ for raw in text.splitlines():
215
+ m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
216
+ if m:
217
+ k = norm(m.group(1))
218
+ v = norm(m.group(2))
219
+ cands[k] = v
220
+
221
+ return cands
222
+
223
+ # ----------------------------- Regex “hard extractors” -----------------------------
224
+ def regex_extract_all(text: str) -> Dict[str, str]:
225
+ out: Dict[str, str] = {}
226
+
227
+ # Invoice number
228
+ m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
229
+ if m: out["invoice_number"] = m.group(1)
230
+
231
+ # Invoice date (DD-MM-YYYY or similar)
232
+ m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
233
+ if m: out["invoice_date"] = m.group(1)
234
+
235
+ # PO number + date
236
+ m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
237
+ if m: out["purchase_order_number"] = m.group(1)
238
+ m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
239
+ if m: out["order_date"] = m.group(1)
240
+
241
+ # Date of Supply -> order_date (if not already)
242
+ if "order_date" not in out:
243
+ m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
244
+ if m: out["order_date"] = m.group(1)
245
+
246
+ # Place of Supply -> shipping_point
247
+ m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
248
+ if m: out["shipping_point"] = m.group(1).strip(" |")
249
+
250
+ # GSTIN (take the first)
251
+ m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
252
+ if m: out["supplier_tax_id"] = m.group(1)
253
+
254
+ # Taxable Value -> total_before_tax
255
+ m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
256
+ if m: out["total_before_tax"] = m.group(1).replace(",", "")
257
+
258
+ # CGST/SGST values -> tax_amount (sum)
259
+ cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
260
+ sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
261
+ if cgst and sgst:
262
+ try:
263
+ tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
264
+ out["tax_amount"] = f"{tax_total:.2f}"
265
+ # Tax rate (if both % available and equal, set combined)
266
+ cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
267
+ sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
268
+ if cgstp and sgstp:
269
+ try:
270
+ rate = float(cgstp.group(1)) + float(sgstp.group(1))
271
+ out["tax_rate"] = f"{rate:g}"
272
+ except:
273
+ pass
274
+ except:
275
+ pass
276
+
277
+ # E-Way bill -> reference_number
278
+ m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
279
+ if m: out["reference_number"] = m.group(1).strip()
280
+
281
  return out
282
 
283
+ # ----------------------------- Bank block parsing -----------------------------
284
+ def extract_bank_block(text: str) -> Dict[str, str]:
285
+ bank: Dict[str, str] = {}
286
+ # account name
287
+ m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
288
+ if m: bank["supplier_name"] = m.group(1).strip()
289
+
290
+ # account no
291
+ m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
292
+ if m: bank["bank_account_number"] = m.group(1).strip()
293
+
294
+ # bank name
295
+ m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
296
+ if m:
297
+ # place bank name into additional_info to avoid overwriting bank_account_number
298
+ bank["additional_info"] = ("Bank: " + m.group(1).strip())
299
+
300
+ # IFSC/IFS Code
301
+ m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
302
+ if m: bank["payment_reference"] = m.group(1).strip()
303
+
304
+ # SWIFT
305
+ m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
306
+ if m: bank["swift_code"] = m.group(1).strip()
307
+
308
+ # Branch / MICR etc -> additional_info
309
+ branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
310
+ micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
311
+ extra_bits = []
312
+ if branch: extra_bits.append("Branch: " + branch.group(1).strip())
313
+ if micr: extra_bits.append("MICR: " + micr.group(1).strip())
314
+ if extra_bits:
315
+ bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
316
+ return bank
317
+
318
+ # ----------------------------- Line-item parser (from table) -----------------------------
319
+ def parse_line_items(text: str) -> List[Dict[str, Any]]:
320
+ """
321
+ Parse a classic table with header like:
322
+ | SNO | Description | HSN/SAC | Qty | UOM | Rate | ... | Total Value |
323
+ """
324
+ items: List[Dict[str, Any]] = []
325
+ lines = [ln for ln in text.splitlines() if ln.strip()]
326
+ # find header row index
327
+ header_idx = -1
328
+ for i, ln in enumerate(lines):
329
+ if ("|") in ln and ("Description" in ln and ("Qty" in ln or "QTY" in ln)) and ("Rate" in ln or "Price" in ln) and ("Total" in ln):
330
+ header_idx = i
331
+ break
332
+ if header_idx == -1:
333
+ return items
334
+
335
+ # parse header cells
336
+ headers = [c.strip().lower() for c in lines[header_idx].split("|")]
337
+ # clean
338
+ headers = [h for h in headers if h and set(h) - set("-")]
339
+
340
+ # parse body until a blank line or a non-table line
341
+ for j in range(header_idx + 1, len(lines)):
342
+ row = lines[j]
343
+ if row.strip().startswith("|") and row.count("|") >= 2:
344
+ cells = [c.strip() for c in row.split("|")]
345
+ cells = [c for c in cells if c and set(c) - set("-")]
346
+ if len(cells) < 3:
347
+ continue
348
+ # map to our schema per best-effort
349
+ rowd = {"quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None}
350
+ # Try to find index of each logical column
351
+ def idx_of(name_parts: List[str]) -> int:
352
+ for k, h in enumerate(headers):
353
+ if any(p in h for p in name_parts):
354
+ return k
355
+ return -1
356
+ i_desc = idx_of(["description", "item"])
357
+ i_qty = idx_of(["qty", "quantity"])
358
+ i_uom = idx_of(["uom", "unit"])
359
+ i_rate = idx_of(["rate", "price"])
360
+ i_amt = idx_of(["total value", "amount", "total"])
361
+
362
+ # safe get
363
+ def safe(i: int) -> str:
364
+ return cells[i] if 0 <= i < len(cells) else ""
365
+
366
+ if i_desc != -1: rowd["description"] = safe(i_desc) or None
367
+ if i_qty != -1: rowd["quantity"] = safe(i_qty) or None
368
+ if i_uom != -1: rowd["units"] = safe(i_uom) or None
369
+ if i_rate != -1: rowd["price"] = safe(i_rate) or None
370
+ if i_amt != -1: rowd["amount"] = safe(i_amt) or None
371
+
372
+ # optional: footage if present in desc like "60.000 mtrs"
373
+ if rowd["units"] and rowd["quantity"]:
374
+ rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
375
+ items.append(rowd)
376
+ else:
377
+ # stop at first non-table line after header
378
+ if j > header_idx + 1:
379
+ break
380
+ return items
381
+
382
+ # ----------------------------- Semantic mapping for leftovers -----------------------------
383
+ def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float) -> Dict[str, str]:
384
  if not candidates:
385
  return {}
386
  cand_keys = list(candidates.keys())
387
+ # synonym pass first
 
 
 
388
  mapped: Dict[str, str] = {}
389
+ leftovers: Dict[str, str] = {}
390
+ for k, v in candidates.items():
391
+ lk = k.lower()
392
+ lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
393
+ hit = None
394
+ for syn, key in SYN2KEY.items():
395
+ if syn in lk_norm:
396
+ hit = key
397
+ break
398
+ if hit:
399
+ mapped[hit] = v
400
+ else:
401
+ leftovers[k] = v
402
+
403
+ if leftovers:
404
+ cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
405
+ head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
406
+ M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
407
+ keys_left = list(leftovers.keys())
408
+ for i, ck in enumerate(keys_left):
409
+ j = int(np.argmax(M[i]))
410
+ score = float(M[i][j])
411
+ if score >= thresh:
412
+ mapped[static_headers[j]] = leftovers[ck]
413
  return mapped
414
 
415
+ # ----------------------------- Build MD2JSON prompt -----------------------------
416
+ def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
 
 
 
 
417
  instruction = (
418
  'Use this schema:\n'
419
  '{\n'
 
488
  'Do not invent fields. Do not add any header or shipment data to any line item. '
489
  'Return ONLY the JSON object, no explanation.\n'
490
  )
491
+ hints = ""
492
  if mapped_hints:
493
+ hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
494
+ if items_hints:
495
+ try:
496
+ hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
497
+ except:
498
+ pass
499
+
500
+ return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
 
 
501
 
502
+ def strict_json(text: str) -> Dict[str, Any]:
503
+ # try direct
 
504
  try:
505
  return json.loads(text)
506
+ except:
507
  pass
508
+ # extract largest {...}
509
  start = text.find("{")
510
  end = text.rfind("}")
511
  if start != -1 and end != -1 and end > start:
512
  try:
513
  return json.loads(text[start:end+1])
514
+ except:
515
  pass
516
+ raise ValueError("Model did not return valid JSON.")
517
+
518
+ # ----------------------------- Final merge policy -----------------------------
519
+ def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
520
+ """
521
+ RULES WIN: Keep everything we extracted deterministically; fill only missing (None) from model.
522
+ """
523
+ final = copy.deepcopy(rule_json)
524
+
525
+ # header
526
+ hdr = final["invoice_header"]
527
+ mdl_hdr = (model_json.get("invoice_header") or {})
528
+ for k in hdr.keys():
529
+ if hdr[k] in [None, "", "null"]:
530
+ v = mdl_hdr.get(k, None)
531
+ if v not in [None, "", "null"]:
532
+ hdr[k] = v
533
+
534
+ # line_items: if we got some via rules, keep them; else take model's
535
+ if final["line_items"] and any(any(v for v in row.values() if v not in [None, "", "null"]) for row in final["line_items"]):
536
+ pass
537
+ else:
538
+ mdl_items = model_json.get("line_items")
539
+ if isinstance(mdl_items, list) and mdl_items:
540
+ final["line_items"] = mdl_items
541
+ else:
542
+ # keep template with nulls
543
+ pass
544
+
545
+ return final
546
+
547
+ # ----------------------------- UI -----------------------------
548
  invoice_text = st.text_area(
549
+ "Paste the invoice text here.",
550
  height=320,
551
  placeholder="Paste the invoice content (OCR/plain text) ..."
552
  )
553
 
554
+ if st.button("Generate JSON", type="primary", use_container_width=True):
555
+ if not invoice_text.strip():
556
+ st.error("Please paste the invoice text first.")
557
+ st.stop()
558
+
559
+ txt = invoice_text
560
+
561
+ # 1) Deterministic extraction
562
+ # 1a) candidates (pipe-table aware)
563
+ candidates = extract_candidates(txt)
564
+
565
+ # 1b) regex “hard” fields
566
+ hard = regex_extract_all(txt)
567
+
568
+ # 1c) bank block
569
+ bank = extract_bank_block(txt)
570
+
571
+ # 1d) line items from table
572
+ items = parse_line_items(txt)
573
+
574
+ # 1e) map candidates (synonyms + semantic) to schema headers
575
+ sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold)
576
+
577
+ # 1f) combine deterministic header fields
578
+ header_found: Dict[str, Any] = {}
579
+ header_found.update(sem_mapped)
580
+ header_found.update(hard)
581
+ header_found.update(bank)
582
+
583
+ # 2) Build RULE JSON (schema-shaped, rules filled)
584
+ rule_json = deep_copy_schema()
585
+ for k, v in header_found.items():
586
+ if k in rule_json["invoice_header"]:
587
+ rule_json["invoice_header"][k] = v
588
+ # line items
589
+ if items:
590
+ rule_json["line_items"] = items
591
+
592
+ if show_intermediates:
593
+ st.subheader("Candidates (first 20)")
594
+ st.json(dict(list(candidates.items())[:20]))
595
+ st.subheader("Regex/Hard fields")
596
+ st.json(hard)
597
+ st.subheader("Bank block")
598
+ st.json(bank)
599
+ st.subheader("Semantic-mapped headers")
600
+ st.json(sem_mapped)
601
+ st.subheader("Line items (parsed)")
602
+ st.json(items)
603
+
604
+ # 3) MD2JSON generation with strong hints
605
+ with st.spinner("Generating structured JSON with MD2JSON-T5-small-V1..."):
606
+ prompt = build_prompt(txt, header_found, items)
607
+ gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
608
+ try:
609
+ model_json = strict_json(gen)
610
+ except:
611
+ model_json = deep_copy_schema() # model failed; keep empty shape
612
+
613
+ # 4) Final merge (rules win)
614
+ final_json = merge_schema(rule_json, model_json)
615
+
616
+ st.subheader("Final JSON")
617
+ st.json(final_json)
618
+ st.download_button("Download JSON", data=json.dumps(final_json, indent=2),
619
+ file_name="invoice.json", mime="application/json", use_container_width=True)