Is a 3D-Tokenized LLM the Key to Reliable Autonomous Driving?
Paper
โข 2405.18361 โข Published
ๅบไบ Atlas ่ฎบๆ ็ๅคๆจกๆ่ชๅจ้ฉพ้ฉถๅคง่ฏญ่จๆจกๅๅฎ็ฐใๅฐ StreamPETR๏ผ3D ็ฎๆ ๆฃๆต๏ผๅ TopoMLP๏ผ่ฝฆ้็บฟๆฃๆต๏ผๆๅ็ 3D visual tokens ๆณจๅ ฅ Vicuna-7B LLM๏ผๅฎ็ฐๆฃๆตใ่ฝฆ้็บฟใ่งๅ็ญๅคไปปๅก็ปไธ็ๆใ
3dtokenizer-atlas/
โโโ train_atlas.py # Atlas LLM ่ฎญ็ปๅ
ฅๅฃ
โโโ eval_atlas.py # Atlas ่ฏไผฐๅ
ฅๅฃ
โโโ extract_streampetr_tokens.py # ้ขๆๅ StreamPETR detection tokens
โโโ extract_topomlp_tokens.py # ้ขๆๅ TopoMLP raw outputs (lane ๆ ทๆฌ)
โโโ train_streampetr.sh # StreamPETR ้ข่ฎญ็ปๅฏๅจ่ๆฌ
โโโ train_topomlp.sh # TopoMLP ้ข่ฎญ็ปๅฏๅจ่ๆฌ
โ
โโโ configs/
โ โโโ streampetr_atlas_aligned.py # StreamPETR ้
็ฝฎ (EVA-02 ViT-L, 800x1600)
โ โโโ topomlp_atlas_aligned.py # TopoMLP ้
็ฝฎ (EVA-02 ViT-L, 800x1600)
โ โโโ ds_zero2.json # DeepSpeed ZeRO-2 ้
็ฝฎ
โ โโโ REPRODUCTION.md # ๅค็ฐๆๆกฃ
โ
โโโ src/
โ โโโ model/
โ โ โโโ modeling_atlas.py # AtlasForCausalLM ไธปๆจกๅ
โ โ โโโ streampetr_adapter.py # StreamPETR โ ๆฃๆต token ้้
ๅจ
โ โ โโโ topomlp_adapter.py # TopoMLP โ ๅฐๅพ token ้้
ๅจ (Perceiver resampler)
โ โ โโโ token_resampler.py # CrossAttentionTokenResampler
โ โโโ dataset/
โ โ โโโ atlas_dataset.py # AtlasDataset + Collate
โ โ โโโ scene_sampler.py # SceneSequentialSampler (ๆถๅบ้ๆ ท)
โ โโโ eval/
โ โ โโโ metrics.py # ่ฏไผฐๆๆ (F1/Chamfer/L2/Collision)
โ โโโ prompting.py # ๅคไปปๅก Prompt ๆจกๆฟ
โ
โโโ scripts/
โ โโโ gen_atlas_full_data.py # nuScenes โ ๆฃๆต QA JSON
โ โโโ gen_atlas_openlane_subsetB_lane_qa.py # OpenLane-V2 โ ่ฝฆ้็บฟ QA JSON
โ โโโ gen_atlas_planning_qa.py # nuScenes โ ่งๅ QA JSON
โ
โโโ data/ # ่ฎญ็ป/้ช่ฏๆฐๆฎ (JSON)
โ โโโ atlas_nuscenes_train.json # ๆฃๆต (28,130 ๆ ทๆฌ)
โ โโโ atlas_nuscenes_val.json # ๆฃๆต้ช่ฏ (6,019 ๆ ทๆฌ)
โ โโโ openlane_subsetB_lane_train_4pt.json # ่ฝฆ้็บฟ (27,968 ๆ ทๆฌ, 4 ็น/lane)
โ โโโ openlane_subsetB_lane_val_4pt.json # ่ฝฆ้็บฟ้ช่ฏ (6,019 ๆ ทๆฌ)
โ โโโ atlas_planning_train.json # ่งๅ (23,541 ๆ ทๆฌ)
โ โโโ atlas_planning_val.json # ่งๅ้ช่ฏ (5,037 ๆ ทๆฌ)
โ
โโโ pretrained/ # ้ข่ฎญ็ปๆ้
โ โโโ vicuna-7b-v1.5/ # Vicuna-7B-v1.5 LLM
โ โโโ eva02_L_coco_det_sys_o365_remapped_fixed.pth
โ โโโ streampetr/
โ โโโ streampetr_eva02_ep24.pth
โ
โโโ work_dirs/
โ โโโ atlas_full_repro/ # ๅฝๅ่ฎญ็ป่พๅบ
โ โโโ precomputed_det_tokens/ # ้ขๆๅ็ StreamPETR tokens
โ โ โโโ train/ # 56,098 ไธช .pt ๆไปถ (nuScenes + OpenLane)
โ โโโ precomputed_map_tokens/ # ้ขๆๅ็ TopoMLP raw outputs
โ โ โโโ train/ # 27,968 ไธช .pt ๆไปถ (OpenLane lane ๆ ทๆฌ)
โ โโโ topomlp_atlas_aligned/ # TopoMLP ้ข่ฎญ็ปๆ้
โ โโโ epoch_24.pth
โ
โโโ external/ # ๅค้จไพ่ต
โโโ StreamPETR/
โโโ TopoMLP_Repo/
โโโ nuscenes-devkit/
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
6x ็ฏ่ง็ธๆบๅพ็ โ โ StreamPETR (frozen, EVA-02 ViT-L) โโ det tokens [B, 256, 256]
โ TopoMLP (frozen, EVA-02 ViT-L) โโ lane queries โ Resampler โ map tokens [B, 256, 256]
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
AtlasUnifiedProjector
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ projector_det: Linear(256โ4096) โ โ ๅๅฑ็บฟๆงๆๅฝฑ
โ projector_map: Linear(256โ4096) โ
โ projector_rp: Linear(3โ256) โ โ Reference Point, zero-init
โ features += projector_rp(ref) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
ๆณจๅ
ฅๅฐ <query> token ไฝ็ฝฎ (256 det + 256 map)
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Vicuna-7B (ๅ
จๅๆฐๅพฎ่ฐ, DeepSpeed) โ
โ Causal Language Modeling Loss โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
ๅคไปปๅกๆๆฌ่พๅบ
(3D ๆฃๆต / ่ฝฆ้็บฟ / ่งๅ่ฝจ่ฟน)
ไธ่ฎบๆ (arXiv:2405.18361) Appendix B.2 ๅฏน้ฝใ
| ๅๆฐ | ๅผ |
|---|---|
| LLM | Vicuna-7B-v1.5 |
| ๅพฎ่ฐๆนๅผ | ๅ จๅๆฐๅพฎ่ฐ (ๆ LoRA) |
| ๅฏ่ฎญ็ปๅๆฐ | 6,740,530,176 |
| Learning Rate | 2e-5 |
| Optimizer | AdamW (weight_decay=1e-4, torch_adam, adam_w_mode) |
| LR Schedule | Cosine with warmup (3% steps) |
| Epochs | 8 |
| Batch Size | 1 per GPU |
| Gradient Accumulation | 1 |
| Effective Batch Size | 8 (8 GPU x 1 x 1 accum) |
| Total Steps | 79,639 |
| Warmup Steps | 2,389 |
| Max Sequence Length | 4096 tokens |
| ๅๅธๅผ | DeepSpeed ZeRO-2 (optimizer sharding) |
| GPU | 8x NVIDIA A100 80GB |
| ็ฒพๅบฆ | BF16 (model + gradients, via DeepSpeed bf16), optimizer states sharded |
| Memory Queue | StreamPETR temporal modeling (3 frames, top-256, FIFO) |
| ไปปๅก | ๆฐๆฎๆไปถ | ๆ ทๆฌๆฐ |
|---|---|---|
| 3D ็ฎๆ ๆฃๆต | atlas_nuscenes_train.json |
28,130 |
| 3D ่ฝฆ้็บฟๆฃๆต | openlane_subsetB_lane_train_4pt.json |
27,968 |
| ่ฝจ่ฟน่งๅ | atlas_planning_train.json |
23,541 |
| ๆป่ฎก | 79,639 |
่ฝฆ้็บฟๆฐๆฎไฝฟ็จ 4 ไธชๅๅ้ๆ ท็น/lane (ไธ่ฎบๆ Appendix A.2 ไธ่ด)ใๆๆๅๆ ไฝฟ็จ 1000-bin ็ฆปๆฃๅ๏ผBEV ่ๅด [-50m, +50m]ใ
| ๅๆฐ | StreamPETR | TopoMLP |
|---|---|---|
| Backbone | EVA-02 ViT-L (embed_dim=1024) | EVA-02 ViT-L (embed_dim=1024) |
| Resolution | 800x1600 | 800x1600 |
| Queries | 256 (detection) | 256 (map, resampled from 1800) |
| Control Points | - | 4 per lane |
| Epochs | 24 | 24 |
| ๆฐๆฎ้ | nuScenes trainval | OpenLane-V2 subset-B |
conda activate streampetr
# ไธป่ฆไพ่ต: PyTorch 2.0+, transformers, peft, flash-attn, mmcv 1.7, mmdet3d 1.0
# DeepSpeed (ZeRO-2): pip install deepspeed
# nuScenes ๆฐๆฎๆ น็ฎๅฝ (ๅซ v1.0-trainval/ ๅ samples/)
export DATA_ROOT=/path/to/nuscenes
# OpenLane-V2 subset-B
export OPENLANE_ROOT=/path/to/OpenLane-V2/subset_B
# ็ๆ่ฝฆ้็บฟ QA ๆฐๆฎ (4 ็น/lane, ไธ่ฎบๆไธ่ด)
python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
--openlane_root $OPENLANE_ROOT \
--split train --out_json data/openlane_subsetB_lane_train_4pt.json
python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
--openlane_root $OPENLANE_ROOT \
--split val --out_json data/openlane_subsetB_lane_val_4pt.json
# 8x A100 ๅ
จๅๆฐๅพฎ่ฐ + DeepSpeed ZeRO-2
# ้่ฆๅ
่ฟ่ก extract_streampetr_tokens.py ๅ extract_topomlp_tokens.py ้ขๆๅ tokens
# ๆๆ batch size = 8 GPU ร 1 ร 1 accum = 8 (ไธ่ฎบๆไธ่ด)
torchrun --nproc_per_node=8 train_atlas.py \
--llm_model pretrained/vicuna-7b-v1.5 \
--streampetr_config configs/streampetr_atlas_aligned.py \
--streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
--topomlp_config configs/topomlp_atlas_aligned.py \
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
--precomputed_det_tokens work_dirs/precomputed_det_tokens/train \
--precomputed_map_tokens work_dirs/precomputed_map_tokens/train \
--data_json data/atlas_nuscenes_train.json,data/atlas_planning_train.json,data/openlane_subsetB_lane_train_4pt.json \
--data_root /mnt/data/nuscenes \
--image_path_remap /home/guoyuanbo/autodl-tmp/OpenLane-V2=/mnt/OpenLane-V2 \
--output_dir work_dirs/atlas_full_repro \
--lr 2e-5 --weight_decay 1e-4 \
--batch_size 1 --epochs 8 --gradient_accumulation_steps 1 \
--warmup_ratio 0.03 --max_grad_norm 1.0 \
--save_epochs 2 --log_steps 100 \
--seed 42 --num_workers 2 \
--deepspeed configs/ds_zero2.json
python eval_atlas.py \
--checkpoint work_dirs/atlas_full_repro/final/checkpoint.pt \
--llm_model pretrained/vicuna-7b-v1.5 \
--topomlp_config configs/topomlp_atlas_aligned.py \
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
--data_json data/openlane_subsetB_lane_val_4pt.json \
--data_root $DATA_ROOT \
--batch_size 1 --max_new_tokens 512 --no_flash_attn