delinqu commited on
Commit
0dd7ac3
·
verified ·
1 Parent(s): bcb1c1e

Upload processing_eo1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_eo1.py +59 -49
processing_eo1.py CHANGED
@@ -12,11 +12,13 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  from typing import Union
16
 
17
  import numpy as np
18
  import torch
19
  from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
 
20
  from lerobot.datasets.utils import cast_stats_to_numpy
21
  from lerobot.policies.normalize import Normalize, Unnormalize
22
  from transformers.feature_extraction_utils import BatchFeature
@@ -32,6 +34,8 @@ from transformers.processing_utils import (
32
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
33
  from transformers.video_utils import VideoInput
34
 
 
 
35
  """constants"""
36
  DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
37
  DEFAULT_VIDEO_TOKEN = "<|video_pad|>"
@@ -48,8 +52,8 @@ DEFAULT_STATE_TOKEN = "<|state_pad|>"
48
  STATE_END_TOKEN = "<|state_end|>"
49
  TASK_VLA_TOKEN = "<|vla|>"
50
 
 
51
  RobotInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]]
52
- RobotIDInput = Union[str, list[str]]
53
 
54
 
55
  class EO1VisionVideosProcessorKwargs(VideosKwargs, total=False):
@@ -99,22 +103,14 @@ class EO1VisionProcessor(ProcessorMixin):
99
  robot_config=None,
100
  **kwargs,
101
  ):
102
- self.image_token = (
103
- DEFAULT_IMAGE_TOKEN if not hasattr(tokenizer, "image_token") else tokenizer.image_token
104
- )
105
- self.video_token = (
106
- DEFAULT_VIDEO_TOKEN if not hasattr(tokenizer, "video_token") else tokenizer.video_token
107
- )
108
- self.action_token = (
109
- DEFAULT_ACTION_TOKEN if not hasattr(tokenizer, "action_token") else tokenizer.action_token
110
- )
111
- self.state_token = (
112
- DEFAULT_STATE_TOKEN if not hasattr(tokenizer, "state_token") else tokenizer.state_token
113
- )
114
 
115
  # robot policy
116
  self.action_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN) or 151666
117
- self.action_pass_id = tokenizer.convert_tokens_to_ids(PASS_ACTION_TOKEN) or 151672
118
  self.robot_config = robot_config or {}
119
  self.set_normalization(self.robot_config)
120
 
@@ -126,15 +122,14 @@ class EO1VisionProcessor(ProcessorMixin):
126
  robot_config.get("stats"),
127
  robot_config.get("state_mode"),
128
  )
129
- if features is None or stats is None or state_mode is None:
130
  return
131
  else:
132
  normalization_mapping = {
133
  "STATE": NormalizationMode(state_mode),
134
  "ACTION": NormalizationMode(state_mode),
135
  }
136
- self.robot_config = dict(robot_config)
137
- self.normalize_inputs, self.unnormalize_outputs = {}, {}
138
  for repo_id, fea in features.items():
139
  stat = cast_stats_to_numpy(stats[repo_id])
140
  fea = dataset_to_policy_features(fea)
@@ -142,12 +137,11 @@ class EO1VisionProcessor(ProcessorMixin):
142
  input_features = {k: v for k, v in fea.items() if v.type == FeatureType.STATE}
143
  output_features = {k: v for k, v in fea.items() if v.type == FeatureType.ACTION}
144
 
145
- self.normalize_inputs[repo_id] = Normalize(input_features, normalization_mapping, stat)
146
- self.unnormalize_outputs[repo_id] = Unnormalize(output_features, normalization_mapping, stat)
147
 
148
- self.select_video_keys = robot_config.get("select_video_keys")
149
- self.select_state_keys = robot_config.get("select_state_keys")
150
- self.select_action_keys = robot_config.get("select_action_keys")
151
 
152
  def __call__(
153
  self,
@@ -233,7 +227,6 @@ class EO1VisionProcessor(ProcessorMixin):
233
  )
234
  text[i] = text[i].replace("<|placeholder|>", self.action_token)
235
 
236
- # state tokens
237
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
238
  text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
239
  if return_mm_token_type_ids:
@@ -244,10 +237,11 @@ class EO1VisionProcessor(ProcessorMixin):
244
 
245
  # robot inputs
246
  robot_inputs = {}
 
247
  if states is not None:
248
  if isinstance(states, list):
249
  states = torch.stack(states, dim=0)
250
- if states.ndim == 2:
251
  states = states.unsqueeze(0)
252
  robot_inputs.update({"states": states})
253
 
@@ -267,22 +261,31 @@ class EO1VisionProcessor(ProcessorMixin):
267
  tokenizer_input_names = self.tokenizer.model_input_names
268
  image_processor_input_names = self.image_processor.model_input_names
269
  names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
270
- return names_from_processor + ["second_per_grid_ts"] + ["actions"]
271
 
272
  @torch.no_grad
273
- def select_action(self, model, batch: dict, **kwargs):
274
- # normalize batch
275
  batch_messages = []
276
  batch_states = []
277
  max_state_dim = self.robot_config.get("max_state_dim", 32)
278
 
279
- # normalize robot inputs
280
- for i, repo_id in enumerate(batch["repo_id"]):
 
 
 
 
 
 
 
 
 
281
  mini_batch = {k: v[i] for k, v in batch.items()}
282
 
283
  normalize_inputs = self.normalize_inputs[repo_id]
284
- select_video_keys = self.select_video_keys[repo_id]
285
- select_state_keys = self.select_state_keys[repo_id]
286
 
287
  for k in normalize_inputs.features:
288
  if not isinstance(mini_batch[k], torch.Tensor):
@@ -296,31 +299,20 @@ class EO1VisionProcessor(ProcessorMixin):
296
  "role": "user",
297
  "content": [
298
  *({"type": "image", "image": mini_batch[k]} for k in select_video_keys),
299
- {"type": "state", "state": states},
300
- {"type": "text", "text": f"{mini_batch['task']}{TASK_VLA_TOKEN}"}, # add task token
301
  ],
302
  }
303
  ]
304
  batch_messages += [messages]
 
305
 
306
- inputs = self.apply_chat_template(
307
- batch_messages,
308
- states=batch_states,
309
- add_generation_prompt=True,
310
- noise_prompt=f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}",
311
- tokenize=True,
312
- return_dict=True,
313
- return_tensors="pt",
314
- ).to(model.device)
315
-
316
- outputs = model.generate(**inputs, max_new_tokens=128, return_dict_in_generate=True)
317
- actions = outputs.actions.cpu()
318
-
319
- # unnormalize actions
320
  output_actions = []
321
- for i, repo_id in enumerate(batch["repo_id"]):
322
  unnormalize_outputs = self.unnormalize_outputs[repo_id]
323
- select_action_keys = self.select_action_keys[repo_id]
324
  features = unnormalize_outputs.features
325
  cum_dims = [0] + np.cumsum([features[k].shape[0] for k in select_action_keys]).tolist()
326
  origin_action = torch.tensor(actions[i], dtype=torch.float32)[..., : cum_dims[-1]]
@@ -331,7 +323,25 @@ class EO1VisionProcessor(ProcessorMixin):
331
  unnorm_actions = torch.concat([unnorm_actions[k] for k in select_action_keys], -1)
332
  output_actions.append(unnorm_actions)
333
  output_actions = torch.stack(output_actions, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
 
 
335
  return BatchFeature({"action": output_actions})
336
 
337
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import os
16
  from typing import Union
17
 
18
  import numpy as np
19
  import torch
20
  from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
21
+ from lerobot.constants import OBS_STATE
22
  from lerobot.datasets.utils import cast_stats_to_numpy
23
  from lerobot.policies.normalize import Normalize, Unnormalize
24
  from transformers.feature_extraction_utils import BatchFeature
 
34
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
35
  from transformers.video_utils import VideoInput
36
 
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
38
+
39
  """constants"""
40
  DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
41
  DEFAULT_VIDEO_TOKEN = "<|video_pad|>"
 
52
  STATE_END_TOKEN = "<|state_end|>"
53
  TASK_VLA_TOKEN = "<|vla|>"
54
 
55
+
56
  RobotInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]]
 
57
 
58
 
59
  class EO1VisionVideosProcessorKwargs(VideosKwargs, total=False):
 
103
  robot_config=None,
104
  **kwargs,
105
  ):
106
+ self.image_token = getattr(tokenizer, "image_token", DEFAULT_IMAGE_TOKEN)
107
+ self.video_token = getattr(tokenizer, "video_token", DEFAULT_VIDEO_TOKEN)
108
+ self.action_token = getattr(tokenizer, "action_token", DEFAULT_ACTION_TOKEN)
109
+ self.state_token = getattr(tokenizer, "state_token", DEFAULT_STATE_TOKEN)
 
 
 
 
 
 
 
 
110
 
111
  # robot policy
112
  self.action_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN) or 151666
113
+ self.action_pass_id = tokenizer.convert_tokens_to_ids(PASS_ACTION_TOKEN) or 151667
114
  self.robot_config = robot_config or {}
115
  self.set_normalization(self.robot_config)
116
 
 
122
  robot_config.get("stats"),
123
  robot_config.get("state_mode"),
124
  )
125
+ if None in [features, stats, state_mode]:
126
  return
127
  else:
128
  normalization_mapping = {
129
  "STATE": NormalizationMode(state_mode),
130
  "ACTION": NormalizationMode(state_mode),
131
  }
132
+ normalize_inputs, unnormalize_outputs = {}, {}
 
133
  for repo_id, fea in features.items():
134
  stat = cast_stats_to_numpy(stats[repo_id])
135
  fea = dataset_to_policy_features(fea)
 
137
  input_features = {k: v for k, v in fea.items() if v.type == FeatureType.STATE}
138
  output_features = {k: v for k, v in fea.items() if v.type == FeatureType.ACTION}
139
 
140
+ normalize_inputs[repo_id] = Normalize(input_features, normalization_mapping, stat)
141
+ unnormalize_outputs[repo_id] = Unnormalize(output_features, normalization_mapping, stat)
142
 
143
+ self.robot_config = dict(robot_config)
144
+ self.normalize_inputs, self.unnormalize_outputs = normalize_inputs, unnormalize_outputs
 
145
 
146
  def __call__(
147
  self,
 
227
  )
228
  text[i] = text[i].replace("<|placeholder|>", self.action_token)
229
 
 
230
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
231
  text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
232
  if return_mm_token_type_ids:
 
237
 
238
  # robot inputs
239
  robot_inputs = {}
240
+
241
  if states is not None:
242
  if isinstance(states, list):
243
  states = torch.stack(states, dim=0)
244
+ if states.ndim == 1:
245
  states = states.unsqueeze(0)
246
  robot_inputs.update({"states": states})
247
 
 
261
  tokenizer_input_names = self.tokenizer.model_input_names
262
  image_processor_input_names = self.image_processor.model_input_names
263
  names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
264
+ return names_from_processor + ["second_per_grid_ts"] + ["states", "actions"]
265
 
266
  @torch.no_grad
267
+ def _prepare_robot_inputs(self, batch: dict):
268
+ """Prepare model inputs from raw robot batch"""
269
  batch_messages = []
270
  batch_states = []
271
  max_state_dim = self.robot_config.get("max_state_dim", 32)
272
 
273
+ state_keys = [x for x in batch.keys() if x.startswith(OBS_STATE)]
274
+ batch_size = len(batch[state_keys[0]])
275
+
276
+ if "repo_id" in batch:
277
+ repo_ids = batch.pop("repo_id")
278
+ else:
279
+ print("no repo_id found, use the first one in normalize_inputs")
280
+ repo_ids = list(self.normalize_inputs.keys())[0]
281
+ repo_ids = [repo_ids] * batch_size if isinstance(repo_ids, str) else repo_ids
282
+
283
+ for i, repo_id in enumerate(repo_ids):
284
  mini_batch = {k: v[i] for k, v in batch.items()}
285
 
286
  normalize_inputs = self.normalize_inputs[repo_id]
287
+ select_video_keys = self.robot_config["select_video_keys"][repo_id]
288
+ select_state_keys = self.robot_config["select_state_keys"][repo_id]
289
 
290
  for k in normalize_inputs.features:
291
  if not isinstance(mini_batch[k], torch.Tensor):
 
299
  "role": "user",
300
  "content": [
301
  *({"type": "image", "image": mini_batch[k]} for k in select_video_keys),
302
+ {"type": "state", "state": []}, # chat template state token
303
+ {"type": "text", "text": f"{mini_batch['task']}{TASK_VLA_TOKEN}"},
304
  ],
305
  }
306
  ]
307
  batch_messages += [messages]
308
+ return batch_messages, batch_states, repo_ids
309
 
310
+ def _process_robot_outputs(self, repo_ids: list[str], actions: torch.Tensor):
311
+ """Process model outputs back to robot format"""
 
 
 
 
 
 
 
 
 
 
 
 
312
  output_actions = []
313
+ for i, repo_id in enumerate(repo_ids):
314
  unnormalize_outputs = self.unnormalize_outputs[repo_id]
315
+ select_action_keys = self.robot_config["select_action_keys"][repo_id]
316
  features = unnormalize_outputs.features
317
  cum_dims = [0] + np.cumsum([features[k].shape[0] for k in select_action_keys]).tolist()
318
  origin_action = torch.tensor(actions[i], dtype=torch.float32)[..., : cum_dims[-1]]
 
323
  unnorm_actions = torch.concat([unnorm_actions[k] for k in select_action_keys], -1)
324
  output_actions.append(unnorm_actions)
325
  output_actions = torch.stack(output_actions, dim=0)
326
+ return output_actions
327
+
328
+ @torch.no_grad
329
+ def select_action(self, model, batch: dict, **kwargs):
330
+ batch_messages, batch_states, repo_ids = self._prepare_robot_inputs(batch)
331
+
332
+ noise_prompt = f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}"
333
+ inputs = self.apply_chat_template(
334
+ batch_messages,
335
+ states=batch_states,
336
+ add_generation_prompt=True,
337
+ noise_prompt=noise_prompt,
338
+ tokenize=True,
339
+ return_dict=True,
340
+ return_tensors="pt",
341
+ ).to(model.device)
342
 
343
+ actions = model.sample_actions(**inputs)[0].cpu()
344
+ output_actions = self._process_robot_outputs(repo_ids, actions)
345
  return BatchFeature({"action": output_actions})
346
 
347