Samuel Oberhofer
commited on
Commit
·
baa6dd8
1
Parent(s):
761de06
feat: Add input guardrail to block SVNR queries
Browse files- rails/input.py +16 -1
rails/input.py
CHANGED
|
@@ -7,6 +7,7 @@ from langdetect import detect, detect_langs
|
|
| 7 |
from collections import Counter
|
| 8 |
import nltk
|
| 9 |
from nltk.corpus import words
|
|
|
|
| 10 |
nltk.download('words')
|
| 11 |
english_vocab = set(words.words())
|
| 12 |
|
|
@@ -135,6 +136,11 @@ class InputGuardRails:
|
|
| 135 |
print("WARNING: Query appears to be gibberish.")
|
| 136 |
return CheckedInput(False, self.get_output("Input appears to be non-sensical."))
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
return CheckedInput(True, None)
|
| 139 |
|
| 140 |
def query_contains_sql_injection(self, query: str) -> bool:
|
|
@@ -201,6 +207,16 @@ class InputGuardRails:
|
|
| 201 |
else:
|
| 202 |
return False
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
def get_output(self, reason: str) -> str:
|
| 205 |
return "\n".join([reason, AUTO_ANSWERS.REPHRASE_SENTENCE.value])
|
| 206 |
|
|
@@ -220,4 +236,3 @@ if __name__ == '__main__':
|
|
| 220 |
for query in test_queries:
|
| 221 |
result = input_guards.is_valid(query)
|
| 222 |
print(f"'{query[:50]}{'...' if len(query) > 50 else ''}' is valid: {result}")
|
| 223 |
-
|
|
|
|
| 7 |
from collections import Counter
|
| 8 |
import nltk
|
| 9 |
from nltk.corpus import words
|
| 10 |
+
from guards.svnr import is_valid_svnr
|
| 11 |
nltk.download('words')
|
| 12 |
english_vocab = set(words.words())
|
| 13 |
|
|
|
|
| 136 |
print("WARNING: Query appears to be gibberish.")
|
| 137 |
return CheckedInput(False, self.get_output("Input appears to be non-sensical."))
|
| 138 |
|
| 139 |
+
# Block queries containing valid SVNRs
|
| 140 |
+
if self.query_contains_valid_svnr(query):
|
| 141 |
+
print("WARNING: Query appears to contain a valid SVNR.")
|
| 142 |
+
return CheckedInput(False, self.get_output("Queries containing Austrian social security numbers are not permitted."))
|
| 143 |
+
|
| 144 |
return CheckedInput(True, None)
|
| 145 |
|
| 146 |
def query_contains_sql_injection(self, query: str) -> bool:
|
|
|
|
| 207 |
else:
|
| 208 |
return False
|
| 209 |
|
| 210 |
+
def query_contains_valid_svnr(self, query: str) -> bool:
|
| 211 |
+
"""
|
| 212 |
+
Checks if the query contains a valid Austrian social security number (SVNR).
|
| 213 |
+
"""
|
| 214 |
+
potential_svnrs = re.findall(r'\b\d{10}\b', query)
|
| 215 |
+
for svnr in potential_svnrs:
|
| 216 |
+
if is_valid_svnr(svnr):
|
| 217 |
+
return True
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
def get_output(self, reason: str) -> str:
|
| 221 |
return "\n".join([reason, AUTO_ANSWERS.REPHRASE_SENTENCE.value])
|
| 222 |
|
|
|
|
| 236 |
for query in test_queries:
|
| 237 |
result = input_guards.is_valid(query)
|
| 238 |
print(f"'{query[:50]}{'...' if len(query) > 50 else ''}' is valid: {result}")
|
|
|