Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
d942a8d
1
Parent(s):
3dcb152
First commit
Browse files- app.py +126 -0
- canvas.py +295 -0
- requirements.txt +8 -0
- src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png +0 -0
- src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png +0 -0
- src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png +0 -0
- src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png +0 -0
- src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png +0 -0
- src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png +0 -0
- src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png +0 -0
- src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png +0 -0
app.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from canvas import Idefics2Pipeline
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def run_canvas(front_view, map_view, prompt):
|
| 12 |
+
pipeline = Idefics2Pipeline.from_pretrained(
|
| 13 |
+
"maum-ai/CANVAS-S"
|
| 14 |
+
)
|
| 15 |
+
messages = [
|
| 16 |
+
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
| 17 |
+
{
|
| 18 |
+
"role": "user",
|
| 19 |
+
"content": [{"type": "image"}, {"type": "image"}],
|
| 20 |
+
},
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
print(front_view)
|
| 24 |
+
|
| 25 |
+
images = [front_view, map_view]
|
| 26 |
+
pred = pipeline([messages], [images], return_traj=False)
|
| 27 |
+
pred_action = re.findall(r"<ACTION_(\d+)>", pred[0])
|
| 28 |
+
pred_action = np.array(pred_action, dtype=np.int64)
|
| 29 |
+
print(pred_action)
|
| 30 |
+
pred_action_odom = pipeline.action_tokenizer.detokenize(pred_action).tolist()
|
| 31 |
+
print(pred_action_odom)
|
| 32 |
+
|
| 33 |
+
# Create a figure and axes
|
| 34 |
+
fig, axes = plt.subplots(1, 1, figsize=(8, 6))
|
| 35 |
+
|
| 36 |
+
# Scale factor for the arrow
|
| 37 |
+
scale_factor = 0.2
|
| 38 |
+
|
| 39 |
+
axes.plot(0, 0, marker="o", color="black", markersize=10)
|
| 40 |
+
axes.invert_xaxis()
|
| 41 |
+
|
| 42 |
+
for i, center in zip(pred_action, pred_action_odom):
|
| 43 |
+
x, y, yaw = center
|
| 44 |
+
axes.plot(y, x, marker="^", color="blue")
|
| 45 |
+
axes.arrow(
|
| 46 |
+
y,
|
| 47 |
+
x,
|
| 48 |
+
np.sin(yaw) * scale_factor,
|
| 49 |
+
np.cos(yaw) * scale_factor,
|
| 50 |
+
head_width=scale_factor * 0.3,
|
| 51 |
+
head_length=scale_factor * 0.3,
|
| 52 |
+
fc="k",
|
| 53 |
+
ec="k",
|
| 54 |
+
)
|
| 55 |
+
axes.text(y, x, f"{i}", fontsize=10)
|
| 56 |
+
axes.axis("equal")
|
| 57 |
+
axes.grid(True)
|
| 58 |
+
|
| 59 |
+
buf = BytesIO()
|
| 60 |
+
fig.savefig(buf, format="png")
|
| 61 |
+
buf.seek(0) # Rewind the buffer to the beginning
|
| 62 |
+
pil_img = Image.open(buf)
|
| 63 |
+
|
| 64 |
+
return pil_img
|
| 65 |
+
|
| 66 |
+
examples = [
|
| 67 |
+
["src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png", "src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png", """You are an indoor food-serving robot.
|
| 68 |
+
|
| 69 |
+
You must follow these driving instructions:
|
| 70 |
+
1. You must avoid collisions.
|
| 71 |
+
2. You should prioritize reaching the final destination.
|
| 72 |
+
3. You should follow the Trajectory Instruction.
|
| 73 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
| 74 |
+
b. You should try to evade any identifiable obstacles.
|
| 75 |
+
4. You should maintain a constant driving speed.
|
| 76 |
+
a. Indoors, you should drive at a speed of 3-4km/h.
|
| 77 |
+
5. You must slow down(2km/h or lower) if a human or obstacle comes within 1.5m radius.
|
| 78 |
+
a. You must slow down(2km/h or lower) in areas where a human could suddenly appear from a blind spot."""],
|
| 79 |
+
["src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png", "src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png", """You are an outdoor speed-sprayer robot.
|
| 80 |
+
|
| 81 |
+
You must follow these driving instructions:
|
| 82 |
+
1. You must avoid collisions.
|
| 83 |
+
2. You should prioritize reaching the final destination.
|
| 84 |
+
3. You should follow the Trajectory Instruction.
|
| 85 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
| 86 |
+
b. You should try to evade any identifiable obstacles.
|
| 87 |
+
4. You should maintain a constant driving speed."""],
|
| 88 |
+
["src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png", "src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png", """You are an outdoor last mile delivery robot.
|
| 89 |
+
|
| 90 |
+
You must follow these driving instructions:
|
| 91 |
+
1. You must avoid collisions.
|
| 92 |
+
2. You should prioritize reaching the final destination.
|
| 93 |
+
3. You should follow the Trajectory Instruction.
|
| 94 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
| 95 |
+
b. You should try to evade any identifiable obstacles.
|
| 96 |
+
4. You should maintain a constant driving speed.
|
| 97 |
+
5. You must drive on the sidewalk.
|
| 98 |
+
a. If you need to cross the road, you must use the crosswalk."""],
|
| 99 |
+
["src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png", "src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png", """You are an outdoor self-driving robot taxi.
|
| 100 |
+
|
| 101 |
+
You must follow these driving instructions:
|
| 102 |
+
1. You must avoid collisions.
|
| 103 |
+
2. You should prioritize reaching the final destination.
|
| 104 |
+
3. You should follow the Trajectory Instruction.
|
| 105 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
| 106 |
+
b. You should try to evade any identifiable obstacles.
|
| 107 |
+
4. You should maintain a constant driving speed.
|
| 108 |
+
5. You must drive on the road.
|
| 109 |
+
a. You should drive according to the left-hand-traffic law.
|
| 110 |
+
6. You should slow down before entering intersections, speed bumps, and crosswalks."""],
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
demo = gr.Interface(
|
| 114 |
+
fn = run_canvas,
|
| 115 |
+
inputs = [
|
| 116 |
+
gr.Image(label="front_view", type="pil"),
|
| 117 |
+
gr.Image(label="map_view", type="pil"),
|
| 118 |
+
gr.Textbox(label="prompt")
|
| 119 |
+
],
|
| 120 |
+
outputs = gr.Image(label="generated waypoint"),
|
| 121 |
+
title="CANVAS Demo",
|
| 122 |
+
description="This is the demo of the CANVAS-S model from CANVAS: Commonsense-Aware Navigation System for Intuitive Human-Robot Interaction",
|
| 123 |
+
examples=examples
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
demo.launch()
|
canvas.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor, PreTrainedModel, ProcessorMixin
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import pickle
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from matplotlib import pyplot as plt
|
| 15 |
+
from sklearn.cluster import KMeans
|
| 16 |
+
|
| 17 |
+
class BaseModelYamlJsonMixin:
|
| 18 |
+
"""
|
| 19 |
+
BaseModel with helper methods for loading and saving to yaml/json format.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def from_yaml(cls, path: Path):
|
| 24 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 25 |
+
return cls(**yaml.safe_load(f))
|
| 26 |
+
|
| 27 |
+
def to_yaml(self: BaseModel, path: Path):
|
| 28 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 29 |
+
yaml.safe_dump(self.model_dump(), f)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_json(cls, path: Path):
|
| 33 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 34 |
+
return cls.model_validate_json(f.read())
|
| 35 |
+
|
| 36 |
+
def to_json(self: BaseModel, path: Path, indent: int = 4, *args, **kwargs):
|
| 37 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 38 |
+
f.write(self.model_dump_json(indent=indent, *args, **kwargs))
|
| 39 |
+
|
| 40 |
+
class BaseModelWithYamlJsonFromTo(BaseModel, BaseModelYamlJsonMixin):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
class Idefics2TrainAdditionalConfig(BaseModel):
|
| 44 |
+
"""
|
| 45 |
+
num_action_tokens (`int`, defaults to `32`):
|
| 46 |
+
Number of action tokens to add to the tokenizer vocabulary.
|
| 47 |
+
do_image_splitting (`bool`, *optional*, defaults to `False`):
|
| 48 |
+
Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
|
| 49 |
+
strategy was first introduced in https://arxiv.org/abs/2311.06607.
|
| 50 |
+
lora_config (`dict`, defaults to recommended config from https://x.com/danielhanchen/status/1791900967472140583):
|
| 51 |
+
Configuration for the LoRA model. If it is `None`, the model will not use LoRA.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# must be set to extend vocabulary of model + tokenizer
|
| 55 |
+
num_action_tokens: int = -1 # it will be overwritten by the processor_config.yml
|
| 56 |
+
# must be set to be used in pipeline
|
| 57 |
+
num_actions: int = -1 # it will be overwritten by the processor_config.yml
|
| 58 |
+
|
| 59 |
+
do_image_splitting: bool = True
|
| 60 |
+
freeze_original_vocab: bool = False
|
| 61 |
+
freeze_vision_model: bool = False
|
| 62 |
+
freeze_connector: bool = False
|
| 63 |
+
torch_dtype: str = "bfloat16"
|
| 64 |
+
lora_config: dict | None = dict(
|
| 65 |
+
r=256,
|
| 66 |
+
lora_alpha=512,
|
| 67 |
+
lora_dropout=0.1,
|
| 68 |
+
target_modules="all-linear",
|
| 69 |
+
use_rslora=True,
|
| 70 |
+
init_lora_weights="gaussian",
|
| 71 |
+
modules_to_save=["lm_head", "embed_tokens"],
|
| 72 |
+
)
|
| 73 |
+
model_name_or_path: str = "HuggingFaceM4/idefics2-8b"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class KMeansActionTokenizer():
|
| 77 |
+
def __init__(self, action_count: int = 128):
|
| 78 |
+
self.action_count = action_count
|
| 79 |
+
self.kmeans = KMeans(n_clusters=self.action_count, random_state=np.random.RandomState(seed=42))
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def token_count(self):
|
| 83 |
+
return self.action_count
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_pretrained(cls, model_path: str | Path):
|
| 87 |
+
model_path = Path(model_path)
|
| 88 |
+
self = cls()
|
| 89 |
+
with open(model_path / "tokenizer.pkl", "rb") as file:
|
| 90 |
+
self.kmeans = pickle.load(file)
|
| 91 |
+
self.action_count = self.kmeans.n_clusters
|
| 92 |
+
# assert self.action_count == 32
|
| 93 |
+
return self
|
| 94 |
+
|
| 95 |
+
def save_pretrained(self, model_path: str | Path):
|
| 96 |
+
model_path = Path(model_path)
|
| 97 |
+
model_path.mkdir(exist_ok=True)
|
| 98 |
+
with open(model_path / "tokenizer.pkl", "wb") as file:
|
| 99 |
+
pickle.dump(self.kmeans, file)
|
| 100 |
+
|
| 101 |
+
def train(self, actions):
|
| 102 |
+
self.kmeans.fit(actions)
|
| 103 |
+
|
| 104 |
+
def tokenize(self, action, padding=False, max_length=-1, truncation=False):
|
| 105 |
+
# action: (K, 3) shape, adjusted delta_position and delta_yaw
|
| 106 |
+
return [i for i in self.kmeans.predict(action)]
|
| 107 |
+
|
| 108 |
+
def detokenize(self, tokens):
|
| 109 |
+
# Token Check
|
| 110 |
+
check = np.asarray(tokens)
|
| 111 |
+
in_valid_range = (0 <= check) & (check < self.action_count)
|
| 112 |
+
if not in_valid_range.all():
|
| 113 |
+
logging.warning(f"Invalid tokens occur: {tokens}")
|
| 114 |
+
# If error occurs, return stop action.
|
| 115 |
+
return np.asarray([[0.0, 0.0, 0.0] for _ in range(len(tokens))])
|
| 116 |
+
return np.asarray([self.kmeans.cluster_centers_[t] for t in tokens])
|
| 117 |
+
|
| 118 |
+
def visualize(self, figset=None):
|
| 119 |
+
if figset is None:
|
| 120 |
+
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 16), dpi=300)
|
| 121 |
+
else:
|
| 122 |
+
fig, axes = figset
|
| 123 |
+
FONT = {"fontsize": 20}
|
| 124 |
+
|
| 125 |
+
axes[0].set_title("Center", fontdict=FONT)
|
| 126 |
+
axes[1].set_title("Center_Rot", fontdict=FONT)
|
| 127 |
+
|
| 128 |
+
labels = self.kmeans.labels_
|
| 129 |
+
centers = self.kmeans.cluster_centers_
|
| 130 |
+
|
| 131 |
+
# plot center. each center is given as (x, y, yaw). plot point (x,y) and arrow from (x,y) to p', with direction of yaw. consider (x, y)'s scale
|
| 132 |
+
scale_factor = 0.05
|
| 133 |
+
for i, center in enumerate(centers):
|
| 134 |
+
x, y, yaw = center
|
| 135 |
+
axes[0].plot(x, y, "ro")
|
| 136 |
+
axes[0].arrow(
|
| 137 |
+
x,
|
| 138 |
+
y,
|
| 139 |
+
np.cos(yaw) * scale_factor,
|
| 140 |
+
np.sin(yaw) * scale_factor,
|
| 141 |
+
head_width=scale_factor * 0.3,
|
| 142 |
+
head_length=scale_factor * 0.3,
|
| 143 |
+
fc="k",
|
| 144 |
+
ec="k",
|
| 145 |
+
)
|
| 146 |
+
axes[0].text(x, y, f"{i}", fontsize=10)
|
| 147 |
+
axes[0].axis("equal")
|
| 148 |
+
axes[0].grid(True)
|
| 149 |
+
|
| 150 |
+
# filter centers that are not far from origin in distance 0.3
|
| 151 |
+
_centers = centers[np.linalg.norm(centers[:, :2], axis=1) < 0.05]
|
| 152 |
+
# print(f"action near zero: {_centers}")
|
| 153 |
+
scale_factor = 0.1
|
| 154 |
+
for center in _centers:
|
| 155 |
+
x, y, yaw = center
|
| 156 |
+
axes[1].plot(x, y, "ro")
|
| 157 |
+
axes[1].arrow(
|
| 158 |
+
x,
|
| 159 |
+
y,
|
| 160 |
+
np.cos(yaw) * scale_factor,
|
| 161 |
+
np.sin(yaw) * scale_factor,
|
| 162 |
+
head_width=scale_factor * 0.3,
|
| 163 |
+
head_length=scale_factor * 0.3,
|
| 164 |
+
fc="k",
|
| 165 |
+
ec="k",
|
| 166 |
+
)
|
| 167 |
+
axes[1].axis("equal")
|
| 168 |
+
axes[1].grid(True)
|
| 169 |
+
|
| 170 |
+
return fig, axes
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Idefics2PipelineConfig(BaseModelWithYamlJsonFromTo):
|
| 174 |
+
pipeline_class: str = "Idefics2Pipeline"
|
| 175 |
+
train_additional_cfg: Idefics2TrainAdditionalConfig
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class Idefics2Pipeline():
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
model: PreTrainedModel,
|
| 182 |
+
processor: ProcessorMixin,
|
| 183 |
+
action_tokenizer: KMeansActionTokenizer,
|
| 184 |
+
config: Idefics2PipelineConfig,
|
| 185 |
+
):
|
| 186 |
+
self.model = model
|
| 187 |
+
self.processor = processor
|
| 188 |
+
self.action_tokenizer = action_tokenizer
|
| 189 |
+
self.config = config
|
| 190 |
+
|
| 191 |
+
def save_pretrained(
|
| 192 |
+
self,
|
| 193 |
+
save_directory: str,
|
| 194 |
+
):
|
| 195 |
+
if not isinstance(save_directory, Path):
|
| 196 |
+
save_directory = Path(save_directory)
|
| 197 |
+
self.model.save_pretrained(save_directory)
|
| 198 |
+
self.processor.save_pretrained(save_directory)
|
| 199 |
+
self.action_tokenizer.save_pretrained(save_directory)
|
| 200 |
+
self.config.to_json(f"{save_directory}/pipeline_config.json")
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str):
|
| 204 |
+
if not isinstance(pretrained_model_name_or_path, Path):
|
| 205 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 206 |
+
|
| 207 |
+
config = Idefics2PipelineConfig.model_validate_json(
|
| 208 |
+
(pretrained_model_name_or_path / "pipeline_config.json").read_text()
|
| 209 |
+
)
|
| 210 |
+
model = Idefics2ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
|
| 211 |
+
processor = Idefics2Processor.from_pretrained(pretrained_model_name_or_path)
|
| 212 |
+
model.eval()
|
| 213 |
+
action_tokenizer = KMeansActionTokenizer.from_pretrained(pretrained_model_name_or_path)
|
| 214 |
+
return cls(model, processor, action_tokenizer, config)
|
| 215 |
+
|
| 216 |
+
def to(self, device):
|
| 217 |
+
return self.model.to(device)
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def __call__(
|
| 221 |
+
self,
|
| 222 |
+
examples: list[dict],
|
| 223 |
+
return_traj: Optional[bool] = False,
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
call model with examples
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
examples: list of example, [B, example]
|
| 230 |
+
return_traj: return trajectory if True
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
raise NotImplementedError("Not implemented yet")
|
| 234 |
+
|
| 235 |
+
# same as idefics2 data collator
|
| 236 |
+
texts = []
|
| 237 |
+
images = []
|
| 238 |
+
for example in examples:
|
| 239 |
+
image = example["images"]
|
| 240 |
+
messages = example["messages"]
|
| 241 |
+
text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
|
| 242 |
+
texts.append(text.strip())
|
| 243 |
+
images.append(image)
|
| 244 |
+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
|
| 245 |
+
|
| 246 |
+
generate_ids = self.model.generate(**inputs, max_new_tokens=self.config.num_actions)
|
| 247 |
+
generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
|
| 248 |
+
|
| 249 |
+
if return_traj:
|
| 250 |
+
return self.action_tokenizer.detokenize(generated_text)
|
| 251 |
+
else:
|
| 252 |
+
return generated_text
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
def __call__(
|
| 256 |
+
self,
|
| 257 |
+
message_list: list[list[dict]],
|
| 258 |
+
images_list: list[list[Image.Image]],
|
| 259 |
+
return_traj: Optional[bool] = False,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
call model with message and images
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
message_list: list of messages, [B, messages]
|
| 266 |
+
images_list: list of images, [B, images]
|
| 267 |
+
return_traj: return trajectory if True
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
# we don't use batch inference for run model worker
|
| 271 |
+
if len(message_list) != 1:
|
| 272 |
+
raise ValueError("No batch api call allowed for Idefics2Pipeline")
|
| 273 |
+
|
| 274 |
+
message = message_list[0]
|
| 275 |
+
images = images_list[0]
|
| 276 |
+
prompt = self.processor.apply_chat_template(message, add_generation_prompt=True)
|
| 277 |
+
prompt.replace("<end_of_utterance>", "")
|
| 278 |
+
# add space to match the training data
|
| 279 |
+
prompt = prompt + " "
|
| 280 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt", padding=True)
|
| 281 |
+
|
| 282 |
+
device = self.model.device
|
| 283 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 284 |
+
|
| 285 |
+
generate_ids = self.model.generate(
|
| 286 |
+
**inputs, max_new_tokens=self.config.train_additional_cfg.num_actions, top_k=1
|
| 287 |
+
)
|
| 288 |
+
generated_texts = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
|
| 289 |
+
if return_traj:
|
| 290 |
+
pred_action = re.findall(r"<ACTION_(\d+)>", generated_texts[0])
|
| 291 |
+
# pred_action = pred_action if len(pred_action) == self.config.num_actions else [-1] * self.config.num_actions
|
| 292 |
+
pred_action = np.array(pred_action, dtype=np.int64)
|
| 293 |
+
return self.action_tokenizer.detokenize(pred_action).tolist()
|
| 294 |
+
else:
|
| 295 |
+
return generated_texts
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.46.1
|
| 2 |
+
datasets==3.1.0
|
| 3 |
+
pillow==10.4.0
|
| 4 |
+
numpy==2.1.3
|
| 5 |
+
torch==2.4.0
|
| 6 |
+
pydantic==2.9.2
|
| 7 |
+
scikit-learn==1.5.2
|
| 8 |
+
matplotlib==3.9.3
|
src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png
ADDED
|
src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png
ADDED
|
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png
ADDED
|
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png
ADDED
|
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png
ADDED
|
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png
ADDED
|
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png
ADDED
|
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png
ADDED
|