yashita13 commited on
Commit
b36eae9
·
verified ·
1 Parent(s): ba2d57a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +231 -227
pipeline.py CHANGED
@@ -1,227 +1,231 @@
1
- import regex as re
2
- import unicodedata
3
- from collections import Counter
4
- from pathlib import Path
5
- import io
6
- import argparse
7
- import os
8
- import pymupdf
9
- from PIL import Image, ImageOps, ImageFilter
10
- import pytesseract
11
- import spacy
12
- import torch
13
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
- from pytesseract import Output
15
- from ner_functions import ner_extraction_multilingual, get_deadline, get_financial_details
16
- import gen_ai1
17
-
18
- # ==================== CONFIGURATION ====================
19
- MAX_CHUNK_TOKENS = 256
20
- CHUNK_TOKEN_OVERLAP = 40
21
- CLASSIFICATION_MODEL_NAME = "models/final" # Updated model path
22
-
23
- CURRENT_DIR = Path(__file__).resolve().parent
24
- LOCAL_CLF_DIR = CURRENT_DIR / "models" / "final"
25
-
26
- classification_dept_map = {
27
- 0: "Engineering",
28
- 1: "Finance",
29
- 2: "HR",
30
- 3: "Maintenance",
31
- 4: "Operations",
32
- }
33
-
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- NLP_MODEL = spacy.load("en_core_web_md")
36
-
37
-
38
- def clean_text_multilingual(text):
39
- # text = unicodedata.normalize("NFKC", text)
40
- # #text = re.sub(r"[^A-Za-z0-9\s.,;:!?()'\-\"@%$&]", " ", text)
41
- # text = re.sub(r"[^\p{L}\p{N}\s.,;:!?()'\-\"@%$&]", " ", text)
42
- text = re.sub(r'\s+', ' ', text)
43
- return text.strip()
44
-
45
-
46
- def extract_page_text(page, doc):
47
- raw_text = ""
48
- for block in page.get_text("blocks"):
49
- txt = block[4].strip()
50
- if txt:
51
- raw_text += " " + txt
52
-
53
- images = page.get_images(full=True)
54
- for img in images:
55
- xref = img[0]
56
- try:
57
- img_data = doc.extract_image(xref)
58
- image = Image.open(io.BytesIO(img_data["image"]))
59
- except Exception:
60
- print(f"[WARNING] Image extraction failed")
61
- continue
62
-
63
- filtered = image.filter(ImageFilter.MedianFilter(size=3))
64
- gray = ImageOps.grayscale(filtered)
65
- scale = 300 / 72
66
- base_w = min(int(gray.width * scale), 1000)
67
- base_h = min(int(gray.height * scale), 1000)
68
- gray_resized = gray.resize((base_w, base_h), Image.LANCZOS)
69
-
70
- try:
71
- ocr_text = pytesseract.image_to_string(gray_resized, lang="eng+mal")
72
- raw_text += " " + ocr_text
73
- except Exception:
74
- continue
75
-
76
- return raw_text.strip()
77
-
78
-
79
- def chunk_text_tokenwise(text, tokenizer, max_tokens=MAX_CHUNK_TOKENS, overlap=CHUNK_TOKEN_OVERLAP):
80
- token_ids = tokenizer.encode(text, add_special_tokens=False)
81
- chunks = []
82
- start = 0
83
- while start < len(token_ids):
84
- end = min(start + max_tokens, len(token_ids))
85
- chunk_ids = token_ids[start:end]
86
- chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
87
- chunks.append(chunk_text)
88
- start += max_tokens - overlap
89
- return chunks
90
-
91
-
92
- def load_classification_model():
93
- if LOCAL_CLF_DIR.exists():
94
- print(f"[INFO] Loading classification model from local cache: {LOCAL_CLF_DIR}")
95
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_CLF_DIR)
96
- model = AutoModelForSequenceClassification.from_pretrained(LOCAL_CLF_DIR).to(device)
97
- else:
98
- print(f"[INFO] Downloading classification model from: {CLASSIFICATION_MODEL_NAME}")
99
- tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL_NAME)
100
- model = AutoModelForSequenceClassification.from_pretrained(CLASSIFICATION_MODEL_NAME).to(device)
101
- LOCAL_CLF_DIR.mkdir(parents=True, exist_ok=True)
102
- tokenizer.save_pretrained(LOCAL_CLF_DIR)
103
- model.save_pretrained(LOCAL_CLF_DIR)
104
- return tokenizer, model
105
-
106
-
107
- def classify_text_chunk(chunk, tokenizer, model):
108
- inputs = tokenizer(chunk, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
109
- with torch.no_grad():
110
- outputs = model(**inputs)
111
- pred = outputs.logits.argmax(dim=-1).cpu().item()
112
- return classification_dept_map.get(pred, "Unknown")
113
-
114
- _loaded_models = {}
115
-
116
- def load_all_models():
117
- global _loaded_models
118
- if _loaded_models:
119
- return _loaded_models['tokenizer'], _loaded_models['model'], _loaded_models['nlp_model']
120
-
121
- clf_tokenizer, clf_model = load_classification_model()
122
- nlp_model = NLP_MODEL
123
-
124
- _loaded_models = {
125
- 'tokenizer': clf_tokenizer,
126
- 'model': clf_model,
127
- 'nlp_model': nlp_model
128
- }
129
- return clf_tokenizer, clf_model, nlp_model
130
-
131
- def highlight_text(pdf_path, terms, output_path="highlighted.pdf"):
132
- doc = pymupdf.open(pdf_path)
133
-
134
- for page_num, page in enumerate(doc):
135
- # ----- 1. Native text search highlighting -----
136
- for term in terms:
137
- text_instances = page.search_for(term)
138
- for inst in text_instances:
139
- highlight = page.add_highlight_annot(inst)
140
- highlight.update()
141
-
142
- # ----- 2. OCR highlighting for scanned PDFs -----
143
- # Render page as image
144
- pix = page.get_pixmap()
145
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
146
-
147
- # OCR with bounding boxes
148
- ocr_data = pytesseract.image_to_data(img, output_type=Output.DICT)
149
- for i, word in enumerate(ocr_data['text']):
150
- for term in terms:
151
- if term.lower() in word.lower():
152
- x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i],
153
- ocr_data['width'][i], ocr_data['height'][i])
154
- # Convert OCR coords to PDF coordinates
155
- rect = pymupdf.Rect(x, y, x + w, y + h)
156
- highlight = page.add_highlight_annot(rect)
157
- highlight.update()
158
-
159
- # Save PDF once after all highlights
160
- doc.save(output_path)
161
- doc.close()
162
- return output_path
163
-
164
- def pipeline_process_pdf(pdf_path, clf_tokenizer, clf_model, nlp_model):
165
- pdf_id = os.path.splitext(os.path.basename(pdf_path))[0]
166
- doc = pymupdf.open(pdf_path)
167
-
168
- dept_votes = []
169
- deadlines_all = []
170
- financials_all = []
171
-
172
- for page_number, page in enumerate(doc, start=1):
173
- raw_text = extract_page_text(page, doc)
174
-
175
- if not raw_text:
176
- continue
177
- # print("cleaned_text\n")
178
- cleaned_text = clean_text_multilingual(raw_text)
179
- # print(cleaned_text)
180
- if not cleaned_text:
181
- continue
182
-
183
- ner_results = ner_extraction_multilingual(cleaned_text)
184
- deadlines_all.extend(ner_results.get("deadlines", []))
185
- financials_all.extend(ner_results.get("financials", []))
186
-
187
- chunks = chunk_text_tokenwise(cleaned_text, tokenizer=clf_tokenizer)
188
- gen_ai1.encode(pdf_id, page_number,chunks)
189
-
190
- # print("chunks\n")
191
- for chunk in chunks:
192
- # print(chunk)
193
- dept = classify_text_chunk(chunk, clf_tokenizer, clf_model)
194
- dept_votes.append(dept)
195
-
196
- dominant_dept = Counter(dept_votes).most_common(1)[0][0] if dept_votes else "Unknown"
197
-
198
- summary = gen_ai1.create_summary(pdf_id)
199
- #print(summary)
200
- terms = deadlines_all + financials_all
201
- output_path = highlight_text(pdf_path, terms=terms, output_path=f"{pdf_id}_highlighted.pdf")
202
- return {
203
- "department": dominant_dept,
204
- "summary": summary,
205
- "deadlines": deadlines_all,
206
- "financials": financials_all,
207
- "highlighted_pdf": output_path
208
- }
209
-
210
-
211
- if __name__ == "__main__":
212
- parser = argparse.ArgumentParser(description="Unified PDF Processing Pipeline")
213
- parser.add_argument("pdf_file", help="Path to input PDF file")
214
- args = parser.parse_args()
215
-
216
- print("[INFO] Loading all models...")
217
- tokenizer, model, nlp_model = load_all_models()
218
-
219
- print("[INFO] Processing PDF through pipeline...")
220
- results = pipeline_process_pdf(args.pdf_file, tokenizer, model, nlp_model)
221
-
222
- print("\n================ Pipeline Output ================\n")
223
- #print(f"Dominant Department: {results['department']}")
224
- print(f"\nSummary:\n{results['summary']}")
225
- print(f"\nDeadlines found: {results['deadlines']}")
226
- print(f"\nFinancial terms found: {results['financials']}")
227
-
 
 
 
 
 
1
+ import regex as re
2
+ import unicodedata
3
+ from collections import Counter
4
+ from pathlib import Path
5
+ import io
6
+ import argparse
7
+ import os
8
+ import pymupdf
9
+ from PIL import Image, ImageOps, ImageFilter
10
+ import pytesseract
11
+ import spacy
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
+ from pytesseract import Output
15
+ from ner_functions import ner_extraction_multilingual, get_deadline, get_financial_details
16
+ import gen_ai1
17
+
18
+ # ==================== CONFIGURATION ====================
19
+ MAX_CHUNK_TOKENS = 256
20
+ CHUNK_TOKEN_OVERLAP = 40
21
+ CLASSIFICATION_MODEL_NAME = "Shrut04/Fine_tunned_indic_bert_on_documents" # Updated model path
22
+
23
+ #CURRENT_DIR = Path(__file__).resolve().parent
24
+ #LOCAL_CLF_DIR = CURRENT_DIR / "models" / "final"
25
+
26
+ classification_dept_map = {
27
+ 0: "Engineering",
28
+ 1: "Finance",
29
+ 2: "HR",
30
+ 3: "Maintenance",
31
+ 4: "Operations",
32
+ }
33
+
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ NLP_MODEL = spacy.load("en_core_web_md")
36
+
37
+
38
+ def clean_text_multilingual(text):
39
+ # text = unicodedata.normalize("NFKC", text)
40
+ # #text = re.sub(r"[^A-Za-z0-9\s.,;:!?()'\-\"@%$&]", " ", text)
41
+ # text = re.sub(r"[^\p{L}\p{N}\s.,;:!?()'\-\"@%$&]", " ", text)
42
+ text = re.sub(r'\s+', ' ', text)
43
+ return text.strip()
44
+
45
+
46
+ def extract_page_text(page, doc):
47
+ raw_text = ""
48
+ for block in page.get_text("blocks"):
49
+ txt = block[4].strip()
50
+ if txt:
51
+ raw_text += " " + txt
52
+
53
+ images = page.get_images(full=True)
54
+ for img in images:
55
+ xref = img[0]
56
+ try:
57
+ img_data = doc.extract_image(xref)
58
+ image = Image.open(io.BytesIO(img_data["image"]))
59
+ except Exception:
60
+ print(f"[WARNING] Image extraction failed")
61
+ continue
62
+
63
+ filtered = image.filter(ImageFilter.MedianFilter(size=3))
64
+ gray = ImageOps.grayscale(filtered)
65
+ scale = 300 / 72
66
+ base_w = min(int(gray.width * scale), 1000)
67
+ base_h = min(int(gray.height * scale), 1000)
68
+ gray_resized = gray.resize((base_w, base_h), Image.LANCZOS)
69
+
70
+ try:
71
+ ocr_text = pytesseract.image_to_string(gray_resized, lang="eng+mal")
72
+ raw_text += " " + ocr_text
73
+ except Exception:
74
+ continue
75
+
76
+ return raw_text.strip()
77
+
78
+
79
+ def chunk_text_tokenwise(text, tokenizer, max_tokens=MAX_CHUNK_TOKENS, overlap=CHUNK_TOKEN_OVERLAP):
80
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
81
+ chunks = []
82
+ start = 0
83
+ while start < len(token_ids):
84
+ end = min(start + max_tokens, len(token_ids))
85
+ chunk_ids = token_ids[start:end]
86
+ chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
87
+ chunks.append(chunk_text)
88
+ start += max_tokens - overlap
89
+ return chunks
90
+
91
+
92
+ def load_classification_model():
93
+ # if LOCAL_CLF_DIR.exists():
94
+ # print(f"[INFO] Loading classification model from hugging face: {CLASSIFICATION_MODEL_NAME}")
95
+ # tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL_NAME)
96
+ # model = AutoModelForSequenceClassification.from_pretrained(CLASSIFICATION_MODEL_NAME).to(device)
97
+ # else:
98
+ # print(f"[INFO] Downloading classification model from: {CLASSIFICATION_MODEL_NAME}")
99
+ # tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL_NAME)
100
+ # model = AutoModelForSequenceClassification.from_pretrained(CLASSIFICATION_MODEL_NAME).to(device)
101
+ # LOCAL_CLF_DIR.mkdir(parents=True, exist_ok=True)
102
+ # tokenizer.save_pretrained(LOCAL_CLF_DIR)
103
+ # model.save_pretrained(LOCAL_CLF_DIR)
104
+ print(f"[INFO] Loading classification model from hugging face: {CLASSIFICATION_MODEL_NAME}")
105
+ tokenizer = AutoTokenizer.from_pretrained(CLASSIFICATION_MODEL_NAME)
106
+ model = AutoModelForSequenceClassification.from_pretrained(CLASSIFICATION_MODEL_NAME).to(device)
107
+
108
+ return tokenizer, model
109
+
110
+
111
+ def classify_text_chunk(chunk, tokenizer, model):
112
+ inputs = tokenizer(chunk, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
113
+ with torch.no_grad():
114
+ outputs = model(**inputs)
115
+ pred = outputs.logits.argmax(dim=-1).cpu().item()
116
+ return classification_dept_map.get(pred, "Unknown")
117
+
118
+ _loaded_models = {}
119
+
120
+ def load_all_models():
121
+ global _loaded_models
122
+ if _loaded_models:
123
+ return _loaded_models['tokenizer'], _loaded_models['model'], _loaded_models['nlp_model']
124
+
125
+ clf_tokenizer, clf_model = load_classification_model()
126
+ nlp_model = NLP_MODEL
127
+
128
+ _loaded_models = {
129
+ 'tokenizer': clf_tokenizer,
130
+ 'model': clf_model,
131
+ 'nlp_model': nlp_model
132
+ }
133
+ return clf_tokenizer, clf_model, nlp_model
134
+
135
+ def highlight_text(pdf_path, terms, output_path="highlighted.pdf"):
136
+ doc = pymupdf.open(pdf_path)
137
+
138
+ for page_num, page in enumerate(doc):
139
+ # ----- 1. Native text search highlighting -----
140
+ for term in terms:
141
+ text_instances = page.search_for(term)
142
+ for inst in text_instances:
143
+ highlight = page.add_highlight_annot(inst)
144
+ highlight.update()
145
+
146
+ # ----- 2. OCR highlighting for scanned PDFs -----
147
+ # Render page as image
148
+ pix = page.get_pixmap()
149
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
150
+
151
+ # OCR with bounding boxes
152
+ ocr_data = pytesseract.image_to_data(img, output_type=Output.DICT)
153
+ for i, word in enumerate(ocr_data['text']):
154
+ for term in terms:
155
+ if term.lower() in word.lower():
156
+ x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i],
157
+ ocr_data['width'][i], ocr_data['height'][i])
158
+ # Convert OCR coords to PDF coordinates
159
+ rect = pymupdf.Rect(x, y, x + w, y + h)
160
+ highlight = page.add_highlight_annot(rect)
161
+ highlight.update()
162
+
163
+ # Save PDF once after all highlights
164
+ doc.save(output_path)
165
+ doc.close()
166
+ return output_path
167
+
168
+ def pipeline_process_pdf(pdf_path, clf_tokenizer, clf_model, nlp_model):
169
+ pdf_id = os.path.splitext(os.path.basename(pdf_path))[0]
170
+ doc = pymupdf.open(pdf_path)
171
+
172
+ dept_votes = []
173
+ deadlines_all = []
174
+ financials_all = []
175
+
176
+ for page_number, page in enumerate(doc, start=1):
177
+ raw_text = extract_page_text(page, doc)
178
+
179
+ if not raw_text:
180
+ continue
181
+ # print("cleaned_text\n")
182
+ cleaned_text = clean_text_multilingual(raw_text)
183
+ # print(cleaned_text)
184
+ if not cleaned_text:
185
+ continue
186
+
187
+ ner_results = ner_extraction_multilingual(cleaned_text)
188
+ deadlines_all.extend(ner_results.get("deadlines", []))
189
+ financials_all.extend(ner_results.get("financials", []))
190
+
191
+ chunks = chunk_text_tokenwise(cleaned_text, tokenizer=clf_tokenizer)
192
+ gen_ai1.encode(pdf_id, page_number,chunks)
193
+
194
+ # print("chunks\n")
195
+ for chunk in chunks:
196
+ # print(chunk)
197
+ dept = classify_text_chunk(chunk, clf_tokenizer, clf_model)
198
+ dept_votes.append(dept)
199
+
200
+ dominant_dept = Counter(dept_votes).most_common(1)[0][0] if dept_votes else "Unknown"
201
+
202
+ summary = gen_ai1.create_summary(pdf_id)
203
+ #print(summary)
204
+ terms = deadlines_all + financials_all
205
+ output_path = highlight_text(pdf_path, terms=terms, output_path=f"{pdf_id}_highlighted.pdf")
206
+ return {
207
+ "department": dominant_dept,
208
+ "summary": summary,
209
+ "deadlines": deadlines_all,
210
+ "financials": financials_all,
211
+ "highlighted_pdf": output_path
212
+ }
213
+
214
+
215
+ if __name__ == "__main__":
216
+ parser = argparse.ArgumentParser(description="Unified PDF Processing Pipeline")
217
+ parser.add_argument("pdf_file", help="Path to input PDF file")
218
+ args = parser.parse_args()
219
+
220
+ print("[INFO] Loading all models...")
221
+ tokenizer, model, nlp_model = load_all_models()
222
+
223
+ print("[INFO] Processing PDF through pipeline...")
224
+ results = pipeline_process_pdf(args.pdf_file, tokenizer, model, nlp_model)
225
+
226
+ print("\n================ Pipeline Output ================\n")
227
+ #print(f"Dominant Department: {results['department']}")
228
+ print(f"\nSummary:\n{results['summary']}")
229
+ print(f"\nDeadlines found: {results['deadlines']}")
230
+ print(f"\nFinancial terms found: {results['financials']}")
231
+