# app.py # Invoice -> JSON (Paste Text Only) # Pipeline: # 1) User pastes invoice text into a textbox (no OCR calls). # 2) Extract naive key:value candidates. # 3) Map candidates to your static schema headers with all-MiniLM-L6-v2. # 4) Ask MD2JSON-T5-small-V1 to emit STRICT JSON using your schema (nulls if missing). # If the model returns invalid JSON, we fall back to a schema-shaped object filled from mapped data (null where missing). import os import re import json from typing import List, Dict import numpy as np import streamlit as st import torch from transformers import pipeline from sentence_transformers import SentenceTransformer, util st.set_page_config(page_title="Invoice → JSON (Paste Text) · MiniLM + MD2JSON", layout="wide") st.title("Invoice → JSON (Paste Text Only)") # ----------------------------- Your schema (fixed) ----------------------------- SCHEMA_JSON = { "invoice_header": { "car_number": None, "shipment_number": None, "shipping_point": None, "currency": None, "invoice_number": None, "invoice_date": None, "order_number": None, "customer_order_number": None, "our_order_number": None, "sales_order_number": None, "purchase_order_number": None, "order_date": None, "supplier_name": None, "supplier_address": None, "supplier_phone": None, "supplier_email": None, "supplier_tax_id": None, "customer_name": None, "customer_address": None, "customer_phone": None, "customer_email": None, "customer_tax_id": None, "ship_to_name": None, "ship_to_address": None, "bill_to_name": None, "bill_to_address": None, "remit_to_name": None, "remit_to_address": None, "tax_id": None, "tax_registration_number": None, "vat_number": None, "payment_terms": None, "payment_method": None, "payment_reference": None, "bank_account_number": None, "iban": None, "swift_code": None, "total_before_tax": None, "tax_amount": None, "tax_rate": None, "shipping_charges": None, "discount": None, "total_due": None, "amount_paid": None, "balance_due": None, "due_date": None, "invoice_status": None, "reference_number": None, "project_code": None, "department": None, "contact_person": None, "notes": None, "additional_info": None }, "line_items": [ { "quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None } ] } STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys()) # ----------------------------- Sidebar controls ----------------------------- st.sidebar.header("Settings") threshold = st.sidebar.slider("Semantic match threshold (cosine)", 0.0, 1.0, 0.60, 0.01) max_new_tokens = st.sidebar.slider("Max new tokens (MD2JSON)", 128, 2048, 512, 32) show_intermediates = st.sidebar.checkbox("Show intermediate mapping/hints", value=False) # ----------------------------- Model loaders (cached) ----------------------------- @st.cache_resource(show_spinner=True) def load_models(): sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1") return sentence_model, json_converter sentence_model, json_converter = load_models() # ----------------------------- Helpers ----------------------------- def extract_candidates(text: str) -> Dict[str, str]: """ Very simple key:value candidate extraction from lines containing a colon. Extend with regexes for phones/emails/totals as needed. """ out: Dict[str, str] = {} for raw in text.splitlines(): line = raw.strip() if not line or ":" not in line: continue k, v = line.split(":", 1) k, v = k.strip(), v.strip() if k and v: out[k] = v return out def map_to_schema_headers(candidates: Dict[str, str], static_headers: List[str], thresh: float) -> Dict[str, str]: if not candidates: return {} cand_keys = list(candidates.keys()) # Embed all at once cand_emb = sentence_model.encode(cand_keys, normalize_embeddings=True) head_emb = sentence_model.encode(static_headers, normalize_embeddings=True) cos = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy() # [Nc, Nh] mapped: Dict[str, str] = {} for i, key in enumerate(cand_keys): j = int(np.argmax(cos[i])) score = float(cos[i][j]) if score >= thresh: mapped[static_headers[j]] = candidates[key] return mapped def build_prompt(schema_text: str, invoice_text: str, mapped_hints: Dict[str, str]) -> str: """ Construct a strong instruction prompt for MD2JSON. Includes your schema, strict rules, raw text, and optional hints from semantic mapping. """ # Your strict instruction block: instruction = ( 'Use this schema:\n' '{\n' ' "invoice_header": {\n' ' "car_number": "string or null",\n' ' "shipment_number": "string or null",\n' ' "shipping_point": "string or null",\n' ' "currency": "string or null",\n' ' "invoice_number": "string or null",\n' ' "invoice_date": "string or null",\n' ' "order_number": "string or null",\n' ' "customer_order_number": "string or null",\n' ' "our_order_number": "string or null",\n' ' "sales_order_number": "string or null",\n' ' "purchase_order_number": "string or null",\n' ' "order_date": "string or null",\n' ' "supplier_name": "string or null",\n' ' "supplier_address": "string or null",\n' ' "supplier_phone": "string or null",\n' ' "supplier_email": "string or null",\n' ' "supplier_tax_id": "string or null",\n' ' "customer_name": "string or null",\n' ' "customer_address": "string or null",\n' ' "customer_phone": "string or null",\n' ' "customer_email": "string or null",\n' ' "customer_tax_id": "string or null",\n' ' "ship_to_name": "string or null",\n' ' "ship_to_address": "string or null",\n' ' "bill_to_name": "string or null",\n' ' "bill_to_address": "string or null",\n' ' "remit_to_name": "string or null",\n' ' "remit_to_address": "string or null",\n' ' "tax_id": "string or null",\n' ' "tax_registration_number": "string or null",\n' ' "vat_number": "string or null",\n' ' "payment_terms": "string or null",\n' ' "payment_method": "string or null",\n' ' "payment_reference": "string or null",\n' ' "bank_account_number": "string or null",\n' ' "iban": "string or null",\n' ' "swift_code": "string or null",\n' ' "total_before_tax": "string or null",\n' ' "tax_amount": "string or null",\n' ' "tax_rate": "string or null",\n' ' "shipping_charges": "string or null",\n' ' "discount": "string or null",\n' ' "total_due": "string or null",\n' ' "amount_paid": "string or null",\n' ' "balance_due": "string or null",\n' ' "due_date": "string or null",\n' ' "invoice_status": "string or null",\n' ' "reference_number": "string or null",\n' ' "project_code": "string or null",\n' ' "department": "string or null",\n' ' "contact_person": "string or null",\n' ' "notes": "string or null",\n' ' "additional_info": "string or null"\n' ' },\n' ' "line_items": [\n' ' {\n' ' "quantity": "string or null",\n' ' "units": "string or null",\n' ' "description": "string or null",\n' ' "footage": "string or null",\n' ' "price": "string or null",\n' ' "amount": "string or null",\n' ' "notes": "string or null"\n' ' }\n' ' ]\n' '}\n' 'If a field is missing for a line item or header, use null. ' 'Do not invent fields. Do not add any header or shipment data to any line item. ' 'Return ONLY the JSON object, no explanation.\n' ) hint_block = "" if mapped_hints: # Provide soft hints to help the T5 model place values correctly (optional but useful) hints_serialized = " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()]) hint_block = f"\nHints:\n{hints_serialized}\n" return ( instruction + "\nInvoice Text:\n" + invoice_text.strip() + hint_block ) def strict_json_or_fallback(text: str, mapped: Dict[str, str]) -> Dict: """Try to parse model output as JSON. If invalid, return a schema-shaped fallback using mapped values.""" # Try full parse try: return json.loads(text) except Exception: pass # Try to extract the biggest {...} block start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(text[start:end+1]) except Exception: pass # Fallback: shape like schema and fill mapped header fields; empty line_items fallback = json.loads(json.dumps(SCHEMA_JSON)) # deep copy for k, v in mapped.items(): if k in fallback["invoice_header"]: fallback["invoice_header"][k] = v # Leave line_items as a single template row with nulls return fallback # ----------------------------- UI: paste text ----------------------------- invoice_text = st.text_area( "Paste the invoice text here (unstructured).", height=320, placeholder="Paste the invoice content (OCR/plain text) ..." ) colA, colB = st.columns([1,1]) with colA: if st.button("Generate JSON", type="primary", use_container_width=True): if not invoice_text.strip(): st.error("Please paste the invoice text first.") st.stop() with st.spinner("Extracting candidates..."): candidates = extract_candidates(invoice_text) with st.spinner("Semantic mapping with all-MiniLM-L6-v2..."): mapped = map_to_schema_headers(candidates, STATIC_HEADERS, threshold) if show_intermediates: st.subheader("Candidates (first 20)") st.json(dict(list(candidates.items())[:20])) st.subheader("Mapped headers") st.json(mapped) with st.spinner("Generating structured JSON with MD2JSON-T5-small-V1..."): prompt = build_prompt(json.dumps(SCHEMA_JSON, indent=2), invoice_text, mapped) gen = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1") out = gen(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"] final_json = strict_json_or_fallback(out, mapped) st.subheader("Final JSON") st.json(final_json) st.download_button("Download JSON", data=json.dumps(final_json, indent=2), file_name="invoice.json", mime="application/json", use_container_width=True) with colB: st.caption("Tips") st.markdown( "- You can tune the **semantic threshold** in the sidebar to make header mapping more/less strict.\n" "- If the T5 model outputs invalid JSON, the app returns a **schema-shaped fallback** filled from the mapped fields.\n" "- Extend `extract_candidates()` to add regexes for emails, phones, totals, etc., to boost recall." )