Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	update
Browse files- app.py +1 -1
- assets/instruction.md +3 -3
- gradio_tabs/animation.py +16 -36
- gradio_tabs/img_edit.py +2 -14
- gradio_tabs/vid_edit.py +8 -13
    	
        app.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ ckpt_path = hf_hub_download(repo_id="YaohuiW/LIA-X", filename="lia-x.pt") | |
| 17 | 
             
            gen.load_state_dict(torch.load(ckpt_path, weights_only=True))
         | 
| 18 | 
             
            gen.eval()
         | 
| 19 |  | 
| 20 | 
            -
            chunk_size= | 
| 21 |  | 
| 22 | 
             
            def load_file(path):
         | 
| 23 |  | 
|  | |
| 17 | 
             
            gen.load_state_dict(torch.load(ckpt_path, weights_only=True))
         | 
| 18 | 
             
            gen.eval()
         | 
| 19 |  | 
| 20 | 
            +
            chunk_size=30
         | 
| 21 |  | 
| 22 | 
             
            def load_file(path):
         | 
| 23 |  | 
    	
        assets/instruction.md
    CHANGED
    
    | @@ -3,18 +3,18 @@ | |
| 3 | 
             
            * **Image Animation**
         | 
| 4 |  | 
| 5 | 
             
                - Upload `Source Image` and `Driving Video`
         | 
| 6 | 
            -
                - Using sliders in the `Control Panel` to edit image
         | 
| 7 | 
             
            	- Use `Animate` button to obtain `Animated Video`
         | 
| 8 |  | 
| 9 | 
             
            * **Image Editing**
         | 
| 10 |  | 
| 11 | 
             
                - Upload `Source Image`
         | 
| 12 | 
            -
            	- Using sliders in the `Control Panel` to edit image
         | 
| 13 |  | 
| 14 | 
             
            * **Video Editing**
         | 
| 15 |  | 
| 16 | 
             
                - Upload `Video`
         | 
| 17 | 
            -
            	- Using sliders in the `Control Panel` to edit image
         | 
| 18 | 
             
                - Use `Generate` button to obtain `Edited Video`
         | 
| 19 |  | 
| 20 | 
             
            **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
         | 
|  | |
| 3 | 
             
            * **Image Animation**
         | 
| 4 |  | 
| 5 | 
             
                - Upload `Source Image` and `Driving Video`
         | 
| 6 | 
            +
                - Using `sliders` in the `Control Panel` to edit image
         | 
| 7 | 
             
            	- Use `Animate` button to obtain `Animated Video`
         | 
| 8 |  | 
| 9 | 
             
            * **Image Editing**
         | 
| 10 |  | 
| 11 | 
             
                - Upload `Source Image`
         | 
| 12 | 
            +
            	- Using `sliders` in the `Control Panel` to edit image
         | 
| 13 |  | 
| 14 | 
             
            * **Video Editing**
         | 
| 15 |  | 
| 16 | 
             
                - Upload `Video`
         | 
| 17 | 
            +
            	- Using `sliders` in the `Control Panel` to edit image
         | 
| 18 | 
             
                - Use `Generate` button to obtain `Edited Video`
         | 
| 19 |  | 
| 20 | 
             
            **NOTE: we recommend to crop both input images and videos using provided [tools](https://github.com/wyhsirius/LIA-X/tree/main) for better results**
         | 
    	
        gradio_tabs/animation.py
    CHANGED
    
    | @@ -90,10 +90,6 @@ def vid_preprocessing(vid_path, size): | |
| 90 | 
             
            	vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
         | 
| 91 | 
             
            	fps = vid_dict[2]['video_fps']
         | 
| 92 | 
             
            	vid_norm = (vid / 255.0 - 0.5) * 2.0  # [-1, 1]
         | 
| 93 | 
            -
             | 
| 94 | 
            -
            	#vid_norm = torch.cat([
         | 
| 95 | 
            -
            	#	resize(vid_norm[i:i+1, :, :, :], size).unsqueeze(1) for i in range(vid.size(0))
         | 
| 96 | 
            -
            	#], dim=1)
         | 
| 97 | 
             
            	vid_norm = resize(vid_norm, size) # tchw	
         | 
| 98 |  | 
| 99 | 
             
            	return vid_norm, fps
         | 
| @@ -135,9 +131,7 @@ def vid_postprocessing(video, w, h, fps): | |
| 135 |  | 
| 136 | 
             
            	t,c,_,_ = video.size()
         | 
| 137 | 
             
            	vid = resize_back(video, w, h)
         | 
| 138 | 
            -
             | 
| 139 | 
            -
            	vid = vid.clamp(-1, 1)
         | 
| 140 | 
            -
            	vid = (vid - vid.min()) / (vid.max() - vid.min())
         | 
| 141 |  | 
| 142 | 
             
            	vid = rearrange(vid, "t c h w -> t h w c")	# T H W C
         | 
| 143 | 
             
            	vid_np = (vid.cpu().numpy() * 255).astype('uint8')
         | 
| @@ -215,30 +209,27 @@ def animation(gen, chunk_size, device): | |
| 215 | 
             
            		vid_target_tensor, fps = vid_preprocessing(video, 512)
         | 
| 216 | 
             
            		image_tensor = image_tensor.to(device)
         | 
| 217 | 
             
            		video_target_tensor = vid_target_tensor.to(device) #tchw
         | 
| 218 | 
            -
             | 
| 219 | 
            -
            		#animated_video = gen.animate_batch(image_tensor, video_target_tensor, labels_v, selected_s, chunk_size)
         | 
| 220 | 
            -
            		#edited_image = animated_video[:,:,0,:,:]
         | 
| 221 |  | 
| 222 | 
             
            		img_start = video_target_tensor[0:1,:,:,:]
         | 
| 223 | 
            -
            		#vid_target_tensor_batch = rearrange(video_target_tensor, 'b t c h w -> (b t) c h w')
         | 
| 224 |  | 
| 225 | 
             
            		res = []
         | 
| 226 | 
            -
            		t = video_target_tensor.size( | 
|  | |
| 227 | 
             
            		chunks = t // chunk_size
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 228 | 
             
            		z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
         | 
| 229 | 
            -
            		#z_s2r, alpha_r2s, feat_rgb = gen.enc_img(image_tensor, labels_v, selected_s)
         | 
| 230 | 
             
            		for i in range(chunks+1):
         | 
| 231 | 
            -
             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
| 235 | 
            -
            			 | 
| 236 | 
            -
             | 
| 237 | 
            -
            				img_animated = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target)
         | 
| 238 | 
            -
            				#img_animated_batch = gen.dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
         | 
| 239 | 
            -
             | 
| 240 | 
            -
            			res.append(img_animated)
         | 
| 241 | 
            -
            		animated_video = torch.cat(res, dim=0) # TCHW
         | 
| 242 | 
             
            		edited_image = animated_video[0:1,:,:,:]
         | 
| 243 |  | 
| 244 | 
             
            		# postprocessing
         | 
| @@ -308,7 +299,7 @@ def animation(gen, chunk_size, device): | |
| 308 | 
             
            						#video_output.render()
         | 
| 309 | 
             
            						video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
         | 
| 310 |  | 
| 311 | 
            -
            				with gr.Accordion("Control Panel  | 
| 312 | 
             
            					with gr.Tab("Head"):
         | 
| 313 | 
             
            						with gr.Row():
         | 
| 314 | 
             
            							for k in labels_k[:3]:
         | 
| @@ -344,23 +335,12 @@ def animation(gen, chunk_size, device): | |
| 344 | 
             
            				fn=edit_media,
         | 
| 345 | 
             
            				inputs=[image_input] + inputs_s,
         | 
| 346 | 
             
            				outputs=[image_output],
         | 
| 347 | 
            -
             | 
| 348 | 
             
            				show_progress='hidden',
         | 
| 349 | 
            -
             | 
| 350 | 
             
            				trigger_mode='always_last',
         | 
| 351 | 
            -
             | 
| 352 | 
             
            				# currently we have a latency around 450ms
         | 
| 353 | 
             
            				stream_every=0.5
         | 
| 354 | 
             
            			)
         | 
| 355 |  | 
| 356 | 
            -
             | 
| 357 | 
            -
            		#edit_btn.click(
         | 
| 358 | 
            -
            		#	fn=edit_media,
         | 
| 359 | 
            -
            		#	inputs=[image_input] + inputs_s,
         | 
| 360 | 
            -
            		#	outputs=[image_output],
         | 
| 361 | 
            -
            		#	show_progress=True
         | 
| 362 | 
            -
            		#)
         | 
| 363 | 
            -
             | 
| 364 | 
             
            		animate_btn.click(
         | 
| 365 | 
             
            			fn=animate_media,
         | 
| 366 | 
             
            			inputs=[image_input, video_input] + inputs_s,
         | 
|  | |
| 90 | 
             
            	vid = vid_dict[0].permute(0, 3, 1, 2) # tchw
         | 
| 91 | 
             
            	fps = vid_dict[2]['video_fps']
         | 
| 92 | 
             
            	vid_norm = (vid / 255.0 - 0.5) * 2.0  # [-1, 1]
         | 
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
            	vid_norm = resize(vid_norm, size) # tchw	
         | 
| 94 |  | 
| 95 | 
             
            	return vid_norm, fps
         | 
|  | |
| 131 |  | 
| 132 | 
             
            	t,c,_,_ = video.size()
         | 
| 133 | 
             
            	vid = resize_back(video, w, h)
         | 
| 134 | 
            +
            	vid = vid_denorm(vid)
         | 
|  | |
|  | |
| 135 |  | 
| 136 | 
             
            	vid = rearrange(vid, "t c h w -> t h w c")	# T H W C
         | 
| 137 | 
             
            	vid_np = (vid.cpu().numpy() * 255).astype('uint8')
         | 
|  | |
| 209 | 
             
            		vid_target_tensor, fps = vid_preprocessing(video, 512)
         | 
| 210 | 
             
            		image_tensor = image_tensor.to(device)
         | 
| 211 | 
             
            		video_target_tensor = vid_target_tensor.to(device) #tchw
         | 
|  | |
|  | |
|  | |
| 212 |  | 
| 213 | 
             
            		img_start = video_target_tensor[0:1,:,:,:]
         | 
|  | |
| 214 |  | 
| 215 | 
             
            		res = []
         | 
| 216 | 
            +
            		t, c, h, w = video_target_tensor.size()
         | 
| 217 | 
            +
             | 
| 218 | 
             
            		chunks = t // chunk_size
         | 
| 219 | 
            +
            		if t%chunk_size == 0:
         | 
| 220 | 
            +
            			vid_target_tensor_batch = torch.zeros(chunk_size * chunks, c, h, w).to(device)
         | 
| 221 | 
            +
            		else:
         | 
| 222 | 
            +
            			vid_target_tensor_batch = torch.zeros(chunk_size * (chunks + 1), c, h, w).to(device)
         | 
| 223 | 
            +
            		vid_target_tensor_batch[:t] = video_target_tensor
         | 
| 224 | 
            +
             | 
| 225 | 
             
            		z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(image_tensor, selected_s)
         | 
|  | |
| 226 | 
             
            		for i in range(chunks+1):
         | 
| 227 | 
            +
             | 
| 228 | 
            +
            			img_target_batch = vid_target_tensor_batch[i * chunk_size:(i + 1) * chunk_size, :, :, :]
         | 
| 229 | 
            +
            			img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
            			res.append(img_animated_batch)
         | 
| 232 | 
            +
            		animated_video = torch.cat(res, dim=0)[:t] # TCHW
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 233 | 
             
            		edited_image = animated_video[0:1,:,:,:]
         | 
| 234 |  | 
| 235 | 
             
            		# postprocessing
         | 
|  | |
| 299 | 
             
            						#video_output.render()
         | 
| 300 | 
             
            						video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)#.render()
         | 
| 301 |  | 
| 302 | 
            +
            				with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
         | 
| 303 | 
             
            					with gr.Tab("Head"):
         | 
| 304 | 
             
            						with gr.Row():
         | 
| 305 | 
             
            							for k in labels_k[:3]:
         | 
|  | |
| 335 | 
             
            				fn=edit_media,
         | 
| 336 | 
             
            				inputs=[image_input] + inputs_s,
         | 
| 337 | 
             
            				outputs=[image_output],
         | 
|  | |
| 338 | 
             
            				show_progress='hidden',
         | 
|  | |
| 339 | 
             
            				trigger_mode='always_last',
         | 
|  | |
| 340 | 
             
            				# currently we have a latency around 450ms
         | 
| 341 | 
             
            				stream_every=0.5
         | 
| 342 | 
             
            			)
         | 
| 343 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 344 | 
             
            		animate_btn.click(
         | 
| 345 | 
             
            			fn=animate_media,
         | 
| 346 | 
             
            			inputs=[image_input, video_input] + inputs_s,
         | 
    	
        gradio_tabs/img_edit.py
    CHANGED
    
    | @@ -95,14 +95,10 @@ def img_denorm(img): | |
| 95 | 
             
            def img_postprocessing(img, w, h):
         | 
| 96 |  | 
| 97 | 
             
            	img = resize_back(img, w, h)
         | 
| 98 | 
            -
            	#image = image.permute(0, 2, 3, 1)
         | 
| 99 | 
             
            	img = img_denorm(img)
         | 
| 100 | 
             
            	img = img.squeeze(0).permute(1, 2, 0).contiguous()	# contiguous() for fast transfer
         | 
| 101 | 
             
            	img_output = (img.cpu().numpy() * 255).astype(np.uint8)
         | 
| 102 |  | 
| 103 | 
            -
            	#with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
         | 
| 104 | 
            -
            	#	imageio.imwrite(temp_file.name, img_output, quality=8)
         | 
| 105 | 
            -
            	#	return temp_file.name
         | 
| 106 | 
             
            	return img_output	
         | 
| 107 |  | 
| 108 |  | 
| @@ -196,7 +192,7 @@ def img_edit(gen, device): | |
| 196 | 
             
            						image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
         | 
| 197 |  | 
| 198 |  | 
| 199 | 
            -
            				with gr.Accordion("Control Panel  | 
| 200 | 
             
            					with gr.Tab("Head"):
         | 
| 201 | 
             
            						with gr.Row():
         | 
| 202 | 
             
            							for k in labels_k[:3]:
         | 
| @@ -239,15 +235,7 @@ def img_edit(gen, device): | |
| 239 |  | 
| 240 | 
             
            			# currently we have a latency around 450ms
         | 
| 241 | 
             
            			stream_every=0.5
         | 
| 242 | 
            -
            		) | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
            		#edit_btn.click(
         | 
| 246 | 
            -
            		#	fn=edit_img,
         | 
| 247 | 
            -
            		#	inputs=[image_input] + inputs_s,
         | 
| 248 | 
            -
            		#	outputs=[image_output],
         | 
| 249 | 
            -
            		#	show_progress=True
         | 
| 250 | 
            -
            		#)
         | 
| 251 |  | 
| 252 | 
             
            		clear_btn.click(
         | 
| 253 | 
             
            			fn=clear_media,
         | 
|  | |
| 95 | 
             
            def img_postprocessing(img, w, h):
         | 
| 96 |  | 
| 97 | 
             
            	img = resize_back(img, w, h)
         | 
|  | |
| 98 | 
             
            	img = img_denorm(img)
         | 
| 99 | 
             
            	img = img.squeeze(0).permute(1, 2, 0).contiguous()	# contiguous() for fast transfer
         | 
| 100 | 
             
            	img_output = (img.cpu().numpy() * 255).astype(np.uint8)
         | 
| 101 |  | 
|  | |
|  | |
|  | |
| 102 | 
             
            	return img_output	
         | 
| 103 |  | 
| 104 |  | 
|  | |
| 192 | 
             
            						image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
         | 
| 193 |  | 
| 194 |  | 
| 195 | 
            +
            				with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
         | 
| 196 | 
             
            					with gr.Tab("Head"):
         | 
| 197 | 
             
            						with gr.Row():
         | 
| 198 | 
             
            							for k in labels_k[:3]:
         | 
|  | |
| 235 |  | 
| 236 | 
             
            			# currently we have a latency around 450ms
         | 
| 237 | 
             
            			stream_every=0.5
         | 
| 238 | 
            +
            		)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 239 |  | 
| 240 | 
             
            		clear_btn.click(
         | 
| 241 | 
             
            			fn=clear_media,
         | 
    	
        gradio_tabs/vid_edit.py
    CHANGED
    
    | @@ -231,21 +231,23 @@ def vid_edit(gen, chunk_size, device): | |
| 231 | 
             
            		res = []
         | 
| 232 | 
             
            		t = video_target_tensor.size(1)
         | 
| 233 | 
             
            		chunks = t // chunk_size
         | 
|  | |
|  | |
| 234 | 
             
            		z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(img_start, selected_s)
         | 
| 235 | 
             
            		for i in range(chunks + 1):
         | 
| 236 | 
             
            			if i == chunks:
         | 
| 237 | 
            -
            				img_target_batch =  | 
| 238 | 
            -
            				img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start,  | 
| 239 | 
             
            			else:
         | 
| 240 | 
            -
            				img_target_batch =  | 
| 241 | 
            -
            				img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start,  | 
| 242 |  | 
| 243 | 
             
            			res.append(img_animated_batch)
         | 
| 244 | 
             
            		edited_video_tensor = torch.cat(res, dim=0)  # TCHW
         | 
| 245 | 
             
            		edited_image_tensor = edited_video_tensor[0:1,:,:,:]
         | 
| 246 |  | 
| 247 | 
             
            		# de-norm
         | 
| 248 | 
            -
            		animated_video, animated_all_video = vid_all_save( | 
| 249 | 
             
            		edited_image = img_postprocessing(edited_image_tensor, w, h)
         | 
| 250 |  | 
| 251 | 
             
            		return edited_image, animated_video, animated_all_video		
         | 
| @@ -293,7 +295,7 @@ def vid_edit(gen, chunk_size, device): | |
| 293 | 
             
            						video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
         | 
| 294 |  | 
| 295 | 
             
            			with gr.Column(scale=1):
         | 
| 296 | 
            -
            				with gr.Accordion("Control Panel  | 
| 297 | 
             
            					with gr.Tab("Head"):
         | 
| 298 | 
             
            						with gr.Row():
         | 
| 299 | 
             
            							for k in labels_k[:3]:
         | 
| @@ -342,13 +344,6 @@ def vid_edit(gen, chunk_size, device): | |
| 342 | 
             
            				stream_every=0.5
         | 
| 343 | 
             
            			)
         | 
| 344 |  | 
| 345 | 
            -
            		#edit_btn.click(
         | 
| 346 | 
            -
            		#	fn=edit_img,
         | 
| 347 | 
            -
            		#	inputs=[video_input] + inputs_s,
         | 
| 348 | 
            -
            		#	outputs=[image_output],
         | 
| 349 | 
            -
            		#	show_progress=True
         | 
| 350 | 
            -
            		#)
         | 
| 351 | 
            -
             | 
| 352 | 
             
            		animate_btn.click(
         | 
| 353 | 
             
            			fn=edit_vid,
         | 
| 354 | 
             
            			inputs=[video_input] + inputs_s,  # [image_input, video_input] + inputs_s,
         | 
|  | |
| 231 | 
             
            		res = []
         | 
| 232 | 
             
            		t = video_target_tensor.size(1)
         | 
| 233 | 
             
            		chunks = t // chunk_size
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
             
            		z_s2r, alpha_r2s, feat_rgb = compiled_enc_img(img_start, selected_s)
         | 
| 237 | 
             
            		for i in range(chunks + 1):
         | 
| 238 | 
             
            			if i == chunks:
         | 
| 239 | 
            +
            				img_target_batch = video_target_tensor[i * chunk_size:, :, :, :]
         | 
| 240 | 
            +
            				img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
         | 
| 241 | 
             
            			else:
         | 
| 242 | 
            +
            				img_target_batch = video_target_tensor[i * chunk_size:(i + 1) * chunk_size, :, :, :]
         | 
| 243 | 
            +
            				img_animated_batch = compiled_dec_vid(z_s2r, alpha_r2s, feat_rgb, img_start, img_target_batch)
         | 
| 244 |  | 
| 245 | 
             
            			res.append(img_animated_batch)
         | 
| 246 | 
             
            		edited_video_tensor = torch.cat(res, dim=0)  # TCHW
         | 
| 247 | 
             
            		edited_image_tensor = edited_video_tensor[0:1,:,:,:]
         | 
| 248 |  | 
| 249 | 
             
            		# de-norm
         | 
| 250 | 
            +
            		animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, w, h, fps)
         | 
| 251 | 
             
            		edited_image = img_postprocessing(edited_image_tensor, w, h)
         | 
| 252 |  | 
| 253 | 
             
            		return edited_image, animated_video, animated_all_video		
         | 
|  | |
| 295 | 
             
            						video_all_output = gr.Video(label="Videos", elem_id="output_vid_all")
         | 
| 296 |  | 
| 297 | 
             
            			with gr.Column(scale=1):
         | 
| 298 | 
            +
            				with gr.Accordion("Control Panel - Using Sliders to Edit Image", open=True):
         | 
| 299 | 
             
            					with gr.Tab("Head"):
         | 
| 300 | 
             
            						with gr.Row():
         | 
| 301 | 
             
            							for k in labels_k[:3]:
         | 
|  | |
| 344 | 
             
            				stream_every=0.5
         | 
| 345 | 
             
            			)
         | 
| 346 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 347 | 
             
            		animate_btn.click(
         | 
| 348 | 
             
            			fn=edit_vid,
         | 
| 349 | 
             
            			inputs=[video_input] + inputs_s,  # [image_input, video_input] + inputs_s,
         |