import gc
import os
import shutil
import time
from datetime import datetime
import io
import sys
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
from src.utils.inference_utils import load_and_preprocess_images
from src.utils.geometry import (
depth_edge,
normals_edge
)
from src.utils.visual_util import (
convert_predictions_to_glb_scene,
segment_sky,
download_file_from_url
)
from src.utils.save_utils import save_camera_params, save_gs_ply, process_ply_to_splat, convert_gs_to_ply
from src.utils.render_utils import render_interpolated_video
import onnxruntime
# Initialize model - this will be done on GPU when needed
model = None
# Global variable to store current terminal output
current_terminal_output = ""
# Helper class to capture terminal output
class TeeOutput:
"""Capture output while still printing to console"""
def __init__(self, max_chars=10000):
self.terminal = sys.stdout
self.log = io.StringIO()
self.max_chars = max_chars # 限制最大字符数
def write(self, message):
global current_terminal_output
self.terminal.write(message)
self.log.write(message)
# 获取当前内容并限制长度
content = self.log.getvalue()
if len(content) > self.max_chars:
# 只保留最后 max_chars 个字符
content = "...(earlier output truncated)...\n" + content[-self.max_chars:]
self.log = io.StringIO()
self.log.write(content)
current_terminal_output = self.log.getvalue()
def flush(self):
self.terminal.flush()
def getvalue(self):
return self.log.getvalue()
def clear(self):
global current_terminal_output
self.log = io.StringIO()
current_terminal_output = ""
# -------------------------------------------------------------------------
# Model inference
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
target_dir,
confidence_percentile: float = 10,
edge_normal_threshold: float = 5.0,
edge_depth_threshold: float = 0.03,
apply_confidence_mask: bool = True,
apply_edge_mask: bool = True,
):
"""
Run the WorldMirror model on images in the 'target_dir/images' folder and return predictions.
"""
global model
import torch # Ensure torch is available in function scope
from src.models.models.worldmirror import WorldMirror
from src.models.utils.geometry import depth_to_world_coords_points
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Initialize model if not already done
if model is None:
model = WorldMirror.from_pretrained("tencent/HunyuanWorld-Mirror").to(device)
else:
model.to(device)
model.eval()
# Load images using WorldMirror's load_images function
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
image_file_paths = [os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]
img = load_and_preprocess_images(image_file_paths).to(device)
print(f"Loaded {img.shape[1]} images")
if img.shape[1] == 0:
raise ValueError("No images found. Check your upload.")
# Run model inference
print("Running inference...")
inputs = {}
inputs['img'] = img
use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
if use_amp:
amp_dtype = torch.bfloat16
else:
amp_dtype = torch.float32
with torch.amp.autocast('cuda', enabled=bool(use_amp), dtype=amp_dtype):
predictions = model(inputs)
# img
imgs = inputs["img"].permute(0, 1, 3, 4, 2)
imgs = imgs[0].detach().cpu().numpy() # S H W 3
# depth output
depth_preds = predictions["depth"]
depth_conf = predictions["depth_conf"]
depth_preds = depth_preds[0].detach().cpu().numpy() # S H W 1
depth_conf = depth_conf[0].detach().cpu().numpy() # S H W
# normal output
normal_preds = predictions["normals"] # S H W 3
normal_preds = normal_preds[0].detach().cpu().numpy() # S H W 3
# camera parameters
camera_poses = predictions["camera_poses"][0].detach().cpu().numpy() # [S,4,4]
camera_intrs = predictions["camera_intrs"][0].detach().cpu().numpy() # [S,3,3]
# points output
pts3d_preds = depth_to_world_coords_points(predictions["depth"][0, ..., 0], predictions["camera_poses"][0], predictions["camera_intrs"][0])[0]
pts3d_preds = pts3d_preds.detach().cpu().numpy() # S H W 3
pts3d_conf = depth_conf # S H W
# sky mask segmentation
if not os.path.exists("skyseg.onnx"):
print("Downloading skyseg.onnx...")
download_file_from_url(
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx"
)
skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
sky_mask_list = []
for i, img_path in enumerate([os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]):
sky_mask = segment_sky(img_path, skyseg_session)
# Resize mask to match H×W if needed
if sky_mask.shape[0] != imgs.shape[1] or sky_mask.shape[1] != imgs.shape[2]:
sky_mask = cv2.resize(sky_mask, (imgs.shape[2], imgs.shape[1]))
sky_mask_list.append(sky_mask)
sky_mask = np.stack(sky_mask_list, axis=0) # [S, H, W]
sky_mask = sky_mask>0
# mask computation
final_mask_list = []
for i in range(inputs["img"].shape[1]):
final_mask = None
if apply_confidence_mask:
# compute confidence mask based on the pointmap confidence
confidences = pts3d_conf[i, :, :] # [H, W]
percentile_threshold = np.quantile(confidences, confidence_percentile / 100.0)
conf_mask = confidences >= percentile_threshold
if final_mask is None:
final_mask = conf_mask
else:
final_mask = final_mask & conf_mask
if apply_edge_mask:
# compute edge mask based on the normalmap
normal_pred = normal_preds[i] # [H, W, 3]
normal_edges = normals_edge(
normal_pred, tol=edge_normal_threshold, mask=final_mask
)
# compute depth mask based on the depthmap
depth_pred = depth_preds[i, :, :, 0] # [H, W]
depth_edges = depth_edge(
depth_pred, rtol=edge_depth_threshold, mask=final_mask
)
edge_mask = ~(depth_edges & normal_edges)
if final_mask is None:
final_mask = edge_mask
else:
final_mask = final_mask & edge_mask
final_mask_list.append(final_mask)
if final_mask_list[0] is not None:
final_mask = np.stack(final_mask_list, axis=0) # [S, H, W]
else:
final_mask = np.ones(pts3d_conf.shape[:3], dtype=bool) # [S, H, W]
# gaussian splatting output
if "splats" in predictions:
splats_dict = {}
splats_dict['means'] = predictions["splats"]["means"]
splats_dict['scales'] = predictions["splats"]["scales"]
splats_dict['quats'] = predictions["splats"]["quats"]
splats_dict['opacities'] = predictions["splats"]["opacities"]
if "sh" in predictions["splats"]:
splats_dict['sh'] = predictions["splats"]["sh"]
if "colors" in predictions["splats"]:
splats_dict['colors'] = predictions["splats"]["colors"]
# output lists
outputs = {}
outputs['images'] = imgs
outputs['world_points'] = pts3d_preds
outputs['depth'] = depth_preds
outputs['normal'] = normal_preds
outputs['final_mask'] = final_mask
outputs['sky_mask'] = sky_mask
outputs['camera_poses'] = camera_poses
outputs['camera_intrs'] = camera_intrs
if "splats" in predictions:
outputs['splats'] = splats_dict
# Process data for visualization tabs (depth, normal)
processed_data = prepare_visualization_data(
outputs, inputs
)
# Clean up
torch.cuda.empty_cache()
return outputs, processed_data
# -------------------------------------------------------------------------
# Update and navigation function
# -------------------------------------------------------------------------
def update_view_info(current_view, total_views, view_type="Depth"):
"""Update view information display"""
return f"""
{view_type} View Navigation |
Current: View {current_view} / {total_views} views
"""
def update_view_selectors(processed_data):
"""Update view selector sliders and info displays based on available views"""
if processed_data is None or len(processed_data) == 0:
num_views = 1
else:
num_views = len(processed_data)
# 确保 num_views 至少为 1
num_views = max(1, num_views)
# 更新滑块的最大值和视图信息,使用 gr.update() 而不是创建新组件
depth_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
normal_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
# 更新视图信息显示
depth_info_update = update_view_info(1, num_views, "Depth")
normal_info_update = update_view_info(1, num_views, "Normal")
return (
depth_slider_update, # depth_view_slider
normal_slider_update, # normal_view_slider
depth_info_update, # depth_view_info
normal_info_update, # normal_view_info
)
def get_view_data_by_index(processed_data, view_index):
"""Get view data by index, handling bounds"""
if processed_data is None or len(processed_data) == 0:
return None
view_keys = list(processed_data.keys())
if view_index < 0 or view_index >= len(view_keys):
view_index = 0
return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
"""Update depth view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["depth"] is None:
return None
return render_depth_visualization(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
"""Update normal view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["normal"] is None:
return None
return render_normal_visualization(view_data["normal"], mask=view_data.get("mask"))
def initialize_depth_normal_views(processed_data):
"""Initialize the depth and normal view displays with the first view data"""
if processed_data is None or len(processed_data) == 0:
return None, None
# Use update functions to ensure confidence filtering is applied from the start
depth_vis = update_depth_view(processed_data, 0)
normal_vis = update_normal_view(processed_data, 0)
return depth_vis, normal_vis
# -------------------------------------------------------------------------
# File upload and update preview gallery
# -------------------------------------------------------------------------
def process_uploaded_files(files, time_interval=1.0):
"""
Process uploaded files by extracting video frames or copying images.
Args:
files: List of uploaded file objects (videos or images)
time_interval: Interval in seconds for video frame extraction
Returns:
tuple: (target_dir, image_paths) where target_dir is the output directory
and image_paths is a list of processed image file paths
"""
gc.collect()
torch.cuda.empty_cache()
# Create unique output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
images_dir = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(images_dir)
image_paths = []
if files is None:
return target_dir, image_paths
video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
for file_data in files:
# Get file path
if isinstance(file_data, dict) and "name" in file_data:
src_path = file_data["name"]
else:
src_path = str(file_data)
ext = os.path.splitext(src_path)[1].lower()
base_name = os.path.splitext(os.path.basename(src_path))[0]
# Process video: extract frames
if ext in video_exts:
cap = cv2.VideoCapture(src_path)
fps = cap.get(cv2.CAP_PROP_FPS)
interval = int(fps * time_interval)
frame_count = 0
saved_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % interval == 0:
dst_path = os.path.join(images_dir, f"{base_name}_{saved_count:06}.png")
cv2.imwrite(dst_path, frame)
image_paths.append(dst_path)
saved_count += 1
cap.release()
print(f"Extracted {saved_count} frames from: {os.path.basename(src_path)}")
# Process HEIC/HEIF: convert to JPEG
elif ext in [".heic", ".heif"]:
try:
with Image.open(src_path) as img:
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
dst_path = os.path.join(images_dir, f"{base_name}.jpg")
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
print(f"Converted HEIC: {os.path.basename(src_path)} -> {os.path.basename(dst_path)}")
except Exception as e:
print(f"HEIC conversion failed for {src_path}: {e}")
dst_path = os.path.join(images_dir, os.path.basename(src_path))
shutil.copy(src_path, dst_path)
image_paths.append(dst_path)
# Process regular images: copy directly
else:
dst_path = os.path.join(images_dir, os.path.basename(src_path))
shutil.copy(src_path, dst_path)
image_paths.append(dst_path)
image_paths = sorted(image_paths)
print(f"Processed files to {images_dir}")
return target_dir, image_paths
# Handle file upload and update preview gallery
def update_gallery_on_upload(input_video, input_images, time_interval=1.0):
"""
Process uploaded files immediately when user uploads or changes files,
and display them in the gallery. Returns (target_dir, image_paths).
If nothing is uploaded, returns None and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = process_uploaded_files(input_video, input_images, time_interval)
return (
None,
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# Init function
# -------------------------------------------------------------------------
def prepare_visualization_data(
model_outputs, input_views
):
"""Transform model predictions into structured format for display components"""
visualization_dict = {}
# Iterate through each input view
nviews = input_views["img"].shape[1]
for idx in range(nviews):
# Extract RGB image data
rgb_image = input_views["img"][0, idx].detach().cpu().numpy()
# Retrieve 3D coordinate predictions
world_coordinates = model_outputs["world_points"][idx]
# Build view-specific data structure
current_view_info = {
"image": rgb_image,
"points3d": world_coordinates,
"depth": None,
"normal": None,
"mask": None,
}
# Apply final segmentation mask from model
segmentation_mask = model_outputs["final_mask"][idx].copy()
current_view_info["mask"] = segmentation_mask
current_view_info["depth"] = model_outputs["depth"][idx].squeeze()
surface_normals = model_outputs["normal"][idx]
current_view_info["normal"] = surface_normals
visualization_dict[idx] = current_view_info
return visualization_dict
@spaces.GPU(duration=120)
def gradio_demo(
target_dir,
frame_selector="All",
show_camera=False,
filter_sky_bg=False,
show_mesh=False,
filter_ambiguous=False,
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
# Capture terminal output
tee = TeeOutput()
old_stdout = sys.stdout
sys.stdout = tee
try:
if not os.path.isdir(target_dir) or target_dir == "None":
terminal_log = tee.getvalue()
sys.stdout = old_stdout
return None, "No valid target directory found. Please upload first.", None, None, None, None, None, None, None, None, None, None, None, None, terminal_log
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_selector dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images))
if os.path.isdir(target_dir_images)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_selector_choices = ["All"] + all_files
print("Running WorldMirror model...")
with torch.no_grad():
predictions, processed_data = run_model(target_dir)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Save camera parameters as JSON
camera_params_file = save_camera_params(
predictions['camera_poses'],
predictions['camera_intrs'],
target_dir
)
# Handle None frame_selector
if frame_selector is None:
frame_selector = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_camera}_mesh{show_mesh}.glb",
)
# Convert predictions to GLB
glbscene = convert_predictions_to_glb_scene(
predictions,
filter_by_frames=frame_selector,
show_camera=show_camera,
mask_sky_bg=filter_sky_bg,
as_mesh=show_mesh, # Use the show_mesh parameter
mask_ambiguous=filter_ambiguous
)
glbscene.export(file_obj=glbfile)
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
log_msg = (
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
)
# Convert predictions to 3dgs ply
gs_file = None
splat_mode = 'ply'
if "splats" in predictions:
# Get Gaussian parameters (already filtered by GaussianSplatRenderer)
means = predictions["splats"]["means"][0].reshape(-1, 3)
scales = predictions["splats"]["scales"][0].reshape(-1, 3)
quats = predictions["splats"]["quats"][0].reshape(-1, 4)
colors = (predictions["splats"]["sh"][0] if "sh" in predictions["splats"] else predictions["splats"]["colors"][0]).reshape(-1, 3)
opacities = predictions["splats"]["opacities"][0].reshape(-1)
# Convert to torch tensors if needed
if not isinstance(means, torch.Tensor):
means = torch.from_numpy(means)
if not isinstance(scales, torch.Tensor):
scales = torch.from_numpy(scales)
if not isinstance(quats, torch.Tensor):
quats = torch.from_numpy(quats)
if not isinstance(colors, torch.Tensor):
colors = torch.from_numpy(colors)
if not isinstance(opacities, torch.Tensor):
opacities = torch.from_numpy(opacities)
if splat_mode == 'ply':
gs_file = os.path.join(target_dir, "gaussians.ply")
save_gs_ply(
gs_file,
means,
scales,
quats,
colors,
opacities
)
print(f"Saved Gaussian Splatting PLY to: {gs_file}")
print(f"File exists: {os.path.exists(gs_file)}")
if os.path.exists(gs_file):
print(f"File size: {os.path.getsize(gs_file)} bytes")
elif splat_mode == 'splat':
# Save Gaussian splat
plydata = convert_gs_to_ply(
means,
scales,
quats,
colors,
opacities
)
gs_file = os.path.join(target_dir, "gaussians.splat")
gs_file = process_ply_to_splat(plydata, gs_file)
# Initialize depth and normal view displays with processed data
depth_vis, normal_vis = initialize_depth_normal_views(
processed_data
)
# Update view selectors and info displays based on available views
depth_slider, normal_slider, depth_info, normal_info = update_view_selectors(
processed_data
)
# Automatically generate render video
# Generate render video if possible
rgb_video_path = None
depth_video_path = None
if "splats" in predictions:
# try:
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get camera parameters and image dimensions
camera_poses = torch.tensor(predictions['camera_poses']).unsqueeze(0).to(device)
camera_intrs = torch.tensor(predictions['camera_intrs']).unsqueeze(0).to(device)
H, W = predictions['images'].shape[1], predictions['images'].shape[2]
# Render video
out_path = Path(target_dir) / "rendered_video"
render_interpolated_video(
model.gs_renderer,
predictions["splats"],
camera_poses,
camera_intrs,
(H, W),
out_path,
interp_per_pair=15,
loop_reverse=True,
save_mode="split"
)
# Check output files
rgb_video_path = str(out_path) + "_rgb.mp4"
depth_video_path = str(out_path) + "_depth.mp4"
if not os.path.exists(rgb_video_path) and not os.path.exists(depth_video_path):
rgb_video_path = None
depth_video_path = None
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
# Get terminal output and restore stdout
terminal_log = tee.getvalue()
sys.stdout = old_stdout
return (
glbfile,
log_msg,
gr.Dropdown(choices=frame_selector_choices, value=frame_selector, interactive=True),
processed_data,
depth_vis,
normal_vis,
depth_slider,
normal_slider,
depth_info,
normal_info,
camera_params_file,
gs_file,
rgb_video_path,
depth_video_path,
terminal_log,
)
except Exception as e:
# In case of error, still restore stdout
terminal_log = tee.getvalue()
sys.stdout = old_stdout
print(f"Error occurred: {e}")
raise
# -------------------------------------------------------------------------
# Helper functions for visualization
# -------------------------------------------------------------------------
def render_depth_visualization(depth_map, mask=None):
"""Generate a color-coded depth visualization image with masking capabilities"""
if depth_map is None:
return None
# Create working copy and identify positive depth values
depth_copy = depth_map.copy()
positive_depth_mask = depth_copy > 0
# Combine with user-provided mask for filtering
if mask is not None:
positive_depth_mask = positive_depth_mask & mask
# Perform percentile-based normalization on valid regions
if positive_depth_mask.sum() > 0:
valid_depth_values = depth_copy[positive_depth_mask]
lower_bound = np.percentile(valid_depth_values, 5)
upper_bound = np.percentile(valid_depth_values, 95)
depth_copy[positive_depth_mask] = (depth_copy[positive_depth_mask] - lower_bound) / (upper_bound - lower_bound)
# Convert to RGB using matplotlib colormap
import matplotlib.pyplot as plt
color_mapper = plt.cm.turbo_r
rgb_result = color_mapper(depth_copy)
rgb_result = (rgb_result[:, :, :3] * 255).astype(np.uint8)
# Mark invalid regions with white color
rgb_result[~positive_depth_mask] = [255, 255, 255]
return rgb_result
def render_normal_visualization(normal_map, mask=None):
"""Convert surface normal vectors to RGB color representation for display"""
if normal_map is None:
return None
# Make a working copy to avoid modifying original data
normal_display = normal_map.copy()
# Handle masking by zeroing out invalid regions
if mask is not None:
masked_regions = ~mask
normal_display[masked_regions] = [0, 0, 0] # Zero out masked pixels
# Transform from [-1, 1] to [0, 1] range for RGB display
normal_display = (normal_display + 1.0) / 2.0
normal_display = (normal_display * 255).astype(np.uint8)
return normal_display
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def get_terminal_output():
"""
Get current terminal output for real-time display
"""
global current_terminal_output
return current_terminal_output
# -------------------------------------------------------------------------
# FunctionExample scene metadata extraction
# -------------------------------------------------------------------------
def extract_example_scenes_metadata(base_directory):
"""
Extract comprehensive metadata for all scene directories containing valid images.
Args:
base_directory: Root path where example scene directories are located
Returns:
Collection of dictionaries with scene details (title, location, preview, etc.)
"""
from glob import glob
# Return empty list if base directory is missing
if not os.path.exists(base_directory):
return []
# Define supported image format extensions
VALID_IMAGE_FORMATS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif']
scenes_data = []
# Process each subdirectory in the base directory
for directory_name in sorted(os.listdir(base_directory)):
current_directory = os.path.join(base_directory, directory_name)
# Filter out non-directory items
if not os.path.isdir(current_directory):
continue
# Gather all valid image files within the current directory
discovered_images = []
for file_format in VALID_IMAGE_FORMATS:
# Include both lowercase and uppercase format variations
discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format}')))
discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format.upper()}')))
# Skip directories without any valid images
if not discovered_images:
continue
# Ensure consistent image ordering
discovered_images.sort()
# Construct scene metadata record
scene_record = {
'name': directory_name,
'path': current_directory,
'thumbnail': discovered_images[0],
'num_images': len(discovered_images),
'image_files': discovered_images,
}
scenes_data.append(scene_record)
return scenes_data
def load_example_scenes(scene_name, scenes):
"""
Initialize and prepare an example scene for 3D reconstruction processing.
Args:
scene_name: Identifier of the target scene to load
scenes: List containing all available scene configurations
Returns:
Tuple containing processed scene data and status information
"""
# Locate the target scene configuration by matching names
target_scene_config = None
for scene_config in scenes:
if scene_config["name"] == scene_name:
target_scene_config = scene_config
break
# Handle case where requested scene doesn't exist
if target_scene_config is None:
return None, None, None, "Scene not found"
# Prepare image file paths for processing pipeline
# Extract all image file paths from the selected scene
image_file_paths = []
for img_file_path in target_scene_config["image_files"]:
image_file_paths.append(img_file_path)
# Process the scene images through the standard upload pipeline
processed_target_dir, processed_image_list = process_uploaded_files(image_file_paths, 1.0)
# Return structured response with scene data and user feedback
status_message = f"Successfully loaded scene '{scene_name}' containing {target_scene_config['num_images']} images. Click 'Reconstruct' to begin 3D processing."
return (
None, # Reset reconstruction visualization
None, # Reset gaussian splatting output
processed_target_dir, # Provide working directory path
processed_image_list, # Update image gallery display
status_message,
)
# -------------------------------------------------------------------------
# UI and event handling
# -------------------------------------------------------------------------
theme = gr.themes.Base()
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #a9b8f8 0%, #7081e8 60%, #4254c5 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.normal-weight-btn button,
.normal-weight-btn button span,
.normal-weight-btn button *,
.normal-weight-btn * {
font-weight: 400 !important;
}
.terminal-output {
max-height: 400px !important;
overflow-y: auto !important;
}
.terminal-output textarea {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
font-size: 13px !important;
line-height: 1.5 !important;
color: #333 !important;
background-color: #f8f9fa !important;
max-height: 400px !important;
}
.example-gallery {
width: 100% !important;
}
.example-gallery img {
width: 100% !important;
height: 280px !important;
object-fit: contain !important;
aspect-ratio: 16 / 9 !important;
}
.example-gallery .grid-wrap {
width: 100% !important;
}
/* 滑块导航样式 */
.depth-tab-improved .gradio-slider input[type="range"] {
height: 8px !important;
border-radius: 4px !important;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
}
.depth-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
height: 20px !important;
width: 20px !important;
border-radius: 50% !important;
background: #fff !important;
box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
}
.depth-tab-improved button {
transition: all 0.3s ease !important;
border-radius: 6px !important;
font-weight: 500 !important;
}
.depth-tab-improved button:hover {
transform: translateY(-1px) !important;
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
}
.normal-tab-improved .gradio-slider input[type="range"] {
height: 8px !important;
border-radius: 4px !important;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
}
.normal-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
height: 20px !important;
width: 20px !important;
border-radius: 50% !important;
background: #fff !important;
box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
}
.normal-tab-improved button {
transition: all 0.3s ease !important;
border-radius: 6px !important;
font-weight: 500 !important;
}
.normal-tab-improved button:hover {
transform: translateY(-1px) !important;
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
}
#depth-view-info, #normal-view-info {
animation: fadeIn 0.5s ease-in-out;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(-10px); }
to { opacity: 1; transform: translateY(0); }
}
"""
) as demo:
# State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
processed_data_state = gr.State(value=None)
current_view_index = gr.State(value=0) # Track current view index for navigation
# Header and description
gr.HTML(
"""
WorldMirror supports any combination of inputs (images, intrinsics, poses, and depth) and multiple outputs including point clouds, camera parameters, depth maps, normal maps, and 3D Gaussian Splatting (3DGS).
How to Use:
- Upload Your Data: Click the "Upload Video or Images" button to add your files. Videos are automatically extracted into frames at one-second intervals.
- Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction.
- Visualize: Explore multiple reconstruction results across different tabs:
- 3D View: Interactive point cloud/mesh visualization with camera poses (downloadable as GLB)
- 3D Gaussian Splatting: Interactive 3D Gaussian Splatting visualization with RGB and depth videos (downloadable as PLY)
- Depth Maps: Per-view depth estimation results (downloadable as PNG)
- Normal Maps: Per-view surface orientation visualization (downloadable as PNG)
- Camera Parameters: Estimated camera poses and intrinsics (downloadable as JSON)
Please note: Loading data and displaying 3D effects may take a moment. For faster performance, we recommend downloading the code from our GitHub and running it locally.
""")
output_path_state = gr.Textbox(label="Output Path", visible=False, value="None")
# Main UI components
with gr.Row(equal_height=False):
with gr.Column(scale=1):
file_upload = gr.File(
file_count="multiple",
label="Upload Video or Images",
interactive=True,
file_types=["image", "video"],
height="200px",
)
time_interval = gr.Slider(
minimum=0.1,
maximum=10.0,
value=1.0,
step=0.1,
label="Video Sample interval",
interactive=True,
visible=True,
scale=4,
)
resample_btn = gr.Button(
"Resample",
visible=True,
scale=1,
elem_classes=["normal-weight-btn"],
)
image_gallery = gr.Gallery(
label="Image Preview",
columns=4,
height="200px",
show_download_button=True,
object_fit="contain",
preview=True
)
terminal_output = gr.Textbox(
label="Terminal Output",
lines=6,
max_lines=6,
interactive=False,
show_copy_button=True,
container=True,
elem_classes=["terminal-output"],
autoscroll=True
)
with gr.Column(scale=3):
log_output = gr.Markdown(
"Upload video or images first, then click Reconstruct to start processing",
elem_classes=["custom-log"],
)
with gr.Tabs() as tabs:
with gr.Tab("3D Gaussian Splatting", id=1) as gs_tab:
with gr.Row():
with gr.Column(scale=3):
gs_output = gr.Model3D(
label="Gaussian Splatting",
height=500,
)
with gr.Column(scale=1):
gs_rgb_video = gr.Video(
label="Rendered RGB Video",
height=250,
autoplay=False,
loop=False,
interactive=False,
)
gs_depth_video = gr.Video(
label="Rendered Depth Video",
height=250,
autoplay=False,
loop=False,
interactive=False,
)
with gr.Tab("Point Cloud/Mesh", id=0):
reconstruction_output = gr.Model3D(
label="3D Pointmap/Mesh",
height=500,
zoom_speed=0.4,
pan_speed=0.4,
)
with gr.Tab("Depth", elem_classes=["depth-tab-improved"]):
depth_view_info = gr.HTML(
value=""
"Depth View Navigation | Current: View 1 / 1 views
",
elem_id="depth-view-info"
)
depth_view_slider = gr.Slider(
minimum=1,
maximum=1,
step=1,
value=1,
label="View Selection Slider",
interactive=True,
elem_id="depth-view-slider"
)
depth_map = gr.Image(
type="numpy",
label="Depth Map",
format="png",
interactive=False,
height=340
)
with gr.Tab("Normal", elem_classes=["normal-tab-improved"]):
normal_view_info = gr.HTML(
value=""
"Normal View Navigation | Current: View 1 / 1 views
",
elem_id="normal-view-info"
)
normal_view_slider = gr.Slider(
minimum=1,
maximum=1,
step=1,
value=1,
label="View Selection Slider",
interactive=True,
elem_id="normal-view-slider"
)
normal_map = gr.Image(
type="numpy",
label="Normal Map",
format="png",
interactive=False,
height=340
)
with gr.Tab("Camera Parameters", elem_classes=["camera-tab"]):
with gr.Row():
gr.HTML("")
camera_params = gr.DownloadButton(
label="Download Camera Parameters",
scale=1,
variant="primary",
)
gr.HTML("")
with gr.Row():
reconstruct_btn = gr.Button(
"Reconstruct",
scale=1,
variant="primary"
)
clear_btn = gr.ClearButton(
[
file_upload,
reconstruction_output,
log_output,
output_path_state,
image_gallery,
depth_map,
normal_map,
depth_view_slider,
normal_view_slider,
depth_view_info,
normal_view_info,
camera_params,
gs_output,
gs_rgb_video,
gs_depth_video,
],
scale=1,
)
with gr.Row():
frame_selector = gr.Dropdown(
choices=["All"], value="All", label="Show Points of a Specific Frame"
)
gr.Markdown("### Reconstruction Options: (not applied to 3DGS)")
with gr.Row():
show_camera = gr.Checkbox(label="Show Camera", value=True)
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
filter_ambiguous = gr.Checkbox(label="Filter low confidence & depth/normal edges", value=True)
filter_sky_bg = gr.Checkbox(label="Filter Sky Background", value=False)
with gr.Column(scale=1):
gr.Markdown("### Click to load example scenes")
realworld_scenes = extract_example_scenes_metadata("examples/realistic") if os.path.exists("examples/realistic") else extract_example_scenes_metadata("examples")
generated_scenes = extract_example_scenes_metadata("examples/stylistic") if os.path.exists("examples/stylistic") else []
# If no subdirectories exist, fall back to single gallery
if not os.path.exists("examples/realistic") and not os.path.exists("examples/stylistic"):
# Fallback: use all scenes from examples directory
all_scenes = extract_example_scenes_metadata("examples")
if all_scenes:
gallery_items = [
(scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
for scene in all_scenes
]
example_gallery = gr.Gallery(
value=gallery_items,
label="Example Scenes",
columns=1,
rows=None,
height=800,
object_fit="contain",
show_label=False,
interactive=True,
preview=False,
allow_preview=False,
elem_classes=["example-gallery"]
)
def handle_example_selection(evt: gr.SelectData):
if evt:
result = load_example_scenes(all_scenes[evt.index]["name"], all_scenes)
return result
return (None, None, None, None, "No scene selected")
example_gallery.select(
fn=handle_example_selection,
outputs=[
reconstruction_output,
gs_output,
output_path_state,
image_gallery,
log_output,
],
)
else:
# Tabbed interface for categorized examples
with gr.Tabs():
with gr.Tab("🌍 Realistic Cases"):
if realworld_scenes:
realworld_items = [
(scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
for scene in realworld_scenes
]
realworld_gallery = gr.Gallery(
value=realworld_items,
label="Real-world Examples",
columns=1,
rows=None,
height=750,
object_fit="contain",
show_label=False,
interactive=True,
preview=False,
allow_preview=False,
elem_classes=["example-gallery"]
)
def handle_realworld_selection(evt: gr.SelectData):
if evt:
result = load_example_scenes(realworld_scenes[evt.index]["name"], realworld_scenes)
return result
return (None, None, None, None, "No scene selected")
realworld_gallery.select(
fn=handle_realworld_selection,
outputs=[
reconstruction_output,
gs_output,
output_path_state,
image_gallery,
log_output,
],
)
else:
gr.Markdown("No real-world examples available")
with gr.Tab("🎨 Stylistic Cases"):
if generated_scenes:
generated_items = [
(scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
for scene in generated_scenes
]
generated_gallery = gr.Gallery(
value=generated_items,
label="Generated Examples",
columns=1,
rows=None,
height=750,
object_fit="contain",
show_label=False,
interactive=True,
preview=False,
allow_preview=False,
elem_classes=["example-gallery"]
)
def handle_generated_selection(evt: gr.SelectData):
if evt:
result = load_example_scenes(generated_scenes[evt.index]["name"], generated_scenes)
return result
return (None, None, None, None, "No scene selected")
generated_gallery.select(
fn=handle_generated_selection,
outputs=[
reconstruction_output,
gs_output,
output_path_state,
image_gallery,
log_output,
],
)
else:
gr.Markdown("No generated examples available")
# -------------------------------------------------------------------------
# Click logic
# -------------------------------------------------------------------------
reconstruct_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
output_path_state,
frame_selector,
show_camera,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
outputs=[
reconstruction_output,
log_output,
frame_selector,
processed_data_state,
depth_map,
normal_map,
depth_view_slider,
normal_view_slider,
depth_view_info,
normal_view_info,
camera_params,
gs_output,
gs_rgb_video,
gs_depth_video,
terminal_output,
],
).then(
fn=lambda: "False",
inputs=[],
outputs=[is_example], # set is_example to "False"
)
# -------------------------------------------------------------------------
# Live update logic
# -------------------------------------------------------------------------
def refresh_3d_scene(
workspace_path,
frame_selector,
show_camera,
is_example,
filter_sky_bg=False,
show_mesh=False,
filter_ambiguous=False
):
"""
Refresh 3D scene visualization
Load prediction data from workspace, generate or reuse GLB scene files based on current parameters,
and return file paths needed for the 3D viewer.
Args:
workspace_path: Workspace directory path for reconstruction results
frame_selector: Frame selector value for filtering points from specific frames
show_camera: Whether to display camera positions
is_example: Whether this is an example scene
filter_sky_bg: Whether to filter sky background
show_mesh: Whether to display as mesh mode
filter_ambiguous: Whether to filter low-confidence ambiguous areas
Returns:
tuple: (GLB scene file path, Gaussian point cloud file path, status message)
"""
# If example scene is clicked, skip processing directly
if is_example == "True":
return (
gr.update(),
gr.update(),
"No reconstruction results available. Please click the Reconstruct button first.",
)
# Validate workspace directory path
if not workspace_path or workspace_path == "None" or not os.path.isdir(workspace_path):
return (
gr.update(),
gr.update(),
"No reconstruction results available. Please click the Reconstruct button first.",
)
# Check if prediction data file exists
prediction_file_path = os.path.join(workspace_path, "predictions.npz")
if not os.path.exists(prediction_file_path):
return (
gr.update(),
gr.update(),
f"Prediction file does not exist: {prediction_file_path}. Please run reconstruction first.",
)
# Load prediction data
prediction_data = np.load(prediction_file_path, allow_pickle=True)
predictions = {key: prediction_data[key] for key in prediction_data.keys() if key != 'splats'}
# Generate GLB scene file path (named based on parameter combination)
safe_frame_name = frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')
scene_filename = f"scene_{safe_frame_name}_cam{show_camera}_mesh{show_mesh}_edges{filter_ambiguous}_sky{filter_sky_bg}.glb"
scene_glb_path = os.path.join(workspace_path, scene_filename)
# If GLB file doesn't exist, generate new scene file
if not os.path.exists(scene_glb_path):
scene_model = convert_predictions_to_glb_scene(
predictions,
filter_by_frames=frame_selector,
show_camera=show_camera,
mask_sky_bg=filter_sky_bg,
as_mesh=show_mesh,
mask_ambiguous=filter_ambiguous
)
scene_model.export(file_obj=scene_glb_path)
# Find Gaussian point cloud file
gaussian_file_path = os.path.join(workspace_path, "gaussians.ply")
if not os.path.exists(gaussian_file_path):
gaussian_file_path = None
return (
scene_glb_path,
gaussian_file_path,
"3D scene updated.",
)
def refresh_view_displays_on_filter_update(
workspace_dir,
sky_background_filter,
current_processed_data,
depth_slider_position,
normal_slider_position,
):
"""
Refresh depth and normal view displays when filter settings change
When the background filter checkbox state changes, regenerate processed data and update all view displays.
This ensures that filter effects are reflected in real-time in the depth map and normal map visualizations.
Args:
workspace_dir: Workspace directory path containing prediction data and images
sky_background_filter: Sky background filter enable status
current_processed_data: Currently processed visualization data
depth_slider_position: Current position of the depth view slider
normal_slider_position: Current position of the normal view slider
Returns:
tuple: (updated processed data, depth visualization result, normal visualization result)
"""
# Validate workspace directory validity
if not workspace_dir or workspace_dir == "None" or not os.path.isdir(workspace_dir):
return current_processed_data, None, None
# Build and check prediction data file path
prediction_data_path = os.path.join(workspace_dir, "predictions.npz")
if not os.path.exists(prediction_data_path):
return current_processed_data, None, None
try:
# Load raw prediction data
raw_prediction_data = np.load(prediction_data_path, allow_pickle=True)
predictions_dict = {key: raw_prediction_data[key] for key in raw_prediction_data.keys()}
# Load image data using WorldMirror's load_images function
images_directory = os.path.join(workspace_dir, "images")
image_file_paths = [os.path.join(images_directory, path) for path in os.listdir(images_directory)]
img = load_and_preprocess_images(image_file_paths)
img = img.detach().cpu().numpy()
# Regenerate processed data with new filter settings
refreshed_data = {}
for view_idx in range(img.shape[1]):
view_data = {
"image": img[0, view_idx],
"points3d": predictions_dict["world_points"][view_idx],
"depth": None,
"normal": None,
"mask": None,
}
mask = predictions_dict["final_mask"][view_idx].copy()
if sky_background_filter:
sky_mask = predictions_dict["sky_mask"][view_idx]
mask = mask & sky_mask
view_data["mask"] = mask
view_data["depth"] = predictions_dict["depth"][view_idx].squeeze()
view_data["normal"] = predictions_dict["normal"][view_idx]
refreshed_data[view_idx] = view_data
# Get current view indices from slider positions (convert to 0-based indices)
current_depth_index = int(depth_slider_position) - 1 if depth_slider_position else 0
current_normal_index = int(normal_slider_position) - 1 if normal_slider_position else 0
# Update depth and normal views with new filter data
updated_depth_visualization = update_depth_view(refreshed_data, current_depth_index)
updated_normal_visualization = update_normal_view(refreshed_data, current_normal_index)
return refreshed_data, updated_depth_visualization, updated_normal_visualization
except Exception as error:
print(f"Error occurred while refreshing view displays: {error}")
return current_processed_data, None, None
frame_selector.change(
refresh_3d_scene,
[
output_path_state,
frame_selector,
show_camera,
is_example,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
[reconstruction_output, gs_output, log_output],
)
show_camera.change(
refresh_3d_scene,
[
output_path_state,
frame_selector,
show_camera,
is_example,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
[reconstruction_output, gs_output, log_output],
)
show_mesh.change(
refresh_3d_scene,
[
output_path_state,
frame_selector,
show_camera,
is_example,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
[reconstruction_output, gs_output, log_output],
)
filter_sky_bg.change(
refresh_3d_scene,
[
output_path_state,
frame_selector,
show_camera,
is_example,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
[reconstruction_output, gs_output, log_output],
).then(
fn=refresh_view_displays_on_filter_update,
inputs=[
output_path_state,
filter_sky_bg,
processed_data_state,
depth_view_slider,
normal_view_slider,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
],
)
filter_ambiguous.change(
refresh_3d_scene,
[
output_path_state,
frame_selector,
show_camera,
is_example,
filter_sky_bg,
show_mesh,
filter_ambiguous
],
[reconstruction_output, gs_output, log_output],
).then(
fn=refresh_view_displays_on_filter_update,
inputs=[
output_path_state,
filter_sky_bg,
processed_data_state,
depth_view_slider,
normal_view_slider,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
],
)
# -------------------------------------------------------------------------
# Auto update gallery when user uploads or changes files
# -------------------------------------------------------------------------
def update_gallery_on_file_upload(files, interval):
if not files:
return None, None, None, ""
# Capture terminal output
tee = TeeOutput()
old_stdout = sys.stdout
sys.stdout = tee
try:
target_dir, image_paths = process_uploaded_files(files, interval)
terminal_log = tee.getvalue()
sys.stdout = old_stdout
return (
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
terminal_log,
)
except Exception as e:
terminal_log = tee.getvalue()
sys.stdout = old_stdout
print(f"Error occurred: {e}")
raise
def resample_video_with_new_interval(files, new_interval, current_target_dir):
"""Resample video with new slider value"""
if not files:
return (
current_target_dir,
None,
"No files to resample.",
"",
)
# Check if we have videos to resample
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
has_video = any(
os.path.splitext(
str(file_data["name"] if isinstance(file_data, dict) else file_data)
)[1].lower()
in video_extensions
for file_data in files
)
if not has_video:
return (
current_target_dir,
None,
"No videos found to resample.",
"",
)
# Capture terminal output
tee = TeeOutput()
old_stdout = sys.stdout
sys.stdout = tee
try:
# Clean up old target directory if it exists
if (
current_target_dir
and current_target_dir != "None"
and os.path.exists(current_target_dir)
):
shutil.rmtree(current_target_dir)
# Process files with new interval
target_dir, image_paths = process_uploaded_files(files, new_interval)
terminal_log = tee.getvalue()
sys.stdout = old_stdout
return (
target_dir,
image_paths,
f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
terminal_log,
)
except Exception as e:
terminal_log = tee.getvalue()
sys.stdout = old_stdout
print(f"Error occurred: {e}")
raise
file_upload.change(
fn=update_gallery_on_file_upload,
inputs=[file_upload, time_interval],
outputs=[output_path_state, image_gallery, log_output, terminal_output],
)
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[file_upload, time_interval, output_path_state],
outputs=[output_path_state, image_gallery, log_output, terminal_output],
)
# -------------------------------------------------------------------------
# Navigation for Depth, Normal tabs
# -------------------------------------------------------------------------
def navigate_with_slider(processed_data, target_view):
"""Navigate to specified view using slider"""
if processed_data is None or len(processed_data) == 0:
return None, update_view_info(1, 1)
# Check if target_view is None or invalid value, and safely convert to int
try:
if target_view is None:
target_view = 1
else:
target_view = int(float(target_view)) # Convert to float first then int, handle decimal input
except (ValueError, TypeError):
target_view = 1
total_views = len(processed_data)
# Ensure view index is within valid range
view_index = max(1, min(target_view, total_views)) - 1
# Update depth map
depth_vis = update_depth_view(processed_data, view_index)
# Update view information
info_html = update_view_info(view_index + 1, total_views)
return depth_vis, info_html
def navigate_with_slider_normal(processed_data, target_view):
"""Navigate to specified normal view using slider"""
if processed_data is None or len(processed_data) == 0:
return None, update_view_info(1, 1, "Normal")
# Check if target_view is None or invalid value, and safely convert to int
try:
if target_view is None:
target_view = 1
else:
target_view = int(float(target_view)) # Convert to float first then int, handle decimal input
except (ValueError, TypeError):
target_view = 1
total_views = len(processed_data)
# Ensure view index is within valid range
view_index = max(1, min(target_view, total_views)) - 1
# Update normal map
normal_vis = update_normal_view(processed_data, view_index)
# Update view information
info_html = update_view_info(view_index + 1, total_views, "Normal")
return normal_vis, info_html
def handle_depth_slider_change(processed_data, target_view):
return navigate_with_slider(processed_data, target_view)
def handle_normal_slider_change(processed_data, target_view):
return navigate_with_slider_normal(processed_data, target_view)
depth_view_slider.change(
fn=handle_depth_slider_change,
inputs=[processed_data_state, depth_view_slider],
outputs=[depth_map, depth_view_info]
)
normal_view_slider.change(
fn=handle_normal_slider_change,
inputs=[processed_data_state, normal_view_slider],
outputs=[normal_map, normal_view_info]
)
# -------------------------------------------------------------------------
# Real-time terminal output update
# -------------------------------------------------------------------------
# Use a timer to periodically update terminal output
timer = gr.Timer(value=0.5) # Update every 0.5 seconds
timer.tick(
fn=get_terminal_output,
inputs=[],
outputs=[terminal_output]
)
gr.HTML("""
""")
demo.queue().launch(
show_error=True,
share=True,
ssr_mode=False,
)