JeffreyXiang commited on
Commit
1d2bd93
·
1 Parent(s): 2d8d8e7
app.py CHANGED
@@ -34,20 +34,46 @@ def start_session(req: gr.Request):
34
  def end_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
  shutil.rmtree(user_dir)
 
 
 
 
 
 
 
37
 
38
 
39
- def preprocess_image(image: Image.Image) -> Image.Image:
40
  """
41
  Preprocess the input image.
42
-
43
- Args:
44
- image (Image.Image): The input image.
45
-
46
- Returns:
47
- Image.Image: The preprocessed image.
48
  """
49
- processed_image = pipeline.preprocess_image(image)
50
- return processed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
34
  def end_session(req: gr.Request):
35
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
36
  shutil.rmtree(user_dir)
37
+
38
+
39
+ @spaces.GPU()
40
+ def remove_background(input: Image.Image) -> Image.Image:
41
+ input = input.convert('RGB')
42
+ output = pipeline.rembg_model(input)
43
+ return output
44
 
45
 
46
+ def preprocess_image(self, input: Image.Image) -> Image.Image:
47
  """
48
  Preprocess the input image.
 
 
 
 
 
 
49
  """
50
+ # if has alpha channel, use it directly; otherwise, remove background
51
+ has_alpha = False
52
+ if input.mode == 'RGBA':
53
+ alpha = np.array(input)[:, :, 3]
54
+ if not np.all(alpha == 255):
55
+ has_alpha = True
56
+ max_size = max(input.size)
57
+ scale = min(1, 1024 / max_size)
58
+ if scale < 1:
59
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
60
+ if has_alpha:
61
+ output = input
62
+ else:
63
+ output = remove_background(input)
64
+ output_np = np.array(output)
65
+ alpha = output_np[:, :, 3]
66
+ bbox = np.argwhere(alpha > 0.8 * 255)
67
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
68
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
69
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
70
+ size = int(size * 1)
71
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
72
+ output = output.crop(bbox) # type: ignore
73
+ output = np.array(output).astype(np.float32) / 255
74
+ output = output[:, :, :3] * output[:, :, 3:4]
75
+ output = Image.fromarray((output * 255).astype(np.uint8))
76
+ return output
77
 
78
 
79
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -11,13 +11,6 @@ from ..modules import image_feature_extractor
11
  from ..representations import Mesh, MeshWithVoxel
12
 
13
 
14
- @spaces.GPU()
15
- def remove_background(rembg_model, input: Image.Image) -> Image.Image:
16
- input = input.convert('RGB')
17
- output = rembg_model(input)
18
- return output
19
-
20
-
21
  class Trellis2ImageTo3DPipeline(Pipeline):
22
  """
23
  Pipeline for inferring Trellis2 image-to-3D models.
@@ -139,9 +132,10 @@ class Trellis2ImageTo3DPipeline(Pipeline):
139
  if has_alpha:
140
  output = input
141
  else:
 
142
  if self.low_vram:
143
  self.rembg_model.to(self.device)
144
- output = remove_background(self.rembg_model, input)
145
  if self.low_vram:
146
  self.rembg_model.cpu()
147
  output_np = np.array(output)
 
11
  from ..representations import Mesh, MeshWithVoxel
12
 
13
 
 
 
 
 
 
 
 
14
  class Trellis2ImageTo3DPipeline(Pipeline):
15
  """
16
  Pipeline for inferring Trellis2 image-to-3D models.
 
132
  if has_alpha:
133
  output = input
134
  else:
135
+ input = input.convert('RGB')
136
  if self.low_vram:
137
  self.rembg_model.to(self.device)
138
+ output = self.rembg_model(input)
139
  if self.low_vram:
140
  self.rembg_model.cpu()
141
  output_np = np.array(output)