Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration, AutoProcessor | |
| import gc | |
| MODEL_ID = "arjunanand13/gas_pipe_llava_finetunedv3" | |
| def clear_memory(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def extract_frames_from_video(video_path, num_frames=4): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Cannot open video: {video_path}") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames < num_frames: | |
| num_frames = min(total_frames, num_frames) | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| frames = [] | |
| for frame_idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_pil = Image.fromarray(frame_rgb) | |
| frame_resized = frame_pil.resize((112, 112), Image.Resampling.LANCZOS) | |
| frames.append(frame_resized) | |
| cap.release() | |
| while len(frames) < 4: | |
| if frames: | |
| frames.append(frames[-1].copy()) | |
| else: | |
| frames.append(Image.new('RGB', (112, 112), color='black')) | |
| return frames[:4] | |
| def create_frame_grid(frames, grid_size=(2, 2)): | |
| cols, rows = grid_size | |
| frame_size = 112 | |
| grid_width = frame_size * cols | |
| grid_height = frame_size * rows | |
| grid_image = Image.new('RGB', (grid_width, grid_height)) | |
| for i, frame in enumerate(frames): | |
| row = i // cols | |
| col = i % cols | |
| x = col * frame_size | |
| y = row * frame_size | |
| grid_image.paste(frame, (x, y)) | |
| return grid_image | |
| def load_model(): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_storage=torch.uint8 | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| processor.tokenizer.padding_side = "right" | |
| processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
| model = LlavaNextForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| model.config.use_cache = False | |
| model.eval() | |
| return model, processor | |
| model, processor = load_model() | |
| def predict_gas_pipe_quality(video_path): | |
| try: | |
| frames = extract_frames_from_video(video_path, num_frames=4) | |
| grid_image = create_frame_grid(frames, grid_size=(2, 2)) | |
| prompt = "[INST] <image>\nGas pipe test result? [/INST]" | |
| inputs = processor(text=prompt, images=grid_image, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| pixel_values=inputs["pixel_values"], | |
| image_sizes=inputs["image_sizes"], | |
| max_new_tokens=16, | |
| do_sample=False, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| ) | |
| prediction = processor.batch_decode( | |
| generated_ids[:, inputs["input_ids"].size(1):], | |
| skip_special_tokens=True | |
| )[0].strip() | |
| clear_memory() | |
| return grid_image, prediction | |
| except Exception as e: | |
| clear_memory() | |
| return None, f"Error: {str(e)}" | |
| def create_interface(): | |
| with gr.Blocks(title="Gas Pipe Quality Control") as iface: | |
| gr.Markdown("# Gas Pipe Quality Control") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video") | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| frame_grid = gr.Image(label="Extracted Frames") | |
| result_output = gr.Textbox(label="Model Output", lines=3) | |
| gr.Examples( | |
| examples=[ | |
| ["13.mp4"], | |
| ["14.mp4"], | |
| ["04.mp4"], | |
| ["07_part1.mp4"], | |
| ["09_part1.mp4"], | |
| ["29_part1.mp4"] | |
| ], | |
| inputs=video_input, | |
| outputs=[frame_grid, result_output], | |
| fn=predict_gas_pipe_quality, | |
| cache_examples=False | |
| ) | |
| analyze_btn.click( | |
| fn=predict_gas_pipe_quality, | |
| inputs=video_input, | |
| outputs=[frame_grid, result_output] | |
| ) | |
| video_input.change( | |
| fn=predict_gas_pipe_quality, | |
| inputs=video_input, | |
| outputs=[frame_grid, result_output] | |
| ) | |
| return iface | |
| if __name__ == "__main__": | |
| iface = create_interface() | |
| iface.launch(share=True) |