integrate with transformers (#21)
Browse files- integrate with transformers (a845831f8f4651b4b6b3676ad744a0a3c368417f)
- Update README (52584706b45157beda5e5f42670d75dfc29bfed7)
- changing custom pipeline and pinning requirements (2fb3b2ec826f02e10b74e334a6a8678273a57dfc)
Co-authored-by: LAin <[email protected]>
- MyConfig.py +13 -0
- MyPipe.py +73 -0
- README.md +13 -34
- briarmbg.py +8 -7
- config.json +24 -3
- requirements.txt +2 -1
MyConfig.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class RMBGConfig(PretrainedConfig):
|
| 5 |
+
model_type = "SegformerForSemanticSegmentation"
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
in_ch=3,
|
| 9 |
+
out_ch=1,
|
| 10 |
+
**kwargs):
|
| 11 |
+
self.in_ch = in_ch
|
| 12 |
+
self.out_ch = out_ch
|
| 13 |
+
super().__init__(**kwargs)
|
MyPipe.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import Pipeline
|
| 6 |
+
from skimage import io
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
class RMBGPipe(Pipeline):
|
| 10 |
+
def __init__(self,**kwargs):
|
| 11 |
+
Pipeline.__init__(self,**kwargs)
|
| 12 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
self.model.to(self.device)
|
| 14 |
+
self.model.eval()
|
| 15 |
+
|
| 16 |
+
def _sanitize_parameters(self, **kwargs):
|
| 17 |
+
# parse parameters
|
| 18 |
+
preprocess_kwargs = {}
|
| 19 |
+
postprocess_kwargs = {}
|
| 20 |
+
if "model_input_size" in kwargs :
|
| 21 |
+
preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
|
| 22 |
+
if "return_mask" in kwargs:
|
| 23 |
+
postprocess_kwargs["return_mask"] = kwargs["return_mask"]
|
| 24 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
| 25 |
+
|
| 26 |
+
def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
|
| 27 |
+
# preprocess the input
|
| 28 |
+
orig_im = io.imread(im_path)
|
| 29 |
+
orig_im_size = orig_im.shape[0:2]
|
| 30 |
+
image = self.preprocess_image(orig_im, model_input_size).to(self.device)
|
| 31 |
+
inputs = {
|
| 32 |
+
"image":image,
|
| 33 |
+
"orig_im_size":orig_im_size,
|
| 34 |
+
"im_path" : im_path
|
| 35 |
+
}
|
| 36 |
+
return inputs
|
| 37 |
+
|
| 38 |
+
def _forward(self,inputs):
|
| 39 |
+
result = self.model(inputs.pop("image"))
|
| 40 |
+
inputs["result"] = result
|
| 41 |
+
return inputs
|
| 42 |
+
def postprocess(self,inputs,return_mask:bool=False ):
|
| 43 |
+
result = inputs.pop("result")
|
| 44 |
+
orig_im_size = inputs.pop("orig_im_size")
|
| 45 |
+
im_path = inputs.pop("im_path")
|
| 46 |
+
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
| 47 |
+
pil_im = Image.fromarray(result_image)
|
| 48 |
+
if return_mask ==True :
|
| 49 |
+
return pil_im
|
| 50 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
| 51 |
+
orig_image = Image.open(im_path)
|
| 52 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
| 53 |
+
return no_bg_image
|
| 54 |
+
|
| 55 |
+
# utilities functions
|
| 56 |
+
def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
|
| 57 |
+
# same as utilities.py with minor modification
|
| 58 |
+
if len(im.shape) < 3:
|
| 59 |
+
im = im[:, :, np.newaxis]
|
| 60 |
+
# orig_im_size=im.shape[0:2]
|
| 61 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
| 62 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
|
| 63 |
+
image = torch.divide(im_tensor,255.0)
|
| 64 |
+
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
| 65 |
+
return image
|
| 66 |
+
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
| 67 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
| 68 |
+
ma = torch.max(result)
|
| 69 |
+
mi = torch.min(result)
|
| 70 |
+
result = (result-mi)/(ma-mi)
|
| 71 |
+
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
|
| 72 |
+
im_array = np.squeeze(im_array)
|
| 73 |
+
return im_array
|
README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
license: other
|
| 3 |
license_name: bria-rmbg-1.4
|
| 4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
| 5 |
-
pipeline_tag: image-
|
| 6 |
tags:
|
| 7 |
- remove background
|
| 8 |
- background
|
|
@@ -10,6 +10,7 @@ tags:
|
|
| 10 |
- Pytorch
|
| 11 |
- vision
|
| 12 |
- legal liability
|
|
|
|
| 13 |
|
| 14 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
| 15 |
extra_gated_fields:
|
|
@@ -94,43 +95,21 @@ These modifications significantly improve the model’s accuracy and effectivene
|
|
| 94 |
|
| 95 |
## Installation
|
| 96 |
```bash
|
| 97 |
-
|
| 98 |
-
cd RMBG-1.4/
|
| 99 |
-
pip install -r requirements.txt
|
| 100 |
```
|
| 101 |
|
| 102 |
## Usage
|
| 103 |
|
|
|
|
| 104 |
```python
|
| 105 |
-
from
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
from briarmbg import BriaRMBG
|
| 109 |
-
from utilities import preprocess_image, postprocess_image
|
| 110 |
-
|
| 111 |
-
im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
|
| 112 |
-
|
| 113 |
-
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
|
| 114 |
-
|
| 115 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 116 |
-
net.to(device)
|
| 117 |
-
|
| 118 |
-
# prepare input
|
| 119 |
-
model_input_size = [1024,1024]
|
| 120 |
-
orig_im = io.imread(im_path)
|
| 121 |
-
orig_im_size = orig_im.shape[0:2]
|
| 122 |
-
image = preprocess_image(orig_im, model_input_size).to(device)
|
| 123 |
-
|
| 124 |
-
# inference
|
| 125 |
-
result=net(image)
|
| 126 |
-
|
| 127 |
-
# post process
|
| 128 |
-
result_image = postprocess_image(result[0][0], orig_im_size)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
```
|
|
|
|
| 2 |
license: other
|
| 3 |
license_name: bria-rmbg-1.4
|
| 4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
| 5 |
+
pipeline_tag: image-segmentation
|
| 6 |
tags:
|
| 7 |
- remove background
|
| 8 |
- background
|
|
|
|
| 10 |
- Pytorch
|
| 11 |
- vision
|
| 12 |
- legal liability
|
| 13 |
+
- transformers
|
| 14 |
|
| 15 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
| 16 |
extra_gated_fields:
|
|
|
|
| 95 |
|
| 96 |
## Installation
|
| 97 |
```bash
|
| 98 |
+
wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
|
|
|
|
|
|
|
| 99 |
```
|
| 100 |
|
| 101 |
## Usage
|
| 102 |
|
| 103 |
+
either load the model
|
| 104 |
```python
|
| 105 |
+
from transformers import AutoModelForImageSegmentation
|
| 106 |
+
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
|
| 107 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
or load the pipeline
|
| 110 |
+
```python
|
| 111 |
+
from transformers import pipeline
|
| 112 |
+
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
| 113 |
+
pillow_mask = pipe("img_path",return_mask = True) # outputs a pillow mask
|
| 114 |
+
pillow_image = pipe("image_path") # applies mask on input and returns a pillow image
|
| 115 |
```
|
briarmbg.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
from
|
|
|
|
| 5 |
|
| 6 |
class REBNCONV(nn.Module):
|
| 7 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
|
@@ -345,12 +346,12 @@ class myrebnconv(nn.Module):
|
|
| 345 |
return self.rl(self.bn(self.conv(x)))
|
| 346 |
|
| 347 |
|
| 348 |
-
class BriaRMBG(
|
| 349 |
-
|
| 350 |
-
def __init__(self,config
|
| 351 |
-
super(
|
| 352 |
-
in_ch=config
|
| 353 |
-
out_ch=config
|
| 354 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
| 355 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
| 356 |
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from transformers import PreTrainedModel
|
| 5 |
+
from .MyConfig import RMBGConfig
|
| 6 |
|
| 7 |
class REBNCONV(nn.Module):
|
| 8 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
|
|
|
| 346 |
return self.rl(self.bn(self.conv(x)))
|
| 347 |
|
| 348 |
|
| 349 |
+
class BriaRMBG(PreTrainedModel):
|
| 350 |
+
config_class = RMBGConfig
|
| 351 |
+
def __init__(self,config):
|
| 352 |
+
super().__init__(config)
|
| 353 |
+
in_ch = config.in_ch # 3
|
| 354 |
+
out_ch = config.out_ch # 1
|
| 355 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
| 356 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
| 357 |
|
config.json
CHANGED
|
@@ -1,4 +1,25 @@
|
|
| 1 |
{
|
| 2 |
-
"
|
| 3 |
-
"
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "briaai/RMBG-1.4",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BriaRMBG"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "MyConfig.RMBGConfig",
|
| 8 |
+
"AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
|
| 9 |
+
},
|
| 10 |
+
"custom_pipelines": {
|
| 11 |
+
"image-segmentation": {
|
| 12 |
+
"impl": "MyPipe.RMBGPipe",
|
| 13 |
+
"pt": [
|
| 14 |
+
"AutoModelForImageSegmentation"
|
| 15 |
+
],
|
| 16 |
+
"tf": [],
|
| 17 |
+
"type": "image"
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"in_ch": 3,
|
| 21 |
+
"model_type": "SegformerForSemanticSegmentation",
|
| 22 |
+
"out_ch": 1,
|
| 23 |
+
"torch_dtype": "float32",
|
| 24 |
+
"transformers_version": "4.38.0.dev0"
|
| 25 |
+
}
|
requirements.txt
CHANGED
|
@@ -4,4 +4,5 @@ pillow
|
|
| 4 |
numpy
|
| 5 |
typing
|
| 6 |
scikit-image
|
| 7 |
-
huggingface_hub
|
|
|
|
|
|
| 4 |
numpy
|
| 5 |
typing
|
| 6 |
scikit-image
|
| 7 |
+
huggingface_hub
|
| 8 |
+
transformers==4.39.1
|