import argparse from functools import partial import cv2 import requests import os from io import BytesIO from PIL import Image import numpy as np from pathlib import Path import warnings import torch # prepare the environment os.system("python setup.py build develop --user") os.system("pip install packaging==21.3") os.system("pip install gradio") warnings.filterwarnings("ignore") import gradio as gr from groundingdino.models import build_model from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict from groundingdino.util.inference import annotate, load_image, predict import groundingdino.datasets.transforms as T from huggingface_hub import hf_hub_download , login # Authenticate with Hugging Face Hub login(token=os.getenv("HUGGINGFACE_HUB_TOKEN")) # Use this command for evaluating the Grounding DINO model config_file = "cfg_odvg.py" ckpt_repo_id = "Hasanmog/Peft-GroundingDINO" ckpt_filename = "checkpoint.pth" def load_model_hf(model_config_path, repo_id, filename, device='cpu'): args = SLConfig.fromfile(model_config_path) model = build_model(args) args.device = device cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location='cpu') log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) print("Model loaded from {} \n => {}".format(cache_file, log)) _ = model.eval() return model def image_transform_grounding(init_image): transform = T.Compose([ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image, _ = transform(init_image, None) # 3, h, w return init_image, image def image_transform_grounding_for_vis(init_image): transform = T.Compose([ T.RandomResize([800], max_size=1333), ]) image, _ = transform(init_image, None) # 3, h, w return image model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename) def run_grounding(input_image, grounding_caption, box_threshold, text_threshold): init_image = input_image.convert("RGB") original_size = init_image.size _, image_tensor = image_transform_grounding(init_image) image_pil: Image = image_transform_grounding_for_vis(init_image) # run grounidng boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu') annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) return image_with_box if __name__ == "__main__": parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") args = parser.parse_args() block = gr.Blocks().queue() with block: gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)") gr.Markdown("### Open-World Detection with Grounding DINO") with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type="pil") grounding_caption = gr.Textbox(label="Detection Prompt") run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): box_threshold = gr.Slider( label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 ) text_threshold = gr.Slider( label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 ) with gr.Column(): gallery = gr.outputs.Image( type="pil", # label="grounding results" ).style(full_width=True, full_height=True) # gallery = gr.Gallery(label="Generated images", show_label=False).style( # grid=[1], height="auto", container=True, full_width=True, full_height=True) run_button.click(fn=run_grounding, inputs=[ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery]) block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share) import os import numpy as np import torch from PIL import Image, ImageDraw, ImageFont # please make sure https://github.com/IDEA-Research/GroundingDINO is installed correctly. import groundingdino.datasets.transforms as T from groundingdino.models import build_model from groundingdino.util import box_ops from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap from groundingdino.util.vl_utils import create_positive_map_from_span def plot_boxes_to_image(image_pil, tgt): H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] assert len(boxes) == len(labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) mask_draw = ImageDraw.Draw(mask) # draw boxes and masks for box, label in zip(boxes, labels): # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]) # from xywh to xyxy box[:2] -= box[2:] / 2 box[2:] += box[:2] # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=6) # draw.text((x0, y0), str(label), fill=color) font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) # bbox = draw.textbbox((x0, y0), str(label)) draw.rectangle(bbox, fill=color) draw.text((x0, y0), str(label), fill="white") mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) return image_pil, mask def load_image(image_path): # load image image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image def load_model(model_config_path, model_checkpoint_path, cpu_only=False): args = SLConfig.fromfile(model_config_path) args.device = "cuda" if not cpu_only else "cpu" model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) print(load_res) _ = model.eval() return model def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None): assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!" caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." device = "cuda" if not cpu_only else "cpu" model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"][0] # (nq, 4) # filter output if token_spans is None: logits_filt = logits.cpu().clone() boxes_filt = boxes.cpu().clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) else: # given-phrase mode positive_maps = create_positive_map_from_span( model.tokenizer(text_prompt), token_span=token_spans ).to(image.device) # n_phrase, 256 logits_for_phrases = positive_maps @ logits.T # n_phrase, nq all_logits = [] all_phrases = [] all_boxes = [] for (token_span, logit_phr) in zip(token_spans, logits_for_phrases): # get phrase phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span]) # get mask filt_mask = logit_phr > box_threshold # filt box all_boxes.append(boxes[filt_mask]) # filt logits all_logits.append(logit_phr[filt_mask]) if with_logits: logit_phr_num = logit_phr[filt_mask] all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num]) else: all_phrases.extend([phrase for _ in range(len(filt_mask))]) boxes_filt = torch.cat(all_boxes, dim=0).cpu() pred_phrases = all_phrases return boxes_filt, pred_phrases if __name__ == "__main__": parser = argparse.ArgumentParser("Grounding DINO example", add_help=True) parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file") parser.add_argument( "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" ) parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file") parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt") parser.add_argument( "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" ) parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") parser.add_argument("--token_spans", type=str, default=None, help= "The positions of start and end positions of phrases of interest. \ For example, a caption is 'a cat and a dog', \ if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \ if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \ ") parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False") args = parser.parse_args() # cfg config_file = args.config_file # change the path of the model config file checkpoint_path = args.checkpoint_path # change the path of the model image_path = args.image_path text_prompt = args.text_prompt output_dir = args.output_dir box_threshold = args.box_threshold text_threshold = args.text_threshold token_spans = args.token_spans # make dir os.makedirs(output_dir, exist_ok=True) # load image image_pil, image = load_image(image_path) # load model model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only) # visualize raw image image_pil.save(os.path.join(output_dir, "raw_image.jpg")) # set the text_threshold to None if token_spans is set. if token_spans is not None: text_threshold = None print("Using token_spans. Set the text_threshold to None.") # run model boxes_filt, pred_phrases = get_grounding_output( model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=token_spans ) # visualize pred size = image_pil.size pred_dict = { "boxes": boxes_filt, "size": [size[1], size[0]], # H,W "labels": pred_phrases, } image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0] save_path = os.path.join(output_dir, "pred.jpg") image_with_box.save(save_path) print(f"\n======================\n{save_path} saved.\nThe program runs successfully!")