|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, io, re, json, shutil |
|
|
from typing import List |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image, ImageOps, ImageFilter |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
import pytesseract |
|
|
from pytesseract import Output |
|
|
from pdf2image import convert_from_bytes |
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Invoice Extraction — Donut (public) + Tesseract tables", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
st.sidebar.header("Model (Hugging Face — public)") |
|
|
model_id = st.sidebar.text_input( |
|
|
"HF model id", |
|
|
value="naver-clova-ix/donut-base-finetuned-cord-v2", |
|
|
help="Use a public model id; this one works without token." |
|
|
) |
|
|
task_prompt = st.sidebar.text_input( |
|
|
"Task prompt (Donut)", |
|
|
value="<s_cord-v2>", |
|
|
help="Keep default for CORD-style invoices." |
|
|
) |
|
|
det_lang = st.sidebar.text_input("Tesseract language(s) — detection only", value="eng") |
|
|
show_boxes = st.sidebar.checkbox("Show word boxes (debug)", value=False) |
|
|
|
|
|
|
|
|
def _find_poppler_path(): |
|
|
|
|
|
if shutil.which("pdfinfo") and shutil.which("pdftoppm"): |
|
|
return None |
|
|
for p in ["/usr/bin", "/usr/local/bin", "/usr/share/bin"]: |
|
|
if os.path.exists(os.path.join(p, "pdfinfo")) and os.path.exists(os.path.join(p, "pdftoppm")): |
|
|
return p |
|
|
return None |
|
|
|
|
|
def _pages_via_pdf2image(file_bytes: bytes) -> List[Image.Image]: |
|
|
poppler_path = _find_poppler_path() |
|
|
if poppler_path: |
|
|
return convert_from_bytes(file_bytes, dpi=300, poppler_path=poppler_path) |
|
|
else: |
|
|
return convert_from_bytes(file_bytes, dpi=300) |
|
|
|
|
|
def _pages_via_pymupdf(file_bytes: bytes) -> List[Image.Image]: |
|
|
import fitz |
|
|
doc = fitz.open(stream=file_bytes, filetype="pdf") |
|
|
pages = [] |
|
|
for page in doc: |
|
|
|
|
|
pix = page.get_pixmap() |
|
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
|
pages.append(img) |
|
|
return pages |
|
|
|
|
|
def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]: |
|
|
name = (name or "").lower() |
|
|
if name.endswith(".pdf"): |
|
|
|
|
|
try: |
|
|
return _pages_via_pdf2image(file_bytes) |
|
|
except Exception: |
|
|
|
|
|
return _pages_via_pymupdf(file_bytes) |
|
|
return [Image.open(io.BytesIO(file_bytes)).convert("RGB")] |
|
|
|
|
|
def preprocess_for_detection(img: Image.Image) -> Image.Image: |
|
|
g = ImageOps.grayscale(img) |
|
|
g = ImageOps.autocontrast(g) |
|
|
g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3)) |
|
|
return g |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=True) |
|
|
def load_donut(_model_id: str): |
|
|
processor = DonutProcessor.from_pretrained(_model_id) |
|
|
model = VisionEncoderDecoderModel.from_pretrained(_model_id) |
|
|
model.to(device).eval() |
|
|
return processor, model |
|
|
|
|
|
def donut_infer(img: Image.Image, processor: DonutProcessor, model: VisionEncoderDecoderModel, prompt: str): |
|
|
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, max_length=1024, num_beams=1, early_stopping=True) |
|
|
seq = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
parsed = None |
|
|
try: |
|
|
start = seq.find("{") |
|
|
end = seq.rfind("}") |
|
|
if start != -1 and end != -1 and end > start: |
|
|
parsed = json.loads(seq[start:end+1]) |
|
|
except Exception: |
|
|
parsed = None |
|
|
return seq, parsed |
|
|
|
|
|
|
|
|
CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?" |
|
|
MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)" |
|
|
DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))" |
|
|
INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,}))" |
|
|
PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))" |
|
|
TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})" |
|
|
SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})" |
|
|
TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})" |
|
|
|
|
|
def parse_fields_regex(fulltext: str): |
|
|
t = re.sub(r"[ \t]+", " ", fulltext) |
|
|
t = re.sub(r"\n{2,}", "\n", t) |
|
|
out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None} |
|
|
m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None |
|
|
m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None |
|
|
m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I) |
|
|
out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None)) |
|
|
m = re.search(SUBTOTAL_PAT, t, re.I|re.S); |
|
|
if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] |
|
|
m = re.search(TAX_PAT, t, re.I|re.S); |
|
|
if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"] |
|
|
m = re.search(TOTAL_PAT, t, re.I|re.S); |
|
|
if m: |
|
|
out["total"], out["currency"] = m.group("amt").replace(",", ""), m.group("curr") or out["currency"] |
|
|
if out["currency"] in ["$", "C$", "€", "£"]: |
|
|
out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]] |
|
|
return out |
|
|
|
|
|
def normalize_kv_from_donut(parsed: dict): |
|
|
out = {k: None for k in ["invoice_number","invoice_date","po_number","subtotal","tax","total","currency"]} |
|
|
|
|
|
def search_keys(obj, key_list): |
|
|
if isinstance(obj, dict): |
|
|
for k, v in obj.items(): |
|
|
kl = k.lower() |
|
|
if any(kk in kl for kk in key_list): |
|
|
return v if isinstance(v, str) else None |
|
|
found = search_keys(v, key_list) |
|
|
if found is not None: |
|
|
return found |
|
|
elif isinstance(obj, list): |
|
|
for it in obj: |
|
|
found = search_keys(it, key_list) |
|
|
if found is not None: |
|
|
return found |
|
|
return None |
|
|
|
|
|
mapping = { |
|
|
"invoice_number": ["invoice_number","invoice no","invoice_no","invoice","inv_no"], |
|
|
"invoice_date": ["invoice_date","date","bill_date"], |
|
|
"po_number": ["po_number","po","purchase_order"], |
|
|
"subtotal": ["subtotal","sub_total"], |
|
|
"tax": ["tax","gst","vat","hst"], |
|
|
"total": ["total","amount_total","amount_due","grand_total"], |
|
|
} |
|
|
for k, keys in mapping.items(): |
|
|
val = search_keys(parsed, keys) |
|
|
if isinstance(val, str): |
|
|
out[k] = val.strip() |
|
|
|
|
|
txt = json.dumps(parsed, ensure_ascii=False) |
|
|
m = re.search(r"(USD|CAD|EUR|GBP|\$|C\$|€|£)", txt, re.I) |
|
|
if m: |
|
|
sym = m.group(1) |
|
|
out["currency"] = {"$":"USD","C$":"CAD","€":"EUR","£":"GBP"}.get(sym, sym.upper()) |
|
|
return out |
|
|
|
|
|
def items_from_words_simple(tsv: pd.DataFrame) -> pd.DataFrame: |
|
|
HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"] |
|
|
if tsv.empty: |
|
|
return pd.DataFrame() |
|
|
|
|
|
lines = [] |
|
|
for (b,p,l), g in tsv.groupby(["block_num","par_num","line_num"]): |
|
|
text = " ".join([w for w in g["text"].astype(str).tolist() if w.strip()]) |
|
|
if text.strip(): |
|
|
lines.append({ |
|
|
"block_num": b, "par_num": p, "line_num": l, |
|
|
"text": text.lower(), |
|
|
"top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(), |
|
|
"left": g["left"].min(), "right": (g["left"]+g["width"]).max() |
|
|
}) |
|
|
L = pd.DataFrame(lines) |
|
|
if L.empty: |
|
|
return pd.DataFrame() |
|
|
L["header_score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s)) |
|
|
hdrs = L[L["header_score"] >= 2].sort_values(["header_score","top"], ascending=[False,True]) |
|
|
if hdrs.empty: |
|
|
return pd.DataFrame() |
|
|
|
|
|
H = hdrs.iloc[0] |
|
|
header_top, header_bottom = H["top"], H["bottom"] |
|
|
|
|
|
header_words = tsv[(tsv["top"] >= header_top - 5) & ((tsv["top"] + tsv["height"]) <= header_bottom + 5)] |
|
|
header_words = header_words.sort_values("left") |
|
|
if header_words.empty: |
|
|
return pd.DataFrame() |
|
|
xs = header_words["left"].tolist() |
|
|
hdr_tokens = [t.lower() for t in header_words["text"].tolist()] |
|
|
|
|
|
below = tsv[tsv["top"] > header_bottom + 5].copy() |
|
|
totals_mask = below["text"].str.lower().str.contains( |
|
|
r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False |
|
|
) |
|
|
if totals_mask.any(): |
|
|
stop_y = below.loc[totals_mask, "top"].min() |
|
|
below = below[below["top"] < stop_y - 4] |
|
|
if below.empty: |
|
|
return pd.DataFrame() |
|
|
|
|
|
rows = [] |
|
|
for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]): |
|
|
g = g.sort_values("left") |
|
|
buckets = {i:[] for i in range(len(xs))} |
|
|
for _, w in g.iterrows(): |
|
|
if not str(w["text"]).strip(): |
|
|
continue |
|
|
idx = int(np.abs(np.array(xs) - w["left"]).argmin()) |
|
|
buckets[idx].append(str(w["text"])) |
|
|
vals = [" ".join(buckets[i]).strip() for i in range(len(xs))] |
|
|
if any(vals): |
|
|
rows.append(vals) |
|
|
if not rows: |
|
|
return pd.DataFrame() |
|
|
|
|
|
df_rows = pd.DataFrame(rows).fillna("") |
|
|
names = [] |
|
|
for i in range(df_rows.shape[1]): |
|
|
wl = hdr_tokens[i] if i < len(hdr_tokens) else f"col_{i}" |
|
|
if "desc" in wl or wl in ["item","description"]: |
|
|
names.append("description") |
|
|
elif wl in ["qty","quantity"]: |
|
|
names.append("quantity") |
|
|
elif "unit" in wl or "rate" in wl or "price" in wl: |
|
|
names.append("unit_price") |
|
|
elif "amount" in wl or "total" in wl: |
|
|
names.append("line_total") |
|
|
else: |
|
|
names.append(f"col_{i}") |
|
|
df_rows.columns = names |
|
|
df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")] |
|
|
return df_rows.reset_index(drop=True) |
|
|
|
|
|
|
|
|
st.title("Invoice Extraction — Donut (public) + Tesseract tables") |
|
|
|
|
|
up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"]) |
|
|
if not up: |
|
|
st.info("Upload a scanned invoice to begin.") |
|
|
st.stop() |
|
|
|
|
|
with st.spinner(f"Loading model '{model_id}' from Hugging Face…"): |
|
|
processor, donut_model = load_donut(model_id) |
|
|
|
|
|
pages = load_pages(up.read(), up.name) |
|
|
page_idx = 0 |
|
|
if len(pages) > 1: |
|
|
page_idx = st.number_input("Page", 1, len(pages), 1) - 1 |
|
|
img = pages[page_idx] |
|
|
|
|
|
col1, col2 = st.columns([1.1, 1.3], gap="large") |
|
|
|
|
|
with col1: |
|
|
st.subheader("Preview") |
|
|
st.image(img, use_column_width=True) |
|
|
det_img = preprocess_for_detection(img) |
|
|
with st.expander("Detection view (preprocessed for boxes)"): |
|
|
st.image(det_img, use_column_width=True) |
|
|
|
|
|
with col2: |
|
|
st.subheader("OCR & Extraction") |
|
|
|
|
|
with st.spinner("Running Donut…"): |
|
|
seq, parsed = donut_infer(img, processor, donut_model, task_prompt) |
|
|
|
|
|
if parsed: |
|
|
key_fields = normalize_kv_from_donut(parsed) |
|
|
donut_payload = parsed |
|
|
else: |
|
|
key_fields = parse_fields_regex(seq) |
|
|
donut_payload = {"generated_text": seq} |
|
|
|
|
|
k1,k2,k3 = st.columns(3) |
|
|
with k1: |
|
|
st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}") |
|
|
st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or '—'}") |
|
|
with k2: |
|
|
st.write(f"**PO #:** {key_fields.get('po_number') or '—'}") |
|
|
st.write(f"**Subtotal:** {key_fields.get('subtotal') or '—'}") |
|
|
with k3: |
|
|
st.write(f"**Tax:** {key_fields.get('tax') or '—'}") |
|
|
tot = key_fields.get('total') or '—' |
|
|
cur = key_fields.get('currency') or '' |
|
|
st.write(f"**Total:** {tot} {cur}".strip()) |
|
|
|
|
|
with st.spinner("Detecting words with Tesseract (for table)…"): |
|
|
tsv = pytesseract.image_to_data(det_img, lang=det_lang, output_type=Output.DATAFRAME) |
|
|
tsv = tsv.dropna(subset=["text"]).reset_index(drop=True) |
|
|
tsv["x2"] = tsv["left"] + tsv["width"] |
|
|
tsv["y2"] = tsv["top"] + tsv["height"] |
|
|
|
|
|
st.markdown("**Line Items**") |
|
|
items = items_from_words_simple(tsv) |
|
|
if items.empty: |
|
|
st.caption("No line items confidently detected.") |
|
|
else: |
|
|
st.dataframe(items, use_container_width=True) |
|
|
|
|
|
result = { |
|
|
"file": up.name, |
|
|
"page": page_idx + 1, |
|
|
"key_fields": key_fields, |
|
|
"items": items.to_dict(orient="records") if not items.empty else [], |
|
|
"donut_raw": donut_payload, |
|
|
} |
|
|
st.download_button("Download JSON", data=json.dumps(result, indent=2), |
|
|
file_name="invoice_extraction.json", mime="application/json") |
|
|
if not items.empty: |
|
|
st.download_button("Download Items CSV", data=items.to_csv(index=False), |
|
|
file_name="invoice_items.csv", mime="text/csv") |
|
|
|
|
|
if show_boxes: |
|
|
st.caption("First 20 Tesseract word boxes") |
|
|
st.dataframe(tsv[["left","top","width","height","text","conf"]].head(20), use_container_width=True) |
|
|
|