Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import io | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| from huggingface_hub import hf_hub_download | |
| device = 'cpu' | |
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| st.markdown("## Finder!") | |
| st.markdown("### Get a segmentation mask by a text describtion or a reference image.") | |
| st.markdown("DISCLAIMER: this is working much longer because of cpu :(") | |
| st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True) | |
| # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter | |
| uploaded_file = st.file_uploader( | |
| "Provide an image where you want to find something", | |
| type=['png', 'jpg', 'jpeg'] | |
| ) | |
| def get_sam(): | |
| print("LOADING SAM") | |
| sam = sam_model_registry["default"](checkpoint=hf_hub_download(repo_id="n0r9st/segment-anything", filename="sam_vit_h_4b8939.pth")) | |
| sam.to(device=device) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| # predictor = SamPredictor(sam) | |
| return mask_generator | |
| mask_generator = get_sam() | |
| def get_masks(image_bytes): | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| masks = mask_generator.generate(np.array(image)) | |
| masked_imgs = [Image.fromarray(np.array(image) * mask['segmentation'][..., None]) for mask in masks] | |
| return masked_imgs, [mask['segmentation'] for mask in masks] | |
| def get_clip_model_and_processor(): | |
| print("LOADING CLIP") | |
| model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| return model, processor | |
| model, processor = get_clip_model_and_processor() | |
| def embed_pil_images(clip_model: CLIPModel, clip_processor: CLIPProcessor, images): | |
| output_attentions = clip_model.config.output_attentions | |
| output_hidden_states = ( | |
| clip_model.config.output_hidden_states | |
| ) | |
| return_dict = clip_model.config.use_return_dict | |
| pixel_values = clip_processor(images=images, return_tensors="pt", padding=True)['pixel_values'] | |
| vision_outputs = clip_model.vision_model( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| image_embeds = vision_outputs[1] | |
| image_embeds = clip_model.visual_projection(image_embeds) | |
| # normalized features | |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
| return image_embeds.detach().cpu() | |
| def embed_text(clip_model: CLIPModel, clip_processor: CLIPProcessor, text): | |
| output_attentions = clip_model.config.output_attentions | |
| output_hidden_states = clip_model.config.output_hidden_states | |
| return_dict = clip_model.config.use_return_dict | |
| inputs = clip_processor(text=[text], return_tensors="pt", padding=True) | |
| text_outputs = clip_model.text_model( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| text_embeds = text_outputs[1] | |
| text_embeds = clip_model.text_projection(text_embeds) | |
| return text_embeds.detach().cpu() | |
| def get_masks_and_embeds(image_bytes): | |
| with torch.no_grad(): | |
| masked_imgs, masks = get_masks(image_bytes) | |
| embeds = embed_pil_images(model, processor, masked_imgs) | |
| return masked_imgs, embeds, masks | |
| def get_probs(img_embeds, text_embeds): | |
| logit_scale = model.logit_scale.exp() | |
| logits_per_image = (torch.matmul(text_embeds, img_embeds.t()) * logit_scale) | |
| probs = logits_per_image.softmax(dim=-1) | |
| return probs.detach().cpu().numpy() | |
| def get_top_mask(image_bytes, text=None, ref_image=None): | |
| masked_imgs, img_embeds, masks = get_masks_and_embeds(image_bytes) | |
| if text is not None: | |
| q_embeds = embed_text(model, processor, text) | |
| elif ref_image is not None: | |
| q_embeds = embed_pil_images(model, processor, [ref_image]) | |
| probs = get_probs(img_embeds, q_embeds) | |
| return masked_imgs[np.argmax(probs)], masks[np.argmax(probs)] | |
| if uploaded_file is not None: | |
| image_bytes = uploaded_file.getvalue() | |
| option = st.selectbox('Find a segmentation mask by what mode of reference?', ('Text', 'Image')) | |
| to_show = Image.open(io.BytesIO(uploaded_file.getvalue())) | |
| text = "" | |
| ref_image_file = None | |
| if option == 'Text': | |
| text = st.text_input('Textual describtion of an object of interest') | |
| _, mask = get_top_mask(image_bytes, text=text) | |
| elif option == 'Image': | |
| ref_image_file = st.file_uploader("Reference image of an object of interest", type=['png', 'jpg', 'jpeg']) | |
| if ref_image_file: | |
| ref_image_bytes = ref_image_file.getvalue() | |
| ref_image = Image.open(io.BytesIO(ref_image_bytes)) | |
| st.image(ref_image, caption='Your reference image') | |
| _, mask = get_top_mask(image_bytes, ref_image=ref_image) | |
| if (option == 'Text' and text != "") or (option=='Image' and (ref_image_file is not None)): | |
| to_show = Image.fromarray( | |
| np.clip(np.array([1, 0, 0]) * mask[:, :, None] * 100 + np.array(to_show, dtype=int), a_min = 0, a_max = 255).astype(np.uint8) | |
| ) | |
| st.image(to_show, caption='Input image with highlighted mask') | |