sa-clip-v0 / app.py
n0r9st's picture
Update app.py
75cf352
import streamlit as st
import io
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from huggingface_hub import hf_hub_download
device = 'cpu'
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
st.markdown("## Finder!")
st.markdown("### Get a segmentation mask by a text describtion or a reference image.")
st.markdown("DISCLAIMER: this is working much longer because of cpu :(")
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
uploaded_file = st.file_uploader(
"Provide an image where you want to find something",
type=['png', 'jpg', 'jpeg']
)
@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True)
def get_sam():
print("LOADING SAM")
sam = sam_model_registry["default"](checkpoint=hf_hub_download(repo_id="n0r9st/segment-anything", filename="sam_vit_h_4b8939.pth"))
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
# predictor = SamPredictor(sam)
return mask_generator
mask_generator = get_sam()
@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True)
def get_masks(image_bytes):
image = Image.open(io.BytesIO(image_bytes))
masks = mask_generator.generate(np.array(image))
masked_imgs = [Image.fromarray(np.array(image) * mask['segmentation'][..., None]) for mask in masks]
return masked_imgs, [mask['segmentation'] for mask in masks]
@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True)
def get_clip_model_and_processor():
print("LOADING CLIP")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
return model, processor
model, processor = get_clip_model_and_processor()
def embed_pil_images(clip_model: CLIPModel, clip_processor: CLIPProcessor, images):
output_attentions = clip_model.config.output_attentions
output_hidden_states = (
clip_model.config.output_hidden_states
)
return_dict = clip_model.config.use_return_dict
pixel_values = clip_processor(images=images, return_tensors="pt", padding=True)['pixel_values']
vision_outputs = clip_model.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = clip_model.visual_projection(image_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
return image_embeds.detach().cpu()
def embed_text(clip_model: CLIPModel, clip_processor: CLIPProcessor, text):
output_attentions = clip_model.config.output_attentions
output_hidden_states = clip_model.config.output_hidden_states
return_dict = clip_model.config.use_return_dict
inputs = clip_processor(text=[text], return_tensors="pt", padding=True)
text_outputs = clip_model.text_model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_embeds = text_outputs[1]
text_embeds = clip_model.text_projection(text_embeds)
return text_embeds.detach().cpu()
@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True)
def get_masks_and_embeds(image_bytes):
with torch.no_grad():
masked_imgs, masks = get_masks(image_bytes)
embeds = embed_pil_images(model, processor, masked_imgs)
return masked_imgs, embeds, masks
def get_probs(img_embeds, text_embeds):
logit_scale = model.logit_scale.exp()
logits_per_image = (torch.matmul(text_embeds, img_embeds.t()) * logit_scale)
probs = logits_per_image.softmax(dim=-1)
return probs.detach().cpu().numpy()
def get_top_mask(image_bytes, text=None, ref_image=None):
masked_imgs, img_embeds, masks = get_masks_and_embeds(image_bytes)
if text is not None:
q_embeds = embed_text(model, processor, text)
elif ref_image is not None:
q_embeds = embed_pil_images(model, processor, [ref_image])
probs = get_probs(img_embeds, q_embeds)
return masked_imgs[np.argmax(probs)], masks[np.argmax(probs)]
if uploaded_file is not None:
image_bytes = uploaded_file.getvalue()
option = st.selectbox('Find a segmentation mask by what mode of reference?', ('Text', 'Image'))
to_show = Image.open(io.BytesIO(uploaded_file.getvalue()))
text = ""
ref_image_file = None
if option == 'Text':
text = st.text_input('Textual describtion of an object of interest')
_, mask = get_top_mask(image_bytes, text=text)
elif option == 'Image':
ref_image_file = st.file_uploader("Reference image of an object of interest", type=['png', 'jpg', 'jpeg'])
if ref_image_file:
ref_image_bytes = ref_image_file.getvalue()
ref_image = Image.open(io.BytesIO(ref_image_bytes))
st.image(ref_image, caption='Your reference image')
_, mask = get_top_mask(image_bytes, ref_image=ref_image)
if (option == 'Text' and text != "") or (option=='Image' and (ref_image_file is not None)):
to_show = Image.fromarray(
np.clip(np.array([1, 0, 0]) * mask[:, :, None] * 100 + np.array(to_show, dtype=int), a_min = 0, a_max = 255).astype(np.uint8)
)
st.image(to_show, caption='Input image with highlighted mask')