File size: 6,576 Bytes
2a6e562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

FlashWorld is a high-quality 3D scene generation system that creates 3D scenes from text or image prompts in ~7 seconds on a single A100/A800 GPU. The project uses diffusion-based transformers with Gaussian Splatting for 3D reconstruction.

**Key capabilities:**
- Fast 3D scene generation (7 seconds on A100/A800)
- Text-to-3D and Image-to-3D generation
- Supports 24GB GPU memory configurations
- Outputs 3D Gaussian Splatting (.ply) files

## Running the Application

### Local Demo (Flask + Custom UI)
```bash
python app.py --port 7860 --gpu 0 --cache_dir ./tmpfiles --max_concurrent 1
```

Access the web interface at `http://HOST_IP:7860`

**Important flags:**
- `--offload_t5`: Offload text encoding to CPU to reduce GPU memory (trades speed for memory)
- `--ckpt`: Path to custom checkpoint (auto-downloads from HuggingFace if not provided)
- `--max_concurrent`: Maximum concurrent generation tasks (default: 1)

### ZeroGPU Demo (Gradio)
```bash
python app_gradio.py
```

**ZeroGPU Configuration:**
- Uses `@spaces.GPU(duration=15)` decorator with 15-second GPU budget
- Model loading happens **outside** GPU decorator scope (in global scope)
- Gradio 5.49.1+ required
- Compatible with Hugging Face Spaces ZeroGPU hardware
- Automatically downloads model checkpoint from HuggingFace Hub

### Installation
Dependencies are in `requirements.txt`. Key packages:
- PyTorch 2.6.0 with CUDA support
- Custom gsplat version from specific commit
- Custom diffusers version from specific commit

Install with:
```bash
pip install -r requirements.txt
```

## Architecture

### Core Components

**GenerationSystem** (app.py:90-346)
- Main neural network system combining VAE, text encoder, transformer, and 3D reconstruction
- Key submodules:
  - `vae`: AutoencoderKLWan for image encoding/decoding (from Wan2.2-TI2V-5B model)
  - `text_encoder`: UMT5 for text embedding
  - `transformer`: WanTransformer3DModel for diffusion denoising
  - `recon_decoder`: WANDecoderPixelAligned3DGSReconstructionModel for 3D Gaussian Splatting reconstruction
- Uses flow matching scheduler with 4 denoising steps
- Implements feedback mechanism where previous predictions inform next denoising step

**Key Generation Pipeline:**
1. Text/image prompt β†’ text embeddings + optional image latents
2. Create raymaps from camera parameters (6DOF)
3. Iterative denoising with 3D feedback loop (4 steps at timesteps [0, 250, 500, 750])
4. Final prediction β†’ 3D Gaussian parameters β†’ render to images
5. Export to PLY file format

### Model Files

**models/transformer_wan.py**
- 3D transformer for video diffusion (adapted from Wan2.2 model)
- Handles temporal + spatial attention with RoPE (Rotary Position Embeddings)

**models/reconstruction_model.py**
- `WANDecoderPixelAligned3DGSReconstructionModel`: Converts latent features to 3D Gaussian parameters
- `PixelAligned3DGS`: Per-pixel Gaussian parameter prediction
- Outputs: positions (xyz), opacity, scales, rotations, SH features

**models/autoencoder_kl_wan.py**
- VAE for image encoding/decoding (WAN architecture)
- Custom 3D causal convolutions adapted for single-frame processing

**models/render.py**
- Gaussian Splatting rasterization using gsplat library

**utils.py**
- Camera utilities: normalize_cameras, create_rays, create_raymaps
- Quaternion operations: quaternion_to_matrix, matrix_to_quaternion, quaternion_slerp
- Camera interpolation: sample_from_dense_cameras, sample_from_two_pose
- Export: export_ply_for_gaussians

### Gradio Interface (app_gradio.py)

**ZeroGPU Integration:**
- Model initialized in global scope (outside @spaces.GPU decorator)
- `generate_scene()` function decorated with `@spaces.GPU(duration=15)`
- Accepts image prompts (PIL), text prompts, camera JSON, and resolution
- Returns PLY file and status message
- Uses Gradio Progress API for user feedback

**Input Format:**
- Image: PIL Image (optional)
- Text: String prompt (optional)
- Camera JSON: Array of camera dictionaries with `quaternion`, `position`, `fx`, `fy`, `cx`, `cy`
- Resolution: String format "NxHxW" (e.g., "24x480x704")

### Flask API (app.py - Local Only)

**Concurrency Management** (concurrency_manager.py)
- Thread-pool based task queue for handling multiple generation requests
- Task states: QUEUED β†’ RUNNING β†’ COMPLETED/FAILED
- Automatic cleanup of old cached files (30 minute TTL)

**API Endpoints:**
- `POST /generate`: Submit generation task (returns task_id immediately)
- `GET /task/<task_id>`: Poll task status and get results
- `GET /download/<file_id>`: Download generated PLY file
- `DELETE /delete/<file_id>`: Clean up generated files
- `GET /status`: Get queue status
- `GET /`: Serve web interface (index.html)

**Request Format:**
```json
{
  "image_prompt": "<base64 or path>",  // optional
  "text_prompt": "...",
  "cameras": [{"quaternion": [...], "position": [...], "fx": ..., "fy": ..., "cx": ..., "cy": ...}],
  "resolution": [n_frames, height, width],
  "image_index": 0  // which frame to condition on
}
```

### Camera System

Cameras are represented as 11D vectors: `[qw, qx, qy, qz, tx, ty, tz, fx, fy, cx, cy]`
- First 4: quaternion rotation (real-first convention)
- Next 3: translation
- Last 4: intrinsics (normalized by image dimensions)

**Camera normalization** (utils.py:269-296):
- Centers scene around first camera
- Normalizes translation scale based on max camera distance
- Critical for stable 3D generation

## Development Notes

### Memory Management
- Model uses FP8 quantization (quant.py) for transformer to reduce memory
- VAE and text encoder can be offloaded to CPU with `--offload_t5` and `--offload_vae` flags
- Checkpoint mechanism for decoder to reduce memory during training

### Key Constants
- Latent dimension: 48 channels
- Temporal downsample: 4x
- Spatial downsample: 16x
- Feature dimension: 1024 channels
- Latent patch size: 2
- Denoising timesteps: [0, 250, 500, 750]

### Model Weights
- Primary checkpoint auto-downloads from HuggingFace: `imlixinyang/FlashWorld`
- Base diffusion model: `Wan-AI/Wan2.2-TI2V-5B-Diffusers`
- Model is adapted with additional input/output channels for 3D features

### Rendering
- Uses gsplat 1.5.2 for differentiable Gaussian Splatting
- SH degree: 2 (supports spherical harmonics up to degree 2)
- Background modes: 'white', 'black', 'random'
- Output FPS: 15

## License

CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) - Academic research use only.