Samuel Oberhofer commited on
Commit
baa6dd8
·
1 Parent(s): 761de06

feat: Add input guardrail to block SVNR queries

Browse files
Files changed (1) hide show
  1. rails/input.py +16 -1
rails/input.py CHANGED
@@ -7,6 +7,7 @@ from langdetect import detect, detect_langs
7
  from collections import Counter
8
  import nltk
9
  from nltk.corpus import words
 
10
  nltk.download('words')
11
  english_vocab = set(words.words())
12
 
@@ -135,6 +136,11 @@ class InputGuardRails:
135
  print("WARNING: Query appears to be gibberish.")
136
  return CheckedInput(False, self.get_output("Input appears to be non-sensical."))
137
 
 
 
 
 
 
138
  return CheckedInput(True, None)
139
 
140
  def query_contains_sql_injection(self, query: str) -> bool:
@@ -201,6 +207,16 @@ class InputGuardRails:
201
  else:
202
  return False
203
 
 
 
 
 
 
 
 
 
 
 
204
  def get_output(self, reason: str) -> str:
205
  return "\n".join([reason, AUTO_ANSWERS.REPHRASE_SENTENCE.value])
206
 
@@ -220,4 +236,3 @@ if __name__ == '__main__':
220
  for query in test_queries:
221
  result = input_guards.is_valid(query)
222
  print(f"'{query[:50]}{'...' if len(query) > 50 else ''}' is valid: {result}")
223
-
 
7
  from collections import Counter
8
  import nltk
9
  from nltk.corpus import words
10
+ from guards.svnr import is_valid_svnr
11
  nltk.download('words')
12
  english_vocab = set(words.words())
13
 
 
136
  print("WARNING: Query appears to be gibberish.")
137
  return CheckedInput(False, self.get_output("Input appears to be non-sensical."))
138
 
139
+ # Block queries containing valid SVNRs
140
+ if self.query_contains_valid_svnr(query):
141
+ print("WARNING: Query appears to contain a valid SVNR.")
142
+ return CheckedInput(False, self.get_output("Queries containing Austrian social security numbers are not permitted."))
143
+
144
  return CheckedInput(True, None)
145
 
146
  def query_contains_sql_injection(self, query: str) -> bool:
 
207
  else:
208
  return False
209
 
210
+ def query_contains_valid_svnr(self, query: str) -> bool:
211
+ """
212
+ Checks if the query contains a valid Austrian social security number (SVNR).
213
+ """
214
+ potential_svnrs = re.findall(r'\b\d{10}\b', query)
215
+ for svnr in potential_svnrs:
216
+ if is_valid_svnr(svnr):
217
+ return True
218
+ return False
219
+
220
  def get_output(self, reason: str) -> str:
221
  return "\n".join([reason, AUTO_ANSWERS.REPHRASE_SENTENCE.value])
222
 
 
236
  for query in test_queries:
237
  result = input_guards.is_valid(query)
238
  print(f"'{query[:50]}{'...' if len(query) > 50 else ''}' is valid: {result}")