import os from huggingface_hub import InferenceClient CATEGORIES = ['harmless', 'malware'] PROMPT_TEMPLATE = """Classify the entered code and dependencies as exactly one of the following categories: '{labels}' Message: "code: {{source_code}}\n\ndependencies: {{dependencies}}" Answer:""".format(labels="', '".join(CATEGORIES)) class GenerativeMalwareDetector: def __init__(self): self.prompt_template = PROMPT_TEMPLATE self.categories = CATEGORIES self.client = InferenceClient(model='google/gemma-2-2b-it', token=os.environ['HF_INFERENCE_TOKEN']) def detect(self, source_code: str, dependencies: str) -> dict: try: prompt = self.prompt_template.format(source_code=source_code, dependencies=dependencies) print(f"Prompt: {prompt}") response = self.client.chat_completion( messages=[ {"role": "user", "content": prompt} ], max_tokens=50, # Max *new* tokens temperature=1.0, ) generated_text = response.choices[0].message.content.lower().strip() print(f"Generated Text: {generated_text}") # Hacky response parsing (your original logic) predicted_categories = [] for a_cat in self.categories: if a_cat in generated_text: predicted_categories.append(a_cat) print(f"Predicted Categories: {predicted_categories}") if predicted_categories: # Use -1 to safely get the last item label = predicted_categories[-1] return { 'label': label, 'confidence': None } else: print("Warning: Model response did not contain a known category.") return dict(label=None, confidence=None) except Exception as e: print(f"An error occurred: {e}")