JeffreyXiang commited on
Commit
98b3116
·
1 Parent(s): a1e3f5f
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -17,6 +17,7 @@ import numpy as np
17
  from PIL import Image
18
  import base64
19
  import io
 
20
  from trellis2.modules.sparse import SparseTensor
21
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
22
  from trellis2.renderers import EnvMap
@@ -279,20 +280,19 @@ def end_session(req: gr.Request):
279
  shutil.rmtree(user_dir)
280
 
281
 
282
- def remove_background(input: Image.Image, user_dir: str) -> Image.Image:
283
- input = input.convert('RGB')
284
- os.makedirs(user_dir, exist_ok=True)
285
- input.save(os.path.join(user_dir, 'input.png'))
286
- output = rmbg_client.predict(handle_file(os.path.join(user_dir, 'input.png')), api_name="/image")[0][0]
287
- output = Image.open(output)
288
- return output
289
 
290
 
291
- def preprocess_image(input: Image.Image, req: gr.Request,) -> Image.Image:
292
  """
293
  Preprocess the input image.
294
  """
295
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
296
  # if has alpha channel, use it directly; otherwise, remove background
297
  has_alpha = False
298
  if input.mode == 'RGBA':
@@ -306,7 +306,7 @@ def preprocess_image(input: Image.Image, req: gr.Request,) -> Image.Image:
306
  if has_alpha:
307
  output = input
308
  else:
309
- output = remove_background(input, user_dir)
310
  output_np = np.array(output)
311
  alpha = output_np[:, :, 3]
312
  bbox = np.argwhere(alpha > 0.8 * 255)
 
17
  from PIL import Image
18
  import base64
19
  import io
20
+ import tempfile
21
  from trellis2.modules.sparse import SparseTensor
22
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
23
  from trellis2.renderers import EnvMap
 
280
  shutil.rmtree(user_dir)
281
 
282
 
283
+ def remove_background(input: Image.Image) -> Image.Image:
284
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
285
+ input = input.convert('RGB')
286
+ input.save(f.name)
287
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
288
+ output = Image.open(output)
289
+ return output
290
 
291
 
292
+ def preprocess_image(input: Image.Image) -> Image.Image:
293
  """
294
  Preprocess the input image.
295
  """
 
296
  # if has alpha channel, use it directly; otherwise, remove background
297
  has_alpha = False
298
  if input.mode == 'RGBA':
 
306
  if has_alpha:
307
  output = input
308
  else:
309
+ output = remove_background(input)
310
  output_np = np.array(output)
311
  alpha = output_np[:, :, 3]
312
  bbox = np.argwhere(alpha > 0.8 * 255)