add inference.py
Browse files- inference.py +185 -0
 
    	
        inference.py
    ADDED
    
    | 
         @@ -0,0 +1,185 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            from pathlib import Path
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import imageio
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            from transformers import AutoModel
         
     | 
| 11 | 
         
            +
            from tqdm import tqdm
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # Constants
         
     | 
| 15 | 
         
            +
            IMAGE_SIZE = (288, 512)
         
     | 
| 16 | 
         
            +
            N_FRAMES_PER_ROUND = 25
         
     | 
| 17 | 
         
            +
            MAX_NUM_FRAMES = 50
         
     | 
| 18 | 
         
            +
            N_TOKENS_PER_FRAME = 576
         
     | 
| 19 | 
         
            +
            TRAJ_TEMPLATE_PATH = Path("./assets/template_trajectory.json")
         
     | 
| 20 | 
         
            +
            PATH_START_ID = 9
         
     | 
| 21 | 
         
            +
            PATH_POINT_INTERVAL = 10
         
     | 
| 22 | 
         
            +
            N_ACTION_TOKENS = 6
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # change here if you want to use your own images
         
     | 
| 25 | 
         
            +
            CONDITIONING_FRAMES_DIR = Path("./assets/conditioning_frames")
         
     | 
| 26 | 
         
            +
            CONDITIONING_FRAMES_PATH_LIST = [
         
     | 
| 27 | 
         
            +
                CONDITIONING_FRAMES_DIR / "001.png",
         
     | 
| 28 | 
         
            +
                CONDITIONING_FRAMES_DIR / "002.png",
         
     | 
| 29 | 
         
            +
                CONDITIONING_FRAMES_DIR / "003.png"
         
     | 
| 30 | 
         
            +
            ]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def set_random_seed(seed: int = 0):
         
     | 
| 34 | 
         
            +
                random.seed(seed)
         
     | 
| 35 | 
         
            +
                np.random.seed(seed)
         
     | 
| 36 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 37 | 
         
            +
                torch.cuda.manual_seed(seed)
         
     | 
| 38 | 
         
            +
                torch.backends.cudnn.deterministic = True
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def preprocess_image(image: Image.Image, size: tuple[int, int] = (288, 512)) -> torch.Tensor:
         
     | 
| 42 | 
         
            +
                H, W = size
         
     | 
| 43 | 
         
            +
                image = image.convert("RGB")
         
     | 
| 44 | 
         
            +
                image = image.resize((W, H))
         
     | 
| 45 | 
         
            +
                image_array = np.array(image)
         
     | 
| 46 | 
         
            +
                image_array = (image_array / 127.5 - 1.0).astype(np.float32)
         
     | 
| 47 | 
         
            +
                return torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0).float()
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def to_np_images(images: torch.Tensor) -> np.ndarray:
         
     | 
| 51 | 
         
            +
                images = images.detach().cpu()
         
     | 
| 52 | 
         
            +
                images = torch.clamp(images, -1., 1.)
         
     | 
| 53 | 
         
            +
                images = (images + 1.) / 2.
         
     | 
| 54 | 
         
            +
                images = images.permute(0, 2, 3, 1).numpy()
         
     | 
| 55 | 
         
            +
                return (255 * images).astype(np.uint8)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def load_images(file_path_list: list[Path], size: tuple[int, int] = (288, 512)) -> torch.Tensor:
         
     | 
| 59 | 
         
            +
                images = []
         
     | 
| 60 | 
         
            +
                for file_path in file_path_list:
         
     | 
| 61 | 
         
            +
                    image = Image.open(file_path)
         
     | 
| 62 | 
         
            +
                    image = preprocess_image(image, size)
         
     | 
| 63 | 
         
            +
                    images.append(image)
         
     | 
| 64 | 
         
            +
                return torch.cat(images, dim=0)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def save_images_to_mp4(images: np.ndarray, output_path: Path, fps: int = 10):
         
     | 
| 68 | 
         
            +
                writer = imageio.get_writer(output_path, fps=fps)
         
     | 
| 69 | 
         
            +
                for img in images:
         
     | 
| 70 | 
         
            +
                    writer.append_data(img)
         
     | 
| 71 | 
         
            +
                writer.close()
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            def determine_num_rounds(num_frames: int, num_overlapping_frames: int, n_initial_frames: int) -> int:
         
     | 
| 75 | 
         
            +
                n_rounds = (num_frames - n_initial_frames) // (N_FRAMES_PER_ROUND - num_overlapping_frames)
         
     | 
| 76 | 
         
            +
                if (num_frames - n_initial_frames) % (N_FRAMES_PER_ROUND - num_overlapping_frames) > 0:
         
     | 
| 77 | 
         
            +
                    n_rounds += 1
         
     | 
| 78 | 
         
            +
                return n_rounds
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def prepare_action(
         
     | 
| 82 | 
         
            +
                traj_template: dict,
         
     | 
| 83 | 
         
            +
                cmd: str,
         
     | 
| 84 | 
         
            +
                path_start_id: int, 
         
     | 
| 85 | 
         
            +
                path_point_interval: int, 
         
     | 
| 86 | 
         
            +
                n_action_tokens: int = 5, 
         
     | 
| 87 | 
         
            +
                start_index: int = 0, 
         
     | 
| 88 | 
         
            +
                n_frames: int = 25
         
     | 
| 89 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 90 | 
         
            +
                trajs = traj_template[cmd]["instruction_trajs"]
         
     | 
| 91 | 
         
            +
                actions = []
         
     | 
| 92 | 
         
            +
                timesteps = np.arange(0.0, 3.0, 0.05)
         
     | 
| 93 | 
         
            +
                for i in range(start_index, start_index + n_frames):
         
     | 
| 94 | 
         
            +
                    traj = trajs[i][path_start_id::path_point_interval][:n_action_tokens]
         
     | 
| 95 | 
         
            +
                    action = np.array(traj)
         
     | 
| 96 | 
         
            +
                    timestep = timesteps[path_start_id::path_point_interval][:n_action_tokens]
         
     | 
| 97 | 
         
            +
                    action = np.concatenate([
         
     | 
| 98 | 
         
            +
                        action[:, [1, 0]],
         
     | 
| 99 | 
         
            +
                        timestep.reshape(-1, 1)
         
     | 
| 100 | 
         
            +
                    ], axis=1)
         
     | 
| 101 | 
         
            +
                    actions.append(torch.tensor(action))
         
     | 
| 102 | 
         
            +
                return torch.cat(actions, dim=0)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 106 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 107 | 
         
            +
                parser.add_argument("--seed", type=int, default=0)
         
     | 
| 108 | 
         
            +
                parser.add_argument("--output_dir", type=Path)
         
     | 
| 109 | 
         
            +
                parser.add_argument("--cmd", type=str, default="curving_to_left/curving_to_left_moderate")
         
     | 
| 110 | 
         
            +
                parser.add_argument("--num_frames", type=int, default=25)
         
     | 
| 111 | 
         
            +
                parser.add_argument("--num_overlapping_frames", type=int, default=3)
         
     | 
| 112 | 
         
            +
                args = parser.parse_args()
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                assert args.num_frames <= MAX_NUM_FRAMES, f"`num_frames` should be less than or equal to {MAX_NUM_FRAMES}"
         
     | 
| 115 | 
         
            +
                assert args.num_overlapping_frames < N_FRAMES_PER_ROUND, f"`num_overlapping_frames` should be less than {N_FRAMES_PER_ROUND}"
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                set_random_seed(args.seed)
         
     | 
| 118 | 
         
            +
                if args.output_dir is None:
         
     | 
| 119 | 
         
            +
                    output_dir = Path(f"./outputs/{args.cmd}")
         
     | 
| 120 | 
         
            +
                else:
         
     | 
| 121 | 
         
            +
                    output_dir = args.output_dir
         
     | 
| 122 | 
         
            +
                output_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         
     | 
| 125 | 
         
            +
                tokenizer = AutoModel.from_pretrained("turing-motors/Terra", subfolder="lfq_tokenizer_B_256", trust_remote_code=True).to(device).eval()
         
     | 
| 126 | 
         
            +
                model = AutoModel.from_pretrained("turing-motors/Terra", subfolder="world_model", trust_remote_code=True).to(device).eval()
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                conditioning_frames = load_images(CONDITIONING_FRAMES_PATH_LIST, IMAGE_SIZE).to(device)
         
     | 
| 129 | 
         
            +
                with torch.inference_mode(), torch.autocast(device_type="cuda"):
         
     | 
| 130 | 
         
            +
                    input_ids = tokenizer.tokenize(conditioning_frames).detach().unsqueeze(0)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                num_rounds = determine_num_rounds(args.num_frames, args.num_overlapping_frames, len(CONDITIONING_FRAMES_PATH_LIST))
         
     | 
| 133 | 
         
            +
                print(f"Number of generation rounds: {num_rounds}")
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                with open(TRAJ_TEMPLATE_PATH) as f:
         
     | 
| 136 | 
         
            +
                    traj_template = json.load(f)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                all_outputs = []
         
     | 
| 139 | 
         
            +
                for round in range(num_rounds):
         
     | 
| 140 | 
         
            +
                    start_index = round * (N_FRAMES_PER_ROUND - args.num_overlapping_frames)
         
     | 
| 141 | 
         
            +
                    num_frames_for_round = min(N_FRAMES_PER_ROUND, args.num_frames - start_index)
         
     | 
| 142 | 
         
            +
                    actions = prepare_action(
         
     | 
| 143 | 
         
            +
                        traj_template, args.cmd, PATH_START_ID, PATH_POINT_INTERVAL, N_ACTION_TOKENS, start_index, num_frames_for_round
         
     | 
| 144 | 
         
            +
                    ).unsqueeze(0).to(device).float()
         
     | 
| 145 | 
         
            +
                    if round == 0:
         
     | 
| 146 | 
         
            +
                        num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - len(CONDITIONING_FRAMES_PATH_LIST))
         
     | 
| 147 | 
         
            +
                    else:
         
     | 
| 148 | 
         
            +
                        num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - args.num_overlapping_frames)
         
     | 
| 149 | 
         
            +
                    progress_bar = tqdm(total=num_generated_tokens, desc=f"Round {round + 1}")
         
     | 
| 150 | 
         
            +
                    with torch.inference_mode(), torch.autocast(device_type="cuda"):
         
     | 
| 151 | 
         
            +
                        output_tokens = model.generate(
         
     | 
| 152 | 
         
            +
                            input_ids=input_ids,
         
     | 
| 153 | 
         
            +
                            actions=actions,
         
     | 
| 154 | 
         
            +
                            do_sample=True,
         
     | 
| 155 | 
         
            +
                            max_length=N_TOKENS_PER_FRAME * num_frames_for_round,
         
     | 
| 156 | 
         
            +
                            temperature=1.0,
         
     | 
| 157 | 
         
            +
                            top_p=1.0,
         
     | 
| 158 | 
         
            +
                            use_cache=True,
         
     | 
| 159 | 
         
            +
                            pad_token_id=None,
         
     | 
| 160 | 
         
            +
                            eos_token_id=None,
         
     | 
| 161 | 
         
            +
                            progress_bar=progress_bar
         
     | 
| 162 | 
         
            +
                        )
         
     | 
| 163 | 
         
            +
                    if round == 0:
         
     | 
| 164 | 
         
            +
                        all_outputs.append(output_tokens[0])
         
     | 
| 165 | 
         
            +
                    else:
         
     | 
| 166 | 
         
            +
                        all_outputs.append(output_tokens[0, args.num_overlapping_frames * N_TOKENS_PER_FRAME:])
         
     | 
| 167 | 
         
            +
                    input_ids = output_tokens[:, -args.num_overlapping_frames * N_TOKENS_PER_FRAME:]
         
     | 
| 168 | 
         
            +
                    progress_bar.close()
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                output_ids = torch.cat(all_outputs)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                # Calculate the shape of the latent tensor
         
     | 
| 173 | 
         
            +
                downsample_ratio = 1
         
     | 
| 174 | 
         
            +
                for coef in tokenizer.config.encoder_decoder_config["ch_mult"]:
         
     | 
| 175 | 
         
            +
                    downsample_ratio *= coef
         
     | 
| 176 | 
         
            +
                h = IMAGE_SIZE[0] // downsample_ratio
         
     | 
| 177 | 
         
            +
                w = IMAGE_SIZE[1] // downsample_ratio
         
     | 
| 178 | 
         
            +
                c = tokenizer.config.encoder_decoder_config["z_channels"]
         
     | 
| 179 | 
         
            +
                latent_shape = (len(output_ids) // 576, h, w, c)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                # Decode the latent tensor to images
         
     | 
| 182 | 
         
            +
                with torch.inference_mode(), torch.autocast(device_type="cuda"):
         
     | 
| 183 | 
         
            +
                    reconstructed = tokenizer.decode_tokens(output_ids, latent_shape)
         
     | 
| 184 | 
         
            +
                reconstructed_images = to_np_images(reconstructed)
         
     | 
| 185 | 
         
            +
                save_images_to_mp4(reconstructed_images, output_dir / "generated.mp4", fps=10)
         
     |