Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						cd41f5f
	
1
								Parent(s):
							
							4cd032d
								
Manage cache use gradio
Browse files
    	
        app.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ import spaces | |
| 3 | 
             
            from gradio_litmodel3d import LitModel3D
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
|  | |
| 6 | 
             
            os.environ['SPCONV_ALGO'] = 'native'
         | 
| 7 | 
             
            from typing import *
         | 
| 8 | 
             
            import torch
         | 
| @@ -17,11 +18,22 @@ from trellis.utils import render_utils, postprocessing_utils | |
| 17 |  | 
| 18 |  | 
| 19 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 20 | 
            -
            TMP_DIR =  | 
| 21 | 
            -
             | 
| 22 | 
             
            os.makedirs(TMP_DIR, exist_ok=True)
         | 
| 23 |  | 
| 24 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 | 
             
            def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
         | 
| 26 | 
             
                """
         | 
| 27 | 
             
                Preprocess the input image.
         | 
| @@ -33,10 +45,8 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]: | |
| 33 | 
             
                    str: uuid of the trial.
         | 
| 34 | 
             
                    Image.Image: The preprocessed image.
         | 
| 35 | 
             
                """
         | 
| 36 | 
            -
                trial_id = str(uuid.uuid4())
         | 
| 37 | 
             
                processed_image = pipeline.preprocess_image(image)
         | 
| 38 | 
            -
                processed_image | 
| 39 | 
            -
                return trial_id, processed_image
         | 
| 40 |  | 
| 41 |  | 
| 42 | 
             
            def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
         | 
| @@ -80,15 +90,29 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: | |
| 80 | 
             
                return gs, mesh, state['trial_id']
         | 
| 81 |  | 
| 82 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 83 | 
             
            @spaces.GPU
         | 
| 84 | 
            -
            def image_to_3d( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
                """
         | 
| 86 | 
             
                Convert an image to a 3D model.
         | 
| 87 |  | 
| 88 | 
             
                Args:
         | 
| 89 | 
            -
                     | 
| 90 | 
             
                    seed (int): The random seed.
         | 
| 91 | 
            -
                    randomize_seed (bool): Whether to randomize the seed.
         | 
| 92 | 
             
                    ss_guidance_strength (float): The guidance strength for sparse structure generation.
         | 
| 93 | 
             
                    ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
         | 
| 94 | 
             
                    slat_guidance_strength (float): The guidance strength for structured latent generation.
         | 
| @@ -98,10 +122,9 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre | |
| 98 | 
             
                    dict: The information of the generated 3D model.
         | 
| 99 | 
             
                    str: The path to the video of the 3D model.
         | 
| 100 | 
             
                """
         | 
| 101 | 
            -
                 | 
| 102 | 
            -
                    seed = np.random.randint(0, MAX_SEED)
         | 
| 103 | 
             
                outputs = pipeline.run(
         | 
| 104 | 
            -
                     | 
| 105 | 
             
                    seed=seed,
         | 
| 106 | 
             
                    formats=["gaussian", "mesh"],
         | 
| 107 | 
             
                    preprocess_image=False,
         | 
| @@ -118,15 +141,20 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre | |
| 118 | 
             
                video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
         | 
| 119 | 
             
                video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
         | 
| 120 | 
             
                trial_id = uuid.uuid4()
         | 
| 121 | 
            -
                video_path = f"{ | 
| 122 | 
            -
                os.makedirs(os.path.dirname(video_path), exist_ok=True)
         | 
| 123 | 
             
                imageio.mimsave(video_path, video, fps=15)
         | 
| 124 | 
             
                state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
         | 
|  | |
| 125 | 
             
                return state, video_path
         | 
| 126 |  | 
| 127 |  | 
| 128 | 
             
            @spaces.GPU
         | 
| 129 | 
            -
            def extract_glb( | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 130 | 
             
                """
         | 
| 131 | 
             
                Extract a GLB file from the 3D model.
         | 
| 132 |  | 
| @@ -138,22 +166,16 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s | |
| 138 | 
             
                Returns:
         | 
| 139 | 
             
                    str: The path to the extracted GLB file.
         | 
| 140 | 
             
                """
         | 
|  | |
| 141 | 
             
                gs, mesh, trial_id = unpack_state(state)
         | 
| 142 | 
             
                glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
         | 
| 143 | 
            -
                glb_path = f"{ | 
| 144 | 
             
                glb.export(glb_path)
         | 
|  | |
| 145 | 
             
                return glb_path, glb_path
         | 
| 146 |  | 
| 147 |  | 
| 148 | 
            -
             | 
| 149 | 
            -
                return gr.Button(interactive=True)
         | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
            def deactivate_button() -> gr.Button:
         | 
| 153 | 
            -
                return gr.Button(interactive=False)
         | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
            with gr.Blocks() as demo:
         | 
| 157 | 
             
                gr.Markdown("""
         | 
| 158 | 
             
                ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
         | 
| 159 | 
             
                * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
         | 
| @@ -162,7 +184,7 @@ with gr.Blocks() as demo: | |
| 162 |  | 
| 163 | 
             
                with gr.Row():
         | 
| 164 | 
             
                    with gr.Column():
         | 
| 165 | 
            -
                        image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
         | 
| 166 |  | 
| 167 | 
             
                        with gr.Accordion(label="Generation Settings", open=False):
         | 
| 168 | 
             
                            seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
         | 
| @@ -189,7 +211,6 @@ with gr.Blocks() as demo: | |
| 189 | 
             
                        model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
         | 
| 190 | 
             
                        download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
         | 
| 191 |  | 
| 192 | 
            -
                trial_id = gr.Textbox(visible=False)
         | 
| 193 | 
             
                output_buf = gr.State()
         | 
| 194 |  | 
| 195 | 
             
                # Example images at the bottom of the page
         | 
| @@ -201,33 +222,36 @@ with gr.Blocks() as demo: | |
| 201 | 
             
                        ],
         | 
| 202 | 
             
                        inputs=[image_prompt],
         | 
| 203 | 
             
                        fn=preprocess_image,
         | 
| 204 | 
            -
                        outputs=[ | 
| 205 | 
             
                        run_on_click=True,
         | 
| 206 | 
             
                        examples_per_page=64,
         | 
| 207 | 
             
                    )
         | 
| 208 |  | 
| 209 | 
             
                # Handlers
         | 
|  | |
|  | |
|  | |
| 210 | 
             
                image_prompt.upload(
         | 
| 211 | 
             
                    preprocess_image,
         | 
| 212 | 
             
                    inputs=[image_prompt],
         | 
| 213 | 
            -
                    outputs=[ | 
| 214 | 
            -
                )
         | 
| 215 | 
            -
                image_prompt.clear(
         | 
| 216 | 
            -
                    lambda: '',
         | 
| 217 | 
            -
                    outputs=[trial_id],
         | 
| 218 | 
             
                )
         | 
| 219 |  | 
| 220 | 
             
                generate_btn.click(
         | 
|  | |
|  | |
|  | |
|  | |
| 221 | 
             
                    image_to_3d,
         | 
| 222 | 
            -
                    inputs=[ | 
| 223 | 
             
                    outputs=[output_buf, video_output],
         | 
| 224 | 
             
                ).then(
         | 
| 225 | 
            -
                     | 
| 226 | 
             
                    outputs=[extract_glb_btn],
         | 
| 227 | 
             
                )
         | 
| 228 |  | 
| 229 | 
             
                video_output.clear(
         | 
| 230 | 
            -
                     | 
| 231 | 
             
                    outputs=[extract_glb_btn],
         | 
| 232 | 
             
                )
         | 
| 233 |  | 
| @@ -236,33 +260,16 @@ with gr.Blocks() as demo: | |
| 236 | 
             
                    inputs=[output_buf, mesh_simplify, texture_size],
         | 
| 237 | 
             
                    outputs=[model_output, download_glb],
         | 
| 238 | 
             
                ).then(
         | 
| 239 | 
            -
                     | 
| 240 | 
             
                    outputs=[download_glb],
         | 
| 241 | 
             
                )
         | 
| 242 |  | 
| 243 | 
             
                model_output.clear(
         | 
| 244 | 
            -
                     | 
| 245 | 
             
                    outputs=[download_glb],
         | 
| 246 | 
             
                )
         | 
| 247 |  | 
| 248 |  | 
| 249 | 
            -
            # Cleans up the temporary directory every 10 minutes
         | 
| 250 | 
            -
            import threading
         | 
| 251 | 
            -
            import time
         | 
| 252 | 
            -
             | 
| 253 | 
            -
            def cleanup_tmp_dir():
         | 
| 254 | 
            -
                while True:
         | 
| 255 | 
            -
                    if os.path.exists(TMP_DIR):
         | 
| 256 | 
            -
                        for file in os.listdir(TMP_DIR):
         | 
| 257 | 
            -
                            # remove files older than 10 minutes
         | 
| 258 | 
            -
                            if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
         | 
| 259 | 
            -
                                os.remove(os.path.join(TMP_DIR, file))
         | 
| 260 | 
            -
                    time.sleep(600)
         | 
| 261 | 
            -
                            
         | 
| 262 | 
            -
            cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
         | 
| 263 | 
            -
            cleanup_thread.start()
         | 
| 264 | 
            -
                
         | 
| 265 | 
            -
             | 
| 266 | 
             
            # Launch the Gradio app
         | 
| 267 | 
             
            if __name__ == "__main__":
         | 
| 268 | 
             
                pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
         | 
|  | |
| 3 | 
             
            from gradio_litmodel3d import LitModel3D
         | 
| 4 |  | 
| 5 | 
             
            import os
         | 
| 6 | 
            +
            import shutil
         | 
| 7 | 
             
            os.environ['SPCONV_ALGO'] = 'native'
         | 
| 8 | 
             
            from typing import *
         | 
| 9 | 
             
            import torch
         | 
|  | |
| 18 |  | 
| 19 |  | 
| 20 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 21 | 
            +
            TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
         | 
|  | |
| 22 | 
             
            os.makedirs(TMP_DIR, exist_ok=True)
         | 
| 23 |  | 
| 24 |  | 
| 25 | 
            +
            def start_session(req: gr.Request):
         | 
| 26 | 
            +
                user_dir = os.path.join(TMP_DIR, str(req.session_hash))
         | 
| 27 | 
            +
                print(f'Creating user directory: {user_dir}')
         | 
| 28 | 
            +
                os.makedirs(user_dir, exist_ok=True)
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
            def end_session(req: gr.Request):
         | 
| 32 | 
            +
                user_dir = os.path.join(TMP_DIR, str(req.session_hash))
         | 
| 33 | 
            +
                print(f'Removing user directory: {user_dir}')
         | 
| 34 | 
            +
                shutil.rmtree(user_dir)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
             
            def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
         | 
| 38 | 
             
                """
         | 
| 39 | 
             
                Preprocess the input image.
         | 
|  | |
| 45 | 
             
                    str: uuid of the trial.
         | 
| 46 | 
             
                    Image.Image: The preprocessed image.
         | 
| 47 | 
             
                """
         | 
|  | |
| 48 | 
             
                processed_image = pipeline.preprocess_image(image)
         | 
| 49 | 
            +
                return processed_image
         | 
|  | |
| 50 |  | 
| 51 |  | 
| 52 | 
             
            def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
         | 
|  | |
| 90 | 
             
                return gs, mesh, state['trial_id']
         | 
| 91 |  | 
| 92 |  | 
| 93 | 
            +
            def get_seed(randomize_seed: bool, seed: int) -> int:
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                Get the random seed.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                return np.random.randint(0, MAX_SEED) if randomize_seed else seed
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
             
            @spaces.GPU
         | 
| 101 | 
            +
            def image_to_3d(
         | 
| 102 | 
            +
                image: Image.Image,
         | 
| 103 | 
            +
                seed: int,
         | 
| 104 | 
            +
                ss_guidance_strength: float,
         | 
| 105 | 
            +
                ss_sampling_steps: int,
         | 
| 106 | 
            +
                slat_guidance_strength: float,
         | 
| 107 | 
            +
                slat_sampling_steps: int,
         | 
| 108 | 
            +
                req: gr.Request,
         | 
| 109 | 
            +
            ) -> Tuple[dict, str]:
         | 
| 110 | 
             
                """
         | 
| 111 | 
             
                Convert an image to a 3D model.
         | 
| 112 |  | 
| 113 | 
             
                Args:
         | 
| 114 | 
            +
                    image (Image.Image): The input image.
         | 
| 115 | 
             
                    seed (int): The random seed.
         | 
|  | |
| 116 | 
             
                    ss_guidance_strength (float): The guidance strength for sparse structure generation.
         | 
| 117 | 
             
                    ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
         | 
| 118 | 
             
                    slat_guidance_strength (float): The guidance strength for structured latent generation.
         | 
|  | |
| 122 | 
             
                    dict: The information of the generated 3D model.
         | 
| 123 | 
             
                    str: The path to the video of the 3D model.
         | 
| 124 | 
             
                """
         | 
| 125 | 
            +
                user_dir = os.path.join(TMP_DIR, str(req.session_hash))
         | 
|  | |
| 126 | 
             
                outputs = pipeline.run(
         | 
| 127 | 
            +
                    image,
         | 
| 128 | 
             
                    seed=seed,
         | 
| 129 | 
             
                    formats=["gaussian", "mesh"],
         | 
| 130 | 
             
                    preprocess_image=False,
         | 
|  | |
| 141 | 
             
                video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
         | 
| 142 | 
             
                video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
         | 
| 143 | 
             
                trial_id = uuid.uuid4()
         | 
| 144 | 
            +
                video_path = os.path.join(user_dir, f"{trial_id}.mp4")
         | 
|  | |
| 145 | 
             
                imageio.mimsave(video_path, video, fps=15)
         | 
| 146 | 
             
                state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
         | 
| 147 | 
            +
                torch.cuda.empty_cache()
         | 
| 148 | 
             
                return state, video_path
         | 
| 149 |  | 
| 150 |  | 
| 151 | 
             
            @spaces.GPU
         | 
| 152 | 
            +
            def extract_glb(
         | 
| 153 | 
            +
                state: dict,
         | 
| 154 | 
            +
                mesh_simplify: float,
         | 
| 155 | 
            +
                texture_size: int,
         | 
| 156 | 
            +
                req: gr.Request,
         | 
| 157 | 
            +
            ) -> Tuple[str, str]:
         | 
| 158 | 
             
                """
         | 
| 159 | 
             
                Extract a GLB file from the 3D model.
         | 
| 160 |  | 
|  | |
| 166 | 
             
                Returns:
         | 
| 167 | 
             
                    str: The path to the extracted GLB file.
         | 
| 168 | 
             
                """
         | 
| 169 | 
            +
                user_dir = os.path.join(TMP_DIR, str(req.session_hash))
         | 
| 170 | 
             
                gs, mesh, trial_id = unpack_state(state)
         | 
| 171 | 
             
                glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
         | 
| 172 | 
            +
                glb_path = os.path.join(user_dir, f"{trial_id}.glb")
         | 
| 173 | 
             
                glb.export(glb_path)
         | 
| 174 | 
            +
                torch.cuda.empty_cache()
         | 
| 175 | 
             
                return glb_path, glb_path
         | 
| 176 |  | 
| 177 |  | 
| 178 | 
            +
            with gr.Blocks(delete_cache=(600, 600)) as demo:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 179 | 
             
                gr.Markdown("""
         | 
| 180 | 
             
                ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
         | 
| 181 | 
             
                * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
         | 
|  | |
| 184 |  | 
| 185 | 
             
                with gr.Row():
         | 
| 186 | 
             
                    with gr.Column():
         | 
| 187 | 
            +
                        image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
         | 
| 188 |  | 
| 189 | 
             
                        with gr.Accordion(label="Generation Settings", open=False):
         | 
| 190 | 
             
                            seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
         | 
|  | |
| 211 | 
             
                        model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
         | 
| 212 | 
             
                        download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
         | 
| 213 |  | 
|  | |
| 214 | 
             
                output_buf = gr.State()
         | 
| 215 |  | 
| 216 | 
             
                # Example images at the bottom of the page
         | 
|  | |
| 222 | 
             
                        ],
         | 
| 223 | 
             
                        inputs=[image_prompt],
         | 
| 224 | 
             
                        fn=preprocess_image,
         | 
| 225 | 
            +
                        outputs=[image_prompt],
         | 
| 226 | 
             
                        run_on_click=True,
         | 
| 227 | 
             
                        examples_per_page=64,
         | 
| 228 | 
             
                    )
         | 
| 229 |  | 
| 230 | 
             
                # Handlers
         | 
| 231 | 
            +
                demo.load(start_session)
         | 
| 232 | 
            +
                demo.unload(end_session)
         | 
| 233 | 
            +
                
         | 
| 234 | 
             
                image_prompt.upload(
         | 
| 235 | 
             
                    preprocess_image,
         | 
| 236 | 
             
                    inputs=[image_prompt],
         | 
| 237 | 
            +
                    outputs=[image_prompt],
         | 
|  | |
|  | |
|  | |
|  | |
| 238 | 
             
                )
         | 
| 239 |  | 
| 240 | 
             
                generate_btn.click(
         | 
| 241 | 
            +
                    get_seed,
         | 
| 242 | 
            +
                    inputs=[randomize_seed, seed],
         | 
| 243 | 
            +
                    outputs=[seed],
         | 
| 244 | 
            +
                ).then(
         | 
| 245 | 
             
                    image_to_3d,
         | 
| 246 | 
            +
                    inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
         | 
| 247 | 
             
                    outputs=[output_buf, video_output],
         | 
| 248 | 
             
                ).then(
         | 
| 249 | 
            +
                    lambda: gr.Button(interactive=True),
         | 
| 250 | 
             
                    outputs=[extract_glb_btn],
         | 
| 251 | 
             
                )
         | 
| 252 |  | 
| 253 | 
             
                video_output.clear(
         | 
| 254 | 
            +
                    lambda: gr.Button(interactive=False),
         | 
| 255 | 
             
                    outputs=[extract_glb_btn],
         | 
| 256 | 
             
                )
         | 
| 257 |  | 
|  | |
| 260 | 
             
                    inputs=[output_buf, mesh_simplify, texture_size],
         | 
| 261 | 
             
                    outputs=[model_output, download_glb],
         | 
| 262 | 
             
                ).then(
         | 
| 263 | 
            +
                    lambda: gr.Button(interactive=True),
         | 
| 264 | 
             
                    outputs=[download_glb],
         | 
| 265 | 
             
                )
         | 
| 266 |  | 
| 267 | 
             
                model_output.clear(
         | 
| 268 | 
            +
                    lambda: gr.Button(interactive=False),
         | 
| 269 | 
             
                    outputs=[download_glb],
         | 
| 270 | 
             
                )
         | 
| 271 |  | 
| 272 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 273 | 
             
            # Launch the Gradio app
         | 
| 274 | 
             
            if __name__ == "__main__":
         | 
| 275 | 
             
                pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
         | 
 
			
