Spaces:
Running
Running
add possibility to use two networks
Browse files- app.py +10 -6
- resnet101/checkpoint0024.pth +3 -0
- checkpoint0024.pth → resnet50/checkpoint0024.pth +0 -0
- test.py +3 -3
app.py
CHANGED
|
@@ -13,8 +13,8 @@ from models.preprocessing import *
|
|
| 13 |
from models.misc import nested_tensor_from_tensor_list
|
| 14 |
|
| 15 |
|
| 16 |
-
model = create_letr()
|
| 17 |
-
|
| 18 |
# PREPARE PREPROCESSING
|
| 19 |
# transform_test = transforms.Compose([
|
| 20 |
# transforms.Resize((test_size)),
|
|
@@ -38,7 +38,7 @@ normalize_1100 = Compose([
|
|
| 38 |
])
|
| 39 |
|
| 40 |
|
| 41 |
-
def predict(inp, size):
|
| 42 |
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
| 43 |
h, w = image.height, image.width
|
| 44 |
orig_size = torch.as_tensor([int(h), int(w)])
|
|
@@ -52,7 +52,10 @@ def predict(inp, size):
|
|
| 52 |
inputs = nested_tensor_from_tensor_list([img])
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
draw_fig(image, outputs, orig_size)
|
| 58 |
|
|
@@ -62,6 +65,7 @@ def predict(inp, size):
|
|
| 62 |
inputs = [
|
| 63 |
gr.inputs.Image(),
|
| 64 |
gr.inputs.Radio(["256", "512", "1100"]),
|
|
|
|
| 65 |
]
|
| 66 |
outputs = gr.outputs.Image()
|
| 67 |
gr.Interface(
|
|
@@ -69,8 +73,8 @@ gr.Interface(
|
|
| 69 |
inputs=inputs,
|
| 70 |
outputs=outputs,
|
| 71 |
examples=[
|
| 72 |
-
["demo.png", '256'],
|
| 73 |
-
["tappeto-per-calibrazione.jpg", '256']
|
| 74 |
],
|
| 75 |
title="LETR",
|
| 76 |
description="Model for line detection..."
|
|
|
|
| 13 |
from models.misc import nested_tensor_from_tensor_list
|
| 14 |
|
| 15 |
|
| 16 |
+
model = create_letr('resnet50/checkpoint0024.pth')
|
| 17 |
+
model101 = create_letr('resnet101/checkpoint0024.pth')
|
| 18 |
# PREPARE PREPROCESSING
|
| 19 |
# transform_test = transforms.Compose([
|
| 20 |
# transforms.Resize((test_size)),
|
|
|
|
| 38 |
])
|
| 39 |
|
| 40 |
|
| 41 |
+
def predict(inp, size, model_name):
|
| 42 |
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
| 43 |
h, w = image.height, image.width
|
| 44 |
orig_size = torch.as_tensor([int(h), int(w)])
|
|
|
|
| 52 |
inputs = nested_tensor_from_tensor_list([img])
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
+
if model_name == 'resnet101':
|
| 56 |
+
outputs = model101(inputs)[0]
|
| 57 |
+
else:
|
| 58 |
+
outputs = model(inputs)[0]
|
| 59 |
|
| 60 |
draw_fig(image, outputs, orig_size)
|
| 61 |
|
|
|
|
| 65 |
inputs = [
|
| 66 |
gr.inputs.Image(),
|
| 67 |
gr.inputs.Radio(["256", "512", "1100"]),
|
| 68 |
+
gr.inputs.Radio(["resnet50", "resnet101"]),
|
| 69 |
]
|
| 70 |
outputs = gr.outputs.Image()
|
| 71 |
gr.Interface(
|
|
|
|
| 73 |
inputs=inputs,
|
| 74 |
outputs=outputs,
|
| 75 |
examples=[
|
| 76 |
+
["demo.png", '256', "resnet50"],
|
| 77 |
+
["tappeto-per-calibrazione.jpg", '256', "resnet50"]
|
| 78 |
],
|
| 79 |
title="LETR",
|
| 80 |
description="Model for line detection..."
|
resnet101/checkpoint0024.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab07a2ddf5e088540941c755e2cccc66e019dcf94b3ef488bab25f5a76490bb9
|
| 3 |
+
size 457215616
|
checkpoint0024.pth → resnet50/checkpoint0024.pth
RENAMED
|
File without changes
|
test.py
CHANGED
|
@@ -7,9 +7,9 @@ from models.letr import build
|
|
| 7 |
from models.misc import nested_tensor_from_tensor_list
|
| 8 |
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
| 9 |
|
| 10 |
-
def create_letr():
|
| 11 |
# obtain checkpoints
|
| 12 |
-
checkpoint = torch.load(
|
| 13 |
|
| 14 |
# load model
|
| 15 |
args = checkpoint['args']
|
|
@@ -44,7 +44,7 @@ def draw_fig(image, outputs, orig_size):
|
|
| 44 |
draw.line((x1, y1, x2, y2), fill=500)
|
| 45 |
|
| 46 |
if __name__ == '__main__':
|
| 47 |
-
model = create_letr()
|
| 48 |
|
| 49 |
test_size = 256
|
| 50 |
normalize = Compose([
|
|
|
|
| 7 |
from models.misc import nested_tensor_from_tensor_list
|
| 8 |
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
| 9 |
|
| 10 |
+
def create_letr(path):
|
| 11 |
# obtain checkpoints
|
| 12 |
+
checkpoint = torch.load(path, map_location='cpu')
|
| 13 |
|
| 14 |
# load model
|
| 15 |
args = checkpoint['args']
|
|
|
|
| 44 |
draw.line((x1, y1, x2, y2), fill=500)
|
| 45 |
|
| 46 |
if __name__ == '__main__':
|
| 47 |
+
model = create_letr('resnet50/checkpoint0024.pth')
|
| 48 |
|
| 49 |
test_size = 256
|
| 50 |
normalize = Compose([
|