import streamlit as st import io import numpy as np import torch from PIL import Image from transformers import CLIPProcessor, CLIPModel from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from huggingface_hub import hf_hub_download device = 'cpu' import os os.environ['CUDA_VISIBLE_DEVICES'] = '' st.markdown("## Finder!") st.markdown("### Get a segmentation mask by a text describtion or a reference image.") st.markdown("DISCLAIMER: this is working much longer because of cpu :(") st.markdown("", unsafe_allow_html=True) # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter uploaded_file = st.file_uploader( "Provide an image where you want to find something", type=['png', 'jpg', 'jpeg'] ) @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True) def get_sam(): print("LOADING SAM") sam = sam_model_registry["default"](checkpoint=hf_hub_download(repo_id="n0r9st/segment-anything", filename="sam_vit_h_4b8939.pth")) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) # predictor = SamPredictor(sam) return mask_generator mask_generator = get_sam() @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True) def get_masks(image_bytes): image = Image.open(io.BytesIO(image_bytes)) masks = mask_generator.generate(np.array(image)) masked_imgs = [Image.fromarray(np.array(image) * mask['segmentation'][..., None]) for mask in masks] return masked_imgs, [mask['segmentation'] for mask in masks] @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True) def get_clip_model_and_processor(): print("LOADING CLIP") model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") return model, processor model, processor = get_clip_model_and_processor() def embed_pil_images(clip_model: CLIPModel, clip_processor: CLIPProcessor, images): output_attentions = clip_model.config.output_attentions output_hidden_states = ( clip_model.config.output_hidden_states ) return_dict = clip_model.config.use_return_dict pixel_values = clip_processor(images=images, return_tensors="pt", padding=True)['pixel_values'] vision_outputs = clip_model.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] image_embeds = clip_model.visual_projection(image_embeds) # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) return image_embeds.detach().cpu() def embed_text(clip_model: CLIPModel, clip_processor: CLIPProcessor, text): output_attentions = clip_model.config.output_attentions output_hidden_states = clip_model.config.output_hidden_states return_dict = clip_model.config.use_return_dict inputs = clip_processor(text=[text], return_tensors="pt", padding=True) text_outputs = clip_model.text_model( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) text_embeds = text_outputs[1] text_embeds = clip_model.text_projection(text_embeds) return text_embeds.detach().cpu() @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}, allow_output_mutation=True) def get_masks_and_embeds(image_bytes): with torch.no_grad(): masked_imgs, masks = get_masks(image_bytes) embeds = embed_pil_images(model, processor, masked_imgs) return masked_imgs, embeds, masks def get_probs(img_embeds, text_embeds): logit_scale = model.logit_scale.exp() logits_per_image = (torch.matmul(text_embeds, img_embeds.t()) * logit_scale) probs = logits_per_image.softmax(dim=-1) return probs.detach().cpu().numpy() def get_top_mask(image_bytes, text=None, ref_image=None): masked_imgs, img_embeds, masks = get_masks_and_embeds(image_bytes) if text is not None: q_embeds = embed_text(model, processor, text) elif ref_image is not None: q_embeds = embed_pil_images(model, processor, [ref_image]) probs = get_probs(img_embeds, q_embeds) return masked_imgs[np.argmax(probs)], masks[np.argmax(probs)] if uploaded_file is not None: image_bytes = uploaded_file.getvalue() option = st.selectbox('Find a segmentation mask by what mode of reference?', ('Text', 'Image')) to_show = Image.open(io.BytesIO(uploaded_file.getvalue())) text = "" ref_image_file = None if option == 'Text': text = st.text_input('Textual describtion of an object of interest') _, mask = get_top_mask(image_bytes, text=text) elif option == 'Image': ref_image_file = st.file_uploader("Reference image of an object of interest", type=['png', 'jpg', 'jpeg']) if ref_image_file: ref_image_bytes = ref_image_file.getvalue() ref_image = Image.open(io.BytesIO(ref_image_bytes)) st.image(ref_image, caption='Your reference image') _, mask = get_top_mask(image_bytes, ref_image=ref_image) if (option == 'Text' and text != "") or (option=='Image' and (ref_image_file is not None)): to_show = Image.fromarray( np.clip(np.array([1, 0, 0]) * mask[:, :, None] * 100 + np.array(to_show, dtype=int), a_min = 0, a_max = 255).astype(np.uint8) ) st.image(to_show, caption='Input image with highlighted mask')