Upload processing_eo1.py with huggingface_hub
Browse files- processing_eo1.py +59 -49
processing_eo1.py
CHANGED
|
@@ -12,11 +12,13 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
|
|
|
| 15 |
from typing import Union
|
| 16 |
|
| 17 |
import numpy as np
|
| 18 |
import torch
|
| 19 |
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
|
|
| 20 |
from lerobot.datasets.utils import cast_stats_to_numpy
|
| 21 |
from lerobot.policies.normalize import Normalize, Unnormalize
|
| 22 |
from transformers.feature_extraction_utils import BatchFeature
|
|
@@ -32,6 +34,8 @@ from transformers.processing_utils import (
|
|
| 32 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 33 |
from transformers.video_utils import VideoInput
|
| 34 |
|
|
|
|
|
|
|
| 35 |
"""constants"""
|
| 36 |
DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
|
| 37 |
DEFAULT_VIDEO_TOKEN = "<|video_pad|>"
|
|
@@ -48,8 +52,8 @@ DEFAULT_STATE_TOKEN = "<|state_pad|>"
|
|
| 48 |
STATE_END_TOKEN = "<|state_end|>"
|
| 49 |
TASK_VLA_TOKEN = "<|vla|>"
|
| 50 |
|
|
|
|
| 51 |
RobotInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]]
|
| 52 |
-
RobotIDInput = Union[str, list[str]]
|
| 53 |
|
| 54 |
|
| 55 |
class EO1VisionVideosProcessorKwargs(VideosKwargs, total=False):
|
|
@@ -99,22 +103,14 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 99 |
robot_config=None,
|
| 100 |
**kwargs,
|
| 101 |
):
|
| 102 |
-
self.image_token = (
|
| 103 |
-
|
| 104 |
-
)
|
| 105 |
-
self.
|
| 106 |
-
DEFAULT_VIDEO_TOKEN if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
| 107 |
-
)
|
| 108 |
-
self.action_token = (
|
| 109 |
-
DEFAULT_ACTION_TOKEN if not hasattr(tokenizer, "action_token") else tokenizer.action_token
|
| 110 |
-
)
|
| 111 |
-
self.state_token = (
|
| 112 |
-
DEFAULT_STATE_TOKEN if not hasattr(tokenizer, "state_token") else tokenizer.state_token
|
| 113 |
-
)
|
| 114 |
|
| 115 |
# robot policy
|
| 116 |
self.action_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN) or 151666
|
| 117 |
-
self.action_pass_id = tokenizer.convert_tokens_to_ids(PASS_ACTION_TOKEN) or
|
| 118 |
self.robot_config = robot_config or {}
|
| 119 |
self.set_normalization(self.robot_config)
|
| 120 |
|
|
@@ -126,15 +122,14 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 126 |
robot_config.get("stats"),
|
| 127 |
robot_config.get("state_mode"),
|
| 128 |
)
|
| 129 |
-
if
|
| 130 |
return
|
| 131 |
else:
|
| 132 |
normalization_mapping = {
|
| 133 |
"STATE": NormalizationMode(state_mode),
|
| 134 |
"ACTION": NormalizationMode(state_mode),
|
| 135 |
}
|
| 136 |
-
|
| 137 |
-
self.normalize_inputs, self.unnormalize_outputs = {}, {}
|
| 138 |
for repo_id, fea in features.items():
|
| 139 |
stat = cast_stats_to_numpy(stats[repo_id])
|
| 140 |
fea = dataset_to_policy_features(fea)
|
|
@@ -142,12 +137,11 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 142 |
input_features = {k: v for k, v in fea.items() if v.type == FeatureType.STATE}
|
| 143 |
output_features = {k: v for k, v in fea.items() if v.type == FeatureType.ACTION}
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
self.select_action_keys = robot_config.get("select_action_keys")
|
| 151 |
|
| 152 |
def __call__(
|
| 153 |
self,
|
|
@@ -233,7 +227,6 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 233 |
)
|
| 234 |
text[i] = text[i].replace("<|placeholder|>", self.action_token)
|
| 235 |
|
| 236 |
-
# state tokens
|
| 237 |
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
| 238 |
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 239 |
if return_mm_token_type_ids:
|
|
@@ -244,10 +237,11 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 244 |
|
| 245 |
# robot inputs
|
| 246 |
robot_inputs = {}
|
|
|
|
| 247 |
if states is not None:
|
| 248 |
if isinstance(states, list):
|
| 249 |
states = torch.stack(states, dim=0)
|
| 250 |
-
if states.ndim ==
|
| 251 |
states = states.unsqueeze(0)
|
| 252 |
robot_inputs.update({"states": states})
|
| 253 |
|
|
@@ -267,22 +261,31 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 267 |
tokenizer_input_names = self.tokenizer.model_input_names
|
| 268 |
image_processor_input_names = self.image_processor.model_input_names
|
| 269 |
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 270 |
-
return names_from_processor + ["second_per_grid_ts"] + ["actions"]
|
| 271 |
|
| 272 |
@torch.no_grad
|
| 273 |
-
def
|
| 274 |
-
|
| 275 |
batch_messages = []
|
| 276 |
batch_states = []
|
| 277 |
max_state_dim = self.robot_config.get("max_state_dim", 32)
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
mini_batch = {k: v[i] for k, v in batch.items()}
|
| 282 |
|
| 283 |
normalize_inputs = self.normalize_inputs[repo_id]
|
| 284 |
-
select_video_keys = self.select_video_keys[repo_id]
|
| 285 |
-
select_state_keys = self.select_state_keys[repo_id]
|
| 286 |
|
| 287 |
for k in normalize_inputs.features:
|
| 288 |
if not isinstance(mini_batch[k], torch.Tensor):
|
|
@@ -296,31 +299,20 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 296 |
"role": "user",
|
| 297 |
"content": [
|
| 298 |
*({"type": "image", "image": mini_batch[k]} for k in select_video_keys),
|
| 299 |
-
{"type": "state", "state":
|
| 300 |
-
{"type": "text", "text": f"{mini_batch['task']}{TASK_VLA_TOKEN}"},
|
| 301 |
],
|
| 302 |
}
|
| 303 |
]
|
| 304 |
batch_messages += [messages]
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
states=batch_states,
|
| 309 |
-
add_generation_prompt=True,
|
| 310 |
-
noise_prompt=f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}",
|
| 311 |
-
tokenize=True,
|
| 312 |
-
return_dict=True,
|
| 313 |
-
return_tensors="pt",
|
| 314 |
-
).to(model.device)
|
| 315 |
-
|
| 316 |
-
outputs = model.generate(**inputs, max_new_tokens=128, return_dict_in_generate=True)
|
| 317 |
-
actions = outputs.actions.cpu()
|
| 318 |
-
|
| 319 |
-
# unnormalize actions
|
| 320 |
output_actions = []
|
| 321 |
-
for i, repo_id in enumerate(
|
| 322 |
unnormalize_outputs = self.unnormalize_outputs[repo_id]
|
| 323 |
-
select_action_keys = self.select_action_keys[repo_id]
|
| 324 |
features = unnormalize_outputs.features
|
| 325 |
cum_dims = [0] + np.cumsum([features[k].shape[0] for k in select_action_keys]).tolist()
|
| 326 |
origin_action = torch.tensor(actions[i], dtype=torch.float32)[..., : cum_dims[-1]]
|
|
@@ -331,7 +323,25 @@ class EO1VisionProcessor(ProcessorMixin):
|
|
| 331 |
unnorm_actions = torch.concat([unnorm_actions[k] for k in select_action_keys], -1)
|
| 332 |
output_actions.append(unnorm_actions)
|
| 333 |
output_actions = torch.stack(output_actions, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
|
|
|
|
|
|
| 335 |
return BatchFeature({"action": output_actions})
|
| 336 |
|
| 337 |
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
+
import os
|
| 16 |
from typing import Union
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
| 21 |
+
from lerobot.constants import OBS_STATE
|
| 22 |
from lerobot.datasets.utils import cast_stats_to_numpy
|
| 23 |
from lerobot.policies.normalize import Normalize, Unnormalize
|
| 24 |
from transformers.feature_extraction_utils import BatchFeature
|
|
|
|
| 34 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 35 |
from transformers.video_utils import VideoInput
|
| 36 |
|
| 37 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
| 38 |
+
|
| 39 |
"""constants"""
|
| 40 |
DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
|
| 41 |
DEFAULT_VIDEO_TOKEN = "<|video_pad|>"
|
|
|
|
| 52 |
STATE_END_TOKEN = "<|state_end|>"
|
| 53 |
TASK_VLA_TOKEN = "<|vla|>"
|
| 54 |
|
| 55 |
+
|
| 56 |
RobotInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]]
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
class EO1VisionVideosProcessorKwargs(VideosKwargs, total=False):
|
|
|
|
| 103 |
robot_config=None,
|
| 104 |
**kwargs,
|
| 105 |
):
|
| 106 |
+
self.image_token = getattr(tokenizer, "image_token", DEFAULT_IMAGE_TOKEN)
|
| 107 |
+
self.video_token = getattr(tokenizer, "video_token", DEFAULT_VIDEO_TOKEN)
|
| 108 |
+
self.action_token = getattr(tokenizer, "action_token", DEFAULT_ACTION_TOKEN)
|
| 109 |
+
self.state_token = getattr(tokenizer, "state_token", DEFAULT_STATE_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# robot policy
|
| 112 |
self.action_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN) or 151666
|
| 113 |
+
self.action_pass_id = tokenizer.convert_tokens_to_ids(PASS_ACTION_TOKEN) or 151667
|
| 114 |
self.robot_config = robot_config or {}
|
| 115 |
self.set_normalization(self.robot_config)
|
| 116 |
|
|
|
|
| 122 |
robot_config.get("stats"),
|
| 123 |
robot_config.get("state_mode"),
|
| 124 |
)
|
| 125 |
+
if None in [features, stats, state_mode]:
|
| 126 |
return
|
| 127 |
else:
|
| 128 |
normalization_mapping = {
|
| 129 |
"STATE": NormalizationMode(state_mode),
|
| 130 |
"ACTION": NormalizationMode(state_mode),
|
| 131 |
}
|
| 132 |
+
normalize_inputs, unnormalize_outputs = {}, {}
|
|
|
|
| 133 |
for repo_id, fea in features.items():
|
| 134 |
stat = cast_stats_to_numpy(stats[repo_id])
|
| 135 |
fea = dataset_to_policy_features(fea)
|
|
|
|
| 137 |
input_features = {k: v for k, v in fea.items() if v.type == FeatureType.STATE}
|
| 138 |
output_features = {k: v for k, v in fea.items() if v.type == FeatureType.ACTION}
|
| 139 |
|
| 140 |
+
normalize_inputs[repo_id] = Normalize(input_features, normalization_mapping, stat)
|
| 141 |
+
unnormalize_outputs[repo_id] = Unnormalize(output_features, normalization_mapping, stat)
|
| 142 |
|
| 143 |
+
self.robot_config = dict(robot_config)
|
| 144 |
+
self.normalize_inputs, self.unnormalize_outputs = normalize_inputs, unnormalize_outputs
|
|
|
|
| 145 |
|
| 146 |
def __call__(
|
| 147 |
self,
|
|
|
|
| 227 |
)
|
| 228 |
text[i] = text[i].replace("<|placeholder|>", self.action_token)
|
| 229 |
|
|
|
|
| 230 |
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
| 231 |
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 232 |
if return_mm_token_type_ids:
|
|
|
|
| 237 |
|
| 238 |
# robot inputs
|
| 239 |
robot_inputs = {}
|
| 240 |
+
|
| 241 |
if states is not None:
|
| 242 |
if isinstance(states, list):
|
| 243 |
states = torch.stack(states, dim=0)
|
| 244 |
+
if states.ndim == 1:
|
| 245 |
states = states.unsqueeze(0)
|
| 246 |
robot_inputs.update({"states": states})
|
| 247 |
|
|
|
|
| 261 |
tokenizer_input_names = self.tokenizer.model_input_names
|
| 262 |
image_processor_input_names = self.image_processor.model_input_names
|
| 263 |
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 264 |
+
return names_from_processor + ["second_per_grid_ts"] + ["states", "actions"]
|
| 265 |
|
| 266 |
@torch.no_grad
|
| 267 |
+
def _prepare_robot_inputs(self, batch: dict):
|
| 268 |
+
"""Prepare model inputs from raw robot batch"""
|
| 269 |
batch_messages = []
|
| 270 |
batch_states = []
|
| 271 |
max_state_dim = self.robot_config.get("max_state_dim", 32)
|
| 272 |
|
| 273 |
+
state_keys = [x for x in batch.keys() if x.startswith(OBS_STATE)]
|
| 274 |
+
batch_size = len(batch[state_keys[0]])
|
| 275 |
+
|
| 276 |
+
if "repo_id" in batch:
|
| 277 |
+
repo_ids = batch.pop("repo_id")
|
| 278 |
+
else:
|
| 279 |
+
print("no repo_id found, use the first one in normalize_inputs")
|
| 280 |
+
repo_ids = list(self.normalize_inputs.keys())[0]
|
| 281 |
+
repo_ids = [repo_ids] * batch_size if isinstance(repo_ids, str) else repo_ids
|
| 282 |
+
|
| 283 |
+
for i, repo_id in enumerate(repo_ids):
|
| 284 |
mini_batch = {k: v[i] for k, v in batch.items()}
|
| 285 |
|
| 286 |
normalize_inputs = self.normalize_inputs[repo_id]
|
| 287 |
+
select_video_keys = self.robot_config["select_video_keys"][repo_id]
|
| 288 |
+
select_state_keys = self.robot_config["select_state_keys"][repo_id]
|
| 289 |
|
| 290 |
for k in normalize_inputs.features:
|
| 291 |
if not isinstance(mini_batch[k], torch.Tensor):
|
|
|
|
| 299 |
"role": "user",
|
| 300 |
"content": [
|
| 301 |
*({"type": "image", "image": mini_batch[k]} for k in select_video_keys),
|
| 302 |
+
{"type": "state", "state": []}, # chat template state token
|
| 303 |
+
{"type": "text", "text": f"{mini_batch['task']}{TASK_VLA_TOKEN}"},
|
| 304 |
],
|
| 305 |
}
|
| 306 |
]
|
| 307 |
batch_messages += [messages]
|
| 308 |
+
return batch_messages, batch_states, repo_ids
|
| 309 |
|
| 310 |
+
def _process_robot_outputs(self, repo_ids: list[str], actions: torch.Tensor):
|
| 311 |
+
"""Process model outputs back to robot format"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
output_actions = []
|
| 313 |
+
for i, repo_id in enumerate(repo_ids):
|
| 314 |
unnormalize_outputs = self.unnormalize_outputs[repo_id]
|
| 315 |
+
select_action_keys = self.robot_config["select_action_keys"][repo_id]
|
| 316 |
features = unnormalize_outputs.features
|
| 317 |
cum_dims = [0] + np.cumsum([features[k].shape[0] for k in select_action_keys]).tolist()
|
| 318 |
origin_action = torch.tensor(actions[i], dtype=torch.float32)[..., : cum_dims[-1]]
|
|
|
|
| 323 |
unnorm_actions = torch.concat([unnorm_actions[k] for k in select_action_keys], -1)
|
| 324 |
output_actions.append(unnorm_actions)
|
| 325 |
output_actions = torch.stack(output_actions, dim=0)
|
| 326 |
+
return output_actions
|
| 327 |
+
|
| 328 |
+
@torch.no_grad
|
| 329 |
+
def select_action(self, model, batch: dict, **kwargs):
|
| 330 |
+
batch_messages, batch_states, repo_ids = self._prepare_robot_inputs(batch)
|
| 331 |
+
|
| 332 |
+
noise_prompt = f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}"
|
| 333 |
+
inputs = self.apply_chat_template(
|
| 334 |
+
batch_messages,
|
| 335 |
+
states=batch_states,
|
| 336 |
+
add_generation_prompt=True,
|
| 337 |
+
noise_prompt=noise_prompt,
|
| 338 |
+
tokenize=True,
|
| 339 |
+
return_dict=True,
|
| 340 |
+
return_tensors="pt",
|
| 341 |
+
).to(model.device)
|
| 342 |
|
| 343 |
+
actions = model.sample_actions(**inputs)[0].cpu()
|
| 344 |
+
output_actions = self._process_robot_outputs(repo_ids, actions)
|
| 345 |
return BatchFeature({"action": output_actions})
|
| 346 |
|
| 347 |
|