Update app.py
Browse files
app.py
CHANGED
|
@@ -533,7 +533,7 @@ def load_llm_model():
|
|
| 533 |
)
|
| 534 |
|
| 535 |
# Load the adapter
|
| 536 |
-
adapter_id = "saakshigupta/deepfake-explainer-
|
| 537 |
model = PeftModel.from_pretrained(model, adapter_id)
|
| 538 |
|
| 539 |
# Set to inference mode
|
|
@@ -552,50 +552,77 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
| 552 |
else:
|
| 553 |
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
|
| 554 |
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
{"
|
| 558 |
-
{"type": "image", "image": image}, # Original image
|
| 559 |
-
{"type": "image", "image": gradcam_overlay}, # GradCAM overlay
|
| 560 |
-
{"type": "text", "text": full_prompt}
|
| 561 |
-
]}
|
| 562 |
-
]
|
| 563 |
-
|
| 564 |
-
# Apply chat template
|
| 565 |
-
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
| 566 |
-
|
| 567 |
-
# Process with image
|
| 568 |
-
inputs = tokenizer(
|
| 569 |
-
[image, gradcam_overlay], # Send both images
|
| 570 |
-
input_text,
|
| 571 |
-
add_special_tokens=False,
|
| 572 |
-
return_tensors="pt",
|
| 573 |
-
).to(model.device)
|
| 574 |
-
|
| 575 |
-
# Fix cross-attention mask if needed
|
| 576 |
-
inputs = fix_cross_attention_mask(inputs)
|
| 577 |
-
|
| 578 |
-
# Generate response
|
| 579 |
-
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
| 580 |
-
with torch.no_grad():
|
| 581 |
-
output_ids = model.generate(
|
| 582 |
-
**inputs,
|
| 583 |
-
max_new_tokens=max_tokens,
|
| 584 |
-
use_cache=True,
|
| 585 |
-
temperature=temperature,
|
| 586 |
-
top_p=0.9
|
| 587 |
-
)
|
| 588 |
|
| 589 |
-
#
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
#
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
|
| 600 |
# Main app
|
| 601 |
def main():
|
|
@@ -818,7 +845,7 @@ def main():
|
|
| 818 |
caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
|
| 819 |
|
| 820 |
# Default question with option to customize
|
| 821 |
-
default_question = f"This image has been classified as {
|
| 822 |
|
| 823 |
# User input for new question
|
| 824 |
new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
|
|
@@ -902,5 +929,8 @@ def main():
|
|
| 902 |
# Footer
|
| 903 |
st.markdown("---")
|
| 904 |
|
|
|
|
|
|
|
|
|
|
| 905 |
if __name__ == "__main__":
|
| 906 |
main()
|
|
|
|
| 533 |
)
|
| 534 |
|
| 535 |
# Load the adapter
|
| 536 |
+
adapter_id = "saakshigupta/deepfake-explainer-2"
|
| 537 |
model = PeftModel.from_pretrained(model, adapter_id)
|
| 538 |
|
| 539 |
# Set to inference mode
|
|
|
|
| 552 |
else:
|
| 553 |
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
|
| 554 |
|
| 555 |
+
try:
|
| 556 |
+
# Format the message to include all available images
|
| 557 |
+
message_content = [{"type": "text", "text": full_prompt}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
+
# Add original image
|
| 560 |
+
message_content.insert(0, {"type": "image", "image": image})
|
| 561 |
+
|
| 562 |
+
# Add GradCAM overlay
|
| 563 |
+
message_content.insert(1, {"type": "image", "image": gradcam_overlay})
|
| 564 |
+
|
| 565 |
+
# Add comparison image if available
|
| 566 |
+
if hasattr(st.session_state, 'comparison_image'):
|
| 567 |
+
message_content.insert(2, {"type": "image", "image": st.session_state.comparison_image})
|
| 568 |
+
|
| 569 |
+
messages = [{"role": "user", "content": message_content}]
|
| 570 |
+
|
| 571 |
+
# Apply chat template
|
| 572 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
| 573 |
+
|
| 574 |
+
# Create list of images to process
|
| 575 |
+
image_list = [image, gradcam_overlay]
|
| 576 |
+
if hasattr(st.session_state, 'comparison_image'):
|
| 577 |
+
image_list.append(st.session_state.comparison_image)
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
# Try with multiple images first
|
| 581 |
+
inputs = tokenizer(
|
| 582 |
+
image_list,
|
| 583 |
+
input_text,
|
| 584 |
+
add_special_tokens=False,
|
| 585 |
+
return_tensors="pt",
|
| 586 |
+
).to(model.device)
|
| 587 |
+
except Exception as e:
|
| 588 |
+
st.warning(f"Multiple image analysis encountered an issue: {str(e)}")
|
| 589 |
+
st.info("Falling back to single image analysis")
|
| 590 |
+
# Fallback to single image
|
| 591 |
+
inputs = tokenizer(
|
| 592 |
+
image,
|
| 593 |
+
input_text,
|
| 594 |
+
add_special_tokens=False,
|
| 595 |
+
return_tensors="pt",
|
| 596 |
+
).to(model.device)
|
| 597 |
+
|
| 598 |
+
# Fix cross-attention mask if needed
|
| 599 |
+
inputs = fix_cross_attention_mask(inputs)
|
| 600 |
+
|
| 601 |
+
# Generate response
|
| 602 |
+
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
| 603 |
+
with torch.no_grad():
|
| 604 |
+
output_ids = model.generate(
|
| 605 |
+
**inputs,
|
| 606 |
+
max_new_tokens=max_tokens,
|
| 607 |
+
use_cache=True,
|
| 608 |
+
temperature=temperature,
|
| 609 |
+
top_p=0.9
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# Decode the output
|
| 613 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 614 |
+
|
| 615 |
+
# Try to extract just the model's response (after the prompt)
|
| 616 |
+
if full_prompt in response:
|
| 617 |
+
result = response.split(full_prompt)[-1].strip()
|
| 618 |
+
else:
|
| 619 |
+
result = response
|
| 620 |
+
|
| 621 |
+
return result
|
| 622 |
+
|
| 623 |
+
except Exception as e:
|
| 624 |
+
st.error(f"Error during LLM analysis: {str(e)}")
|
| 625 |
+
return f"Error analyzing image: {str(e)}"
|
| 626 |
|
| 627 |
# Main app
|
| 628 |
def main():
|
|
|
|
| 845 |
caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
|
| 846 |
|
| 847 |
# Default question with option to customize
|
| 848 |
+
default_question = f"This image has been classified as {{pred_label}}. Analyze all the provided images (original, GradCAM visualization, and comparison) to determine if this is a deepfake. Focus on highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
|
| 849 |
|
| 850 |
# User input for new question
|
| 851 |
new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
|
|
|
|
| 929 |
# Footer
|
| 930 |
st.markdown("---")
|
| 931 |
|
| 932 |
+
# Add model version indicator in sidebar
|
| 933 |
+
st.sidebar.info("Using deepfake-explainer-2 model")
|
| 934 |
+
|
| 935 |
if __name__ == "__main__":
|
| 936 |
main()
|