munichpavel commited on
Commit
99c8950
·
1 Parent(s): 8be8463

Swap around models

Browse files
app.py CHANGED
@@ -5,8 +5,8 @@ from ai_security.chatter_demo import demo as chatter_demo
5
  from ai_security.malware_demo import demo as malware_demo
6
 
7
 
8
- # demo = chatter_demo
9
- demo = malware_demo
10
 
11
  if __name__ == "__main__":
12
 
 
5
  from ai_security.malware_demo import demo as malware_demo
6
 
7
 
8
+ demo = chatter_demo
9
+ # demo = malware_demo
10
 
11
  if __name__ == "__main__":
12
 
src/ai_security/chatter_demo.py CHANGED
@@ -6,13 +6,13 @@ from .rules_chatter_detector import simple_normalized_blacbriar_chatter_detecto
6
  from .discriminative_chatter_detector import DiscriminativeChatterDetector
7
 
8
 
9
- detector_a = DiscriminativeChatterDetector(scope='blackbriar-only')
10
- detector_c = GenerativeChatterDetector(scope='blackbriar-only')
11
 
12
 
13
  def detect_chatter_a(text):
14
  """Model A detection"""
15
- result = detector_a.predict(text)
16
  return f"**Prediction:** {result['label']}"
17
 
18
 
@@ -26,7 +26,7 @@ def detect_chatter_b(text):
26
 
27
  def detect_chatter_c(text):
28
  """Model C detection"""
29
- result = detector_c.detect(text)
30
  return f"**Prediction:** {result['label']}"
31
 
32
 
 
6
  from .discriminative_chatter_detector import DiscriminativeChatterDetector
7
 
8
 
9
+ discriminative_detector = DiscriminativeChatterDetector(scope='blackbriar-only')
10
+ generative_detector = GenerativeChatterDetector(scope='blackbriar-only')
11
 
12
 
13
  def detect_chatter_a(text):
14
  """Model A detection"""
15
+ result = generative_detector.predict(text)
16
  return f"**Prediction:** {result['label']}"
17
 
18
 
 
26
 
27
  def detect_chatter_c(text):
28
  """Model C detection"""
29
+ result = discriminative_detector.predict(text)
30
  return f"**Prediction:** {result['label']}"
31
 
32
 
src/ai_security/generative_chatter_detector.py CHANGED
@@ -26,7 +26,7 @@ class GenerativeChatterDetector:
26
  self.client = InferenceClient(model='google/gemma-2-2b-it', token=os.environ['HF_INFERENCE_TOKEN'])
27
 
28
 
29
- def detect(self, text):
30
  try:
31
  prompt = self.prompt_template.format(text=text)
32
  print(f"Prompt: {prompt}")
 
26
  self.client = InferenceClient(model='google/gemma-2-2b-it', token=os.environ['HF_INFERENCE_TOKEN'])
27
 
28
 
29
+ def predict(self, text):
30
  try:
31
  prompt = self.prompt_template.format(text=text)
32
  print(f"Prompt: {prompt}")