Spaces:
Runtime error
Runtime error
Update video_mask_app.py
Browse files- video_mask_app.py +15 -2
video_mask_app.py
CHANGED
|
@@ -93,6 +93,7 @@ def process_video(input_video_path):
|
|
| 93 |
|
| 94 |
# Process each frame
|
| 95 |
frames = []
|
|
|
|
| 96 |
for frame in tqdm(video_clip.iter_frames()):
|
| 97 |
frame_pil = Image.fromarray(frame)
|
| 98 |
frame_no_bg, mask_resized = remove_background(frame_pil)
|
|
@@ -109,12 +110,23 @@ def process_video(input_video_path):
|
|
| 109 |
output_np = np.array(output)
|
| 110 |
frames.append(output_np)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# Save the processed frames as a new video
|
| 113 |
output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
|
| 114 |
processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
|
| 115 |
processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# Gradio components
|
| 120 |
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
|
|
@@ -125,6 +137,7 @@ text = gr.Textbox(label="Paste an image URL")
|
|
| 125 |
png_file = gr.File(label="output png file")
|
| 126 |
video_input = gr.Video(label="Upload a video")
|
| 127 |
video_output = gr.Video(label="Processed video")
|
|
|
|
| 128 |
|
| 129 |
# Example videos
|
| 130 |
example_videos = [
|
|
@@ -140,7 +153,7 @@ tab1 = gr.Interface(
|
|
| 140 |
|
| 141 |
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
|
| 142 |
#tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
|
| 143 |
-
tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video", cache_examples = False)
|
| 144 |
|
| 145 |
# Gradio tabbed interface
|
| 146 |
demo = gr.TabbedInterface(
|
|
|
|
| 93 |
|
| 94 |
# Process each frame
|
| 95 |
frames = []
|
| 96 |
+
mask_frames = []
|
| 97 |
for frame in tqdm(video_clip.iter_frames()):
|
| 98 |
frame_pil = Image.fromarray(frame)
|
| 99 |
frame_no_bg, mask_resized = remove_background(frame_pil)
|
|
|
|
| 110 |
output_np = np.array(output)
|
| 111 |
frames.append(output_np)
|
| 112 |
|
| 113 |
+
# Create a mask frame with white foreground and black background
|
| 114 |
+
mask_frame = np.array(mask_resized)
|
| 115 |
+
mask_frame = np.stack([mask_frame, mask_frame, mask_frame], axis=-1) # Convert to 3 channels
|
| 116 |
+
mask_frame[mask_frame > 0] = 255 # Set foreground to white
|
| 117 |
+
mask_frames.append(mask_frame)
|
| 118 |
+
|
| 119 |
# Save the processed frames as a new video
|
| 120 |
output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
|
| 121 |
processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
|
| 122 |
processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
|
| 123 |
|
| 124 |
+
# Save the mask frames as a new video
|
| 125 |
+
mask_video_path = os.path.join(output_folder, "mask_video.mp4")
|
| 126 |
+
mask_clip = ImageSequenceClip(mask_frames, fps=video_clip.fps)
|
| 127 |
+
mask_clip.write_videofile(mask_video_path, codec='libx264')
|
| 128 |
+
|
| 129 |
+
return output_video_path, mask_video_path
|
| 130 |
|
| 131 |
# Gradio components
|
| 132 |
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
|
|
|
|
| 137 |
png_file = gr.File(label="output png file")
|
| 138 |
video_input = gr.Video(label="Upload a video")
|
| 139 |
video_output = gr.Video(label="Processed video")
|
| 140 |
+
mask_video_output = gr.Video(label="Mask video")
|
| 141 |
|
| 142 |
# Example videos
|
| 143 |
example_videos = [
|
|
|
|
| 153 |
|
| 154 |
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
|
| 155 |
#tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
|
| 156 |
+
tab4 = gr.Interface(process_video, inputs=video_input, outputs=[video_output, mask_video_output], examples=example_videos, api_name="video", cache_examples = False)
|
| 157 |
|
| 158 |
# Gradio tabbed interface
|
| 159 |
demo = gr.TabbedInterface(
|