Atlas — 3D-Tokenized LLM for Autonomous Driving

基于 Atlas 论文 的多模态自动驾驶大语言模型实现。将 StreamPETR(3D 目标检测)和 TopoMLP(车道线检测)提取的 3D visual tokens 注入 Vicuna-7B LLM,实现检测、车道线、规划等多任务统一生成。

项目结构

3dtokenizer-atlas/
├── train_atlas.py                  # Atlas LLM 训练入口
├── eval_atlas.py                   # Atlas 评估入口
├── extract_streampetr_tokens.py    # 预提取 StreamPETR detection tokens
├── extract_topomlp_tokens.py      # 预提取 TopoMLP lane tokens
├── train_streampetr.sh             # StreamPETR 预训练启动脚本
├── train_topomlp.sh                # TopoMLP 预训练启动脚本
│
├── configs/
│   ├── streampetr_atlas_aligned.py # StreamPETR 配置 (EVA-02 ViT-L, 800x1600)
│   ├── topomlp_atlas_aligned.py    # TopoMLP 配置 (EVA-02 ViT-L, 800x1600)
│   ├── ds_zero2.json               # DeepSpeed ZeRO-2 配置
│   └── REPRODUCTION.md             # 复现文档
│
├── src/
│   ├── model/
│   │   ├── modeling_atlas.py       # AtlasForCausalLM 主模型
│   │   ├── streampetr_adapter.py   # StreamPETR → 检测 token 适配器
│   │   └── topomlp_adapter.py      # TopoMLP → 地图 token 适配器 (Top-K selection)
│   ├── dataset/
│   │   ├── atlas_dataset.py        # AtlasDataset + Collate
│   │   └── scene_sampler.py        # SceneSequentialSampler (时序采样)
│   ├── eval/
│   │   └── metrics.py              # 评估指标 (F1/Chamfer/L2/Collision)
│   └── prompting.py                # 多任务 Prompt 模板
│
├── scripts/
│   ├── gen_atlas_full_data.py               # nuScenes → 检测 QA JSON
│   ├── gen_atlas_openlane_subsetB_lane_qa.py # OpenLane-V2 → 车道线 QA JSON
│   ├── gen_atlas_planning_qa.py             # nuScenes → 规划 QA JSON
│   ├── train_no_caption_baseline.sh         # 无 caption 训练脚本
│   └── train_with_caption_balanced.sh       # 含 caption 训练脚本
│
├── data/                                    # 训练/验证数据 (JSON)
│   ├── atlas_nuscenes_train.json            # 检测 (28,130 样本)
│   ├── atlas_nuscenes_val.json              # 检测验证 (6,019 样本)
│   ├── openlane_subsetB_lane_train_4pt.json # 车道线 (27,968 样本, 4 点/lane)
│   ├── openlane_subsetB_lane_val_4pt.json   # 车道线验证 (6,019 样本)
│   ├── atlas_planning_train_uniad_command.json  # 规划 (23,541 样本, UniAD-style command)
│   ├── atlas_planning_val_uniad_command.json    # 规划验证 (5,037 样本, UniAD-style command)
│   ├── atlas_caption_train.json             # 环境描述 caption
│   └── atlas_caption_val.json               # 环境描述 caption 验证
│
├── pretrained/                     # 预训练权重
│   ├── vicuna-7b-v1.5/            # Vicuna-7B-v1.5 LLM
│   ├── eva02_L_coco_det_sys_o365_remapped_fixed.pth
│   └── streampetr/
│       └── streampetr_eva02_ep24.pth
│
├── work_dirs/
│   ├── _quick_eval_cpu.py          # 快速检测评估 (CPU, micro-avg F1)
│   ├── _quick_eval_gpu.py          # 快速检测评估 (GPU)
│   ├── _quick_eval_lane_gpu.py     # 快速车道线评估 (GPU)
│   ├── _quick_eval_plan_gpu.py     # 快速规划评估 (GPU, scene-sequential)
│   ├── precomputed_det_tokens_offline/  # 预提取的 StreamPETR tokens (offline 备选)
│   │   ├── train/                  # 56,099 个 .pt 文件 (det + lane,planning 与 det 共享 ID)
│   │   └── val/                    # 12,039 个 .pt 文件
│   ├── precomputed_map_tokens_offline/  # 预提取的 TopoMLP tokens (offline 备选)
│   │   ├── train/                  # 51,510 个 .pt 文件 (lane + planning)
│   │   └── val/                    # 11,057 个 .pt 文件
│   └── topomlp_atlas_aligned/     # TopoMLP 预训练权重
│       └── epoch_24.pth
│
└── external/                       # 外部依赖
    ├── StreamPETR/
    ├── TopoMLP_Repo/
    └── nuscenes-devkit/

模型架构

                   ┌─────────────────────────────────────┐
  6x 环视相机图片 → │ StreamPETR (frozen, EVA-02 ViT-L)    │→ det tokens [B, 256, 256]
                   │ TopoMLP   (frozen, EVA-02 ViT-L)    │→ one-to-one lane queries (300) → Top-K → map tokens [B, 256, 256]
                   └─────────────────────────────────────┘
                                    ↓
                         AtlasUnifiedProjector
                     ┌────────────────────────────────┐
                     │ projector_det: Linear(256→4096) │  ← 单层线性投影
                     │ projector_map: Linear(256→4096) │
                     │ projector_rp:  Linear(3→256)    │  ← Reference Point, zero-init
                     │ features += projector_rp(ref)   │
                     └────────────────────────────────┘
                                    ↓
                    注入到 <query> token 位置 (256 det + 256 map)
                                    ↓
                   ┌───────────────────────────────────────┐
                   │   Vicuna-7B (当前运行: 全参数微调)      │
                   │   Causal Language Modeling Loss        │
                   └───────────────────────────────────────┘
                                    ↓
                         多任务文本输出
              (3D 检测 / 车道线 / 规划轨迹)

训练配置(来源与优先级)

为避免“论文描述 / 脚本默认 / 实际运行”混淆,本仓库统一按以下优先级解释配置:

  1. 论文描述:以 Atlas 原论文(arXiv:2405.18361)正文 Section 3.2 与 Appendix B.2 为准。
  2. 代码默认:以 train_atlas.py 的 argparse 默认值为准。
  3. 实际运行:以 work_dirs/<exp>/args.json + work_dirs/*train*.log 为准(最高优先级)。

若三者冲突,请按第 3 条解释实验结果,不以 README 中示例命令覆盖真实运行参数。

论文原文(Atlas LLM,Section 3.2 + Appendix B.2)

项目 论文描述
Batch Size 1 per GPU
Learning Rate 2e-5
Optimizer AdamW (weight_decay=1e-4)
LR Schedule 3% linear warmup + cosine decay
Max Length 4096
硬件/时长 8x A100,约 100 小时
LoRA 论文附录未显式给出 LoRA 开关

当前实际运行(示例:work_dirs/atlas_no_caption_v3_linear_warmup

来源:work_dirs/atlas_no_caption_v3_linear_warmup/args.jsonwork_dirs/train_no_caption_v3_linear_warmup.log

参数 实际值
LLM Vicuna-7B-v1.5
微调方式 全参数微调 (use_lora=false)
可训练参数 6,740,530,176
Learning Rate 2e-5
Optimizer AdamW (weight_decay=1e-4, torch_adam, adam_w_mode)
LR Schedule WarmupCosineLR(warmup ratio=3%)
Epochs 10
Batch Size 1 per GPU
Gradient Accumulation 2
Effective Batch Size 8 (4 GPU x 1 x 2 accum)
Total Steps 99,550
Warmup Steps 2,986
Max Sequence Length 4096
分布式 DeepSpeed ZeRO-2
GPU 4x NVIDIA H100 80GB
精度 BF16(由 configs/ds_zero2.json 启用)
Visual Tokens 在线 (live frozen StreamPETR + TopoMLP, temporal memory);离线预提取仅作为 fallback

训练数据

任务 数据文件 样本数
3D 目标检测 atlas_nuscenes_train.json 28,130
3D 车道线检测 openlane_subsetB_lane_train_4pt.json 27,968
轨迹规划 atlas_planning_train_uniad_command.json 23,541
环境描述 (可选) atlas_caption_train.json
总计 (无 caption) 79,639

车道线数据使用 4 个采样点/lane(论文 Appendix A.2 要求 four lane points,本仓库实现为均匀采样),不设 lane 数量上限(论文未指定上限),按 BEV 距离近→远排序。实际平均约 25 条 lane/样本,最多约 80 条。所有坐标使用 1000-bin 离散化。规划数据包含 gt_boxes_3d_per_timestep 字段用于 ST-P3 对齐的 per-timestep 碰撞评测。

三类主任务的 question pool 统一采用“前 3 条按论文 Table 6 / 7 / 9 原话整理,后 2 条为仓库补充的同风格扩展模板”的策略;其中车道线 Table 7 的第 2 条按论文现有文本原样保留。

为避免运行时再依赖 prompt 文本猜任务,四类样本的磁盘 JSON 统一显式写入 task 字段:detection / lane / planning / caption。这是仓库层面的工程化 schema,对论文中的 question-answer 文本格式不做额外解释。

caption 数据按论文 Appendix A.3 的单视角设定生成:Table 8 作为 GPT-4V 标注 prompt,human prompt 采用 Figure 5 风格的单模板,并注入具体 camera_name

当前仓库不再向 prompt 追加 bins-format hint;detection / lane / caption 默认以磁盘 JSON 中的 human 文本作为 prompt 主体语义来源。planning 任务运行时仍会按 planning_table3_mode 对磁盘 human prompt 做轻量重写,只负责插入/剥离 command 句和 ego-state 句,再统一做 <query> 展开、空白归一化和 USER: ... / ASSISTANT: 包装。

当前 detection 的 canonical answer 格式为:category: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin]; category: [x_bin, y_bin, z_bin].。当前 lane/map 的 canonical answer 格式为:Lane: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin], ...; [x_bin, y_bin, z_bin], ... .。旧的 detection flat 文本和 lane_centerline(id=...) legacy 文本不再作为受支持协议。

planning 的 answer/output protocol 采用 Figure 5 风格表述,但保持论文 Table 9 的二维语义:Ego car speed value:[vx_bin, vy_bin]. Ego car acceleration value:[ax_bin, ay_bin]. Based on the ego car speed and acceleration you predicted, requeset the ego car planning waypoint in 3-seconds: [x_bin, y_bin], ...。当前实现不为 planning 引入第三维,也不使用固定 z=500 占位。

当前 planning JSON 的顶层 route_command 采用 UniAD-style future-GT-derived command:根据 future planning GT / future waypoints 的最后一个有效 timestep 的横向位移离散为 turn left / turn right / go straight。它不是 raw nuScenes 原生字段,也不是独立导航命令;因此 atlas_high_level* 在本仓库中的含义更接近 UniAD 风格条件输入,而不是 Atlas 论文 Table 3 严格意义上的独立 route command。

3D Tokenizer 预训练 (已完成)

参数 StreamPETR TopoMLP
Backbone EVA-02 ViT-L (embed_dim=1024) EVA-02 ViT-L (embed_dim=1024)
Resolution 800x1600 800x1600
Queries 256 (detection) 256 (map, Top-K from 300 one-to-one queries)
Control Points - 4 per lane
Epochs 24 24
数据集 nuScenes trainval OpenLane-V2 subset-B

快速开始

1. 环境

conda activate streampetr
# 主要依赖: PyTorch 2.0+, transformers, peft, flash-attn, mmcv 1.7, mmdet3d 1.0
# DeepSpeed (ZeRO-2): pip install deepspeed

2. 数据准备

# nuScenes 数据根目录 (含 v1.0-trainval/ 和 samples/)
export DATA_ROOT=/path/to/nuscenes

# OpenLane-V2 subset-B
export OPENLANE_ROOT=/path/to/OpenLane-V2/subset_B

# 生成检测 QA 数据 (按类别分组, 论文 Figure 5 格式)
python scripts/gen_atlas_full_data.py \
  --data-root $DATA_ROOT --split train \
  --output data/atlas_nuscenes_train.json
python scripts/gen_atlas_full_data.py \
  --data-root $DATA_ROOT --split val \
  --output data/atlas_nuscenes_val.json

# 生成车道线 QA 数据 (4 点/lane, 无 lane 数量上限, BEV 距离排序)
python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
  --openlane_root $OPENLANE_ROOT \
  --split train --out_json data/openlane_subsetB_lane_train_4pt.json

python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
  --openlane_root $OPENLANE_ROOT \
  --split val --out_json data/openlane_subsetB_lane_val_4pt.json

# 生成规划 QA 数据
# 默认输出:
#   data/atlas_planning_train_uniad_command.json
#   data/atlas_planning_val_uniad_command.json
# 默认写顶层 route_command(UniAD-style future-GT-derived command)
# 默认 materialize atlas_high_level human prompt;运行时仍可通过
# --planning_table3_mode 改写为 atlas_base / atlas_high_level_ego
python scripts/gen_atlas_planning_qa.py \
  --data-root $DATA_ROOT --split train
python scripts/gen_atlas_planning_qa.py \
  --data-root $DATA_ROOT --split val

3. 训练

默认使用 在线模式--visual_token_mode online),训练时 live 前向 frozen StreamPETR(含 temporal memory)和 TopoMLP,无需预提取 token。

# ===== 推荐:在线模式训练(默认)=====
# 无 caption: det + planning + lane
bash scripts/train_no_caption_baseline.sh

# 含 caption: det + planning + lane + caption
bash scripts/train_with_caption_balanced.sh

等效手动命令(以无 caption 为例):

deepspeed --num_gpus 4 train_atlas.py \
  --llm_model pretrained/vicuna-7b-v1.5 \
  --data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
  --data_root $DATA_ROOT \
  --visual_token_mode online \
  --streampetr_config configs/streampetr_atlas_aligned.py \
  --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
  --topomlp_config configs/topomlp_atlas_aligned.py \
  --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
  --deepspeed configs/ds_zero2.json \
  --output_dir work_dirs/atlas_no_caption_online \
  --lr 2e-5 --weight_decay 1e-4 \
  --batch_size 1 --epochs 10 --gradient_accumulation_steps 2 \
  --warmup_ratio 0.03 --max_grad_norm 1.0 \
  --save_epochs 1 --log_steps 100 \
  --seed 42 --num_workers 4

离线 fallback:如需使用预提取 token 训练(速度更快,但 det 无 temporal memory), 先运行预提取脚本,再使用 bash scripts/train_no_caption_baseline_offline.sh。 预提取 token 存放于 work_dirs/precomputed_*_tokens_offline/

4. 评估

eval_atlas.py 支持两种模式:在线模式(默认,live frozen encoder + temporal memory)和离线 fallback。

# ===== 推荐:在线模式评估(默认)=====
# 检测
bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_nuscenes_val.json

# 车道线
bash scripts/eval_checkpoint.sh <checkpoint> data/openlane_subsetB_lane_val_4pt.json

# 规划
bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_planning_val_uniad_command.json

等效手动命令:

python eval_atlas.py \
  --checkpoint work_dirs/atlas_no_caption_online/final/checkpoint.pt \
  --llm_model pretrained/vicuna-7b-v1.5 \
  --visual_token_mode online \
  --streampetr_config configs/streampetr_atlas_aligned.py \
  --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
  --topomlp_config configs/topomlp_atlas_aligned.py \
  --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
  --data_json data/atlas_nuscenes_val.json \
  --data_root $DATA_ROOT \
  --batch_size 1 --max_new_tokens 2700 --bf16

离线 fallback:使用预提取 token 评估(速度更快,但 det 无 temporal memory): bash scripts/eval_checkpoint_offline.sh <checkpoint> <data_json>

快速验证脚本work_dirs/_quick_eval_*.py)使用离线 token,仅用于开发调试,不等价于主在线评测,不应用于正式结果。其口径与主评测存在差异:planning 解析更宽松、不走 live encoders/temporal memory、不自动检测 LoRA。

评测协议说明

  • 检测:micro-averaged F1 @ 0.5/1.0/2.0/4.0m,BEV 中心距离匹配。
  • 车道线:使用 OpenLane-V2 官方 F-Score 评测器(openlanev2 为必需依赖,缺失时直接报错,不再退化为 Chamfer)。
  • 规划:L2 误差 + 碰撞率。规划数据含 gt_boxes_3d_per_timestep 字段时使用 ST-P3 对齐的 per-timestep 碰撞检测;旧数据自动退化为静态 box 检测。
  • 在线主评测(eval_atlas.py)需要 mmcvmmdet3dopenlanev2 三个关键依赖,缺失时会在启动前报错。

参考

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for guoyb0/Atlas-online