samarth09healthPM commited on
Commit
02f41b6
·
1 Parent(s): f3bad0e

Fix duplicate key error with session state

Browse files
Files changed (1) hide show
  1. deid_pipeline.py +9 -33
deid_pipeline.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import json
2
  import os
3
- from pathlib import Path
4
  from dataclasses import dataclass
5
  from typing import List, Dict, Any, Tuple
6
 
@@ -25,16 +25,7 @@ analyzer_config = {
25
 
26
  # NLP for optional section detection
27
  import spacy
28
-
29
- # If using medspacy, uncomment (preferred for clinical):
30
- # import medspacy
31
- # from medspacy.sectionizer import Sectionizer
32
-
33
- # If not using medspacy, optional lightweight section tagging:
34
- # We'll use regex on common headers as a fallback
35
  import re
36
-
37
- # Encryption
38
  from cryptography.fernet import Fernet
39
 
40
  @dataclass
@@ -46,7 +37,6 @@ class PHISpan:
46
  section: str
47
 
48
  SECTION_HEADERS = [
49
- # Common clinical sections; customize as needed
50
  "HPI", "History of Present Illness",
51
  "PMH", "Past Medical History",
52
  "Medications", "Allergies",
@@ -73,17 +63,18 @@ class DeidPipeline:
73
  """
74
  De-identification pipeline using Microsoft Presidio
75
  """
76
- def __init__(self, fernet_key_path="secure_store/fernet.key"):
77
  """
78
  Initialize de-identification pipeline with Presidio
79
 
80
  Args:
81
- fernet_key_path: Path to Fernet encryption key
82
  """
83
- self.secure_dir = Path(secure_dir)
84
- self.secure_dir.mkdir(exist_ok=True)
85
- import os
86
- from cryptography.fernet import Fernet
 
87
 
88
  # Initialize encryption
89
  try:
@@ -96,7 +87,6 @@ class DeidPipeline:
96
  key = Fernet.generate_key()
97
  # Try to save it (might fail on read-only filesystems)
98
  try:
99
- os.makedirs(os.path.dirname(fernet_key_path), exist_ok=True)
100
  with open(fernet_key_path, "wb") as f:
101
  f.write(key)
102
  except (PermissionError, OSError):
@@ -129,11 +119,9 @@ class DeidPipeline:
129
  Lightweight section finder:
130
  Return list of (section_title, start_idx, end_idx_of_section_block)
131
  """
132
- # Find headers by regex, map their start positions
133
  headers = []
134
  for m in SECTION_PATTERN.finditer(text):
135
  headers.append((m.group("header"), m.start()))
136
- # Add end sentinel
137
  headers.append(("[END]", len(text)))
138
 
139
  sections = []
@@ -142,7 +130,6 @@ class DeidPipeline:
142
  next_title, next_pos = headers[i+1]
143
  sections.append((title.strip(), start_pos, next_pos))
144
  if not sections:
145
- # Single default section if none found
146
  sections = [("DOCUMENT", 0, len(text))]
147
  return sections
148
 
@@ -153,9 +140,7 @@ class DeidPipeline:
153
  return "DOCUMENT"
154
 
155
  def analyze(self, text: str) -> List[Dict[str, Any]]:
156
- # Detect entities
157
  results = self.analyzer.analyze(text=text, language="en")
158
- # Convert to dict for consistency
159
  detections = []
160
  for r in results:
161
  detections.append({
@@ -170,10 +155,8 @@ class DeidPipeline:
170
  """
171
  Replace spans with tags safely (right-to-left to maintain indices).
172
  """
173
- # Determine sections for context
174
  sections = self._detect_sections(text)
175
 
176
- # Build PHI span records
177
  spans: List[PHISpan] = []
178
  for d in detections:
179
  entity = d["entity_type"]
@@ -183,7 +166,6 @@ class DeidPipeline:
183
  section = self._find_section_for_span(sections, start)
184
  spans.append(PHISpan(entity_type=entity, start=start, end=end, text=original, section=section))
185
 
186
- # Replace from the end to avoid index shifting
187
  masked = text
188
  for d in sorted(detections, key=lambda x: x["start"], reverse=True):
189
  entity = d["entity_type"]
@@ -207,7 +189,6 @@ class DeidPipeline:
207
  detections = self.analyze(text)
208
  masked, spans = self.mask(text, detections)
209
 
210
- # Encrypt span map
211
  token = self.encrypt_span_map(
212
  spans=spans,
213
  meta={"note_id": note_id}
@@ -219,19 +200,16 @@ class DeidPipeline:
219
  }
220
 
221
  def _read_text_with_fallback(path: str) -> str:
222
- # 1) Try UTF-8 (preferred for cross-platform)
223
  try:
224
  with open(path, "r", encoding="utf-8") as f:
225
  return f.read()
226
  except UnicodeDecodeError:
227
  pass
228
- # 2) Try Windows-1252 (common for Notepad/docx copy-paste on Windows)
229
  try:
230
  with open(path, "r", encoding="cp1252") as f:
231
  return f.read()
232
  except UnicodeDecodeError:
233
  pass
234
- # 3) Last resort: decode with replacement to avoid crashing; preserves structure
235
  with open(path, "r", encoding="utf-8", errors="replace") as f:
236
  return f.read()
237
 
@@ -242,15 +220,13 @@ def run_file(input_path: str, outputs_dir: str = "data/outputs", secure_dir: str
242
  note_id = os.path.splitext(os.path.basename(input_path))[0]
243
  text = _read_text_with_fallback(input_path)
244
 
245
- pipeline = DeidPipeline()
246
  result = pipeline.run_on_text(text, note_id=note_id)
247
 
248
- # Save masked text normalized to UTF-8
249
  out_txt = os.path.join(outputs_dir, f"{note_id}.deid.txt")
250
  with open(out_txt, "w", encoding="utf-8", newline="\n") as f:
251
  f.write(result["masked_text"])
252
 
253
- # Save encrypted span map (binary)
254
  out_bin = os.path.join(secure_dir, f"{note_id}.spanmap.enc")
255
  with open(out_bin, "wb") as f:
256
  f.write(result["encrypted_span_map"])
 
1
+ # deid_pipeline.py
2
  import json
3
  import os
 
4
  from dataclasses import dataclass
5
  from typing import List, Dict, Any, Tuple
6
 
 
25
 
26
  # NLP for optional section detection
27
  import spacy
 
 
 
 
 
 
 
28
  import re
 
 
29
  from cryptography.fernet import Fernet
30
 
31
  @dataclass
 
37
  section: str
38
 
39
  SECTION_HEADERS = [
 
40
  "HPI", "History of Present Illness",
41
  "PMH", "Past Medical History",
42
  "Medications", "Allergies",
 
63
  """
64
  De-identification pipeline using Microsoft Presidio
65
  """
66
+ def __init__(self, secure_dir="./secure_store"):
67
  """
68
  Initialize de-identification pipeline with Presidio
69
 
70
  Args:
71
+ secure_dir: Directory path to store encryption key (NOT the key file path)
72
  """
73
+ # Ensure secure_dir exists
74
+ os.makedirs(secure_dir, exist_ok=True)
75
+
76
+ # Build full path to key file
77
+ fernet_key_path = os.path.join(secure_dir, "fernet.key")
78
 
79
  # Initialize encryption
80
  try:
 
87
  key = Fernet.generate_key()
88
  # Try to save it (might fail on read-only filesystems)
89
  try:
 
90
  with open(fernet_key_path, "wb") as f:
91
  f.write(key)
92
  except (PermissionError, OSError):
 
119
  Lightweight section finder:
120
  Return list of (section_title, start_idx, end_idx_of_section_block)
121
  """
 
122
  headers = []
123
  for m in SECTION_PATTERN.finditer(text):
124
  headers.append((m.group("header"), m.start()))
 
125
  headers.append(("[END]", len(text)))
126
 
127
  sections = []
 
130
  next_title, next_pos = headers[i+1]
131
  sections.append((title.strip(), start_pos, next_pos))
132
  if not sections:
 
133
  sections = [("DOCUMENT", 0, len(text))]
134
  return sections
135
 
 
140
  return "DOCUMENT"
141
 
142
  def analyze(self, text: str) -> List[Dict[str, Any]]:
 
143
  results = self.analyzer.analyze(text=text, language="en")
 
144
  detections = []
145
  for r in results:
146
  detections.append({
 
155
  """
156
  Replace spans with tags safely (right-to-left to maintain indices).
157
  """
 
158
  sections = self._detect_sections(text)
159
 
 
160
  spans: List[PHISpan] = []
161
  for d in detections:
162
  entity = d["entity_type"]
 
166
  section = self._find_section_for_span(sections, start)
167
  spans.append(PHISpan(entity_type=entity, start=start, end=end, text=original, section=section))
168
 
 
169
  masked = text
170
  for d in sorted(detections, key=lambda x: x["start"], reverse=True):
171
  entity = d["entity_type"]
 
189
  detections = self.analyze(text)
190
  masked, spans = self.mask(text, detections)
191
 
 
192
  token = self.encrypt_span_map(
193
  spans=spans,
194
  meta={"note_id": note_id}
 
200
  }
201
 
202
  def _read_text_with_fallback(path: str) -> str:
 
203
  try:
204
  with open(path, "r", encoding="utf-8") as f:
205
  return f.read()
206
  except UnicodeDecodeError:
207
  pass
 
208
  try:
209
  with open(path, "r", encoding="cp1252") as f:
210
  return f.read()
211
  except UnicodeDecodeError:
212
  pass
 
213
  with open(path, "r", encoding="utf-8", errors="replace") as f:
214
  return f.read()
215
 
 
220
  note_id = os.path.splitext(os.path.basename(input_path))[0]
221
  text = _read_text_with_fallback(input_path)
222
 
223
+ pipeline = DeidPipeline(secure_dir)
224
  result = pipeline.run_on_text(text, note_id=note_id)
225
 
 
226
  out_txt = os.path.join(outputs_dir, f"{note_id}.deid.txt")
227
  with open(out_txt, "w", encoding="utf-8", newline="\n") as f:
228
  f.write(result["masked_text"])
229
 
 
230
  out_bin = os.path.join(secure_dir, f"{note_id}.spanmap.enc")
231
  with open(out_bin, "wb") as f:
232
  f.write(result["encrypted_span_map"])