|
|
|
|
|
""" |
|
|
Use MediaPipe to detect poses in images and extract landmark coordinates. |
|
|
|
|
|
Features: |
|
|
1. Run MediaPipe pose detection on images in the train folder |
|
|
2. Use the nose as the head reference point (headPos) |
|
|
3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals |
|
|
4. Save processed landmarks into JSON files named after the image files |
|
|
|
|
|
Usage: |
|
|
python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR] |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import cv2 |
|
|
import mediapipe as mp |
|
|
|
|
|
|
|
|
class PoseDetector: |
|
|
def __init__(self): |
|
|
"""Initialize MediaPipe pose detector.""" |
|
|
self.mp_pose = mp.solutions.pose |
|
|
self.pose = self.mp_pose.Pose( |
|
|
static_image_mode=True, |
|
|
model_complexity=2, |
|
|
enable_segmentation=False, |
|
|
min_detection_confidence=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
self.landmark_names = [ |
|
|
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', |
|
|
'right_eye_inner', 'right_eye', 'right_eye_outer', |
|
|
'left_ear', 'right_ear', 'mouth_left', 'mouth_right', |
|
|
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', |
|
|
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', |
|
|
'left_index', 'right_index', 'left_thumb', 'right_thumb', |
|
|
'left_hip', 'right_hip', 'left_knee', 'right_knee', |
|
|
'left_ankle', 'right_ankle', 'left_heel', 'right_heel', |
|
|
'left_foot_index', 'right_foot_index' |
|
|
] |
|
|
|
|
|
def get_head_position(self, landmarks): |
|
|
""" |
|
|
Compute the head reference position (use the nose landmark). |
|
|
|
|
|
Args: |
|
|
landmarks: MediaPipe detected landmarks |
|
|
|
|
|
Returns: |
|
|
tuple: (x, y, z) head coordinates |
|
|
""" |
|
|
|
|
|
nose = landmarks[0] |
|
|
return (nose.x, nose.y, nose.z) |
|
|
|
|
|
def process_landmarks(self, landmarks, head_pos): |
|
|
""" |
|
|
Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals. |
|
|
|
|
|
Args: |
|
|
landmarks: MediaPipe detected landmarks |
|
|
head_pos: head coordinates (x, y, z) |
|
|
|
|
|
Returns: |
|
|
dict: processed landmarks dictionary |
|
|
""" |
|
|
processed_landmarks = {} |
|
|
head_pos_x = head_pos[0] |
|
|
head_pos_y = head_pos[1] |
|
|
head_pos_z = head_pos[2] |
|
|
|
|
|
for i, landmark in enumerate(landmarks): |
|
|
if i < len(self.landmark_names): |
|
|
name = self.landmark_names[i] |
|
|
|
|
|
|
|
|
rel_x = round((landmark.x - head_pos_x) * 100, 2) |
|
|
rel_y = round((landmark.y - head_pos_y) * 100, 2) |
|
|
rel_z = round((landmark.z - head_pos_z) * 100, 2) |
|
|
|
|
|
processed_landmarks[name] = { |
|
|
'x': rel_x, |
|
|
'y': rel_y, |
|
|
'z': rel_z, |
|
|
'visibility': round(landmark.visibility, 3) |
|
|
} |
|
|
|
|
|
return processed_landmarks |
|
|
|
|
|
def detect_pose(self, image_path): |
|
|
""" |
|
|
Detect pose for a single image. |
|
|
|
|
|
Args: |
|
|
image_path: path to the image file |
|
|
|
|
|
Returns: |
|
|
dict: processed landmarks and metadata, or None on failure |
|
|
""" |
|
|
try: |
|
|
|
|
|
image = cv2.imread(str(image_path)) |
|
|
if image is None: |
|
|
print(f"Unable to read image: {image_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
results = self.pose.process(image_rgb) |
|
|
|
|
|
if results.pose_landmarks is None: |
|
|
print(f"No pose detected: {image_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
landmarks = results.pose_landmarks.landmark |
|
|
|
|
|
|
|
|
head_pos = self.get_head_position(landmarks) |
|
|
|
|
|
|
|
|
processed_landmarks = self.process_landmarks(landmarks, head_pos) |
|
|
|
|
|
|
|
|
label = image_path.parent.name |
|
|
|
|
|
|
|
|
result = { |
|
|
'image_path': str(image_path), |
|
|
'image_name': image_path.name, |
|
|
'label': label, |
|
|
'head_position': { |
|
|
'x': round(head_pos[0], 4), |
|
|
'y': round(head_pos[1], 4), |
|
|
'z': round(head_pos[2], 4) |
|
|
}, |
|
|
'landmarks': processed_landmarks, |
|
|
'total_landmarks': len(processed_landmarks) |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image {image_path}: {e}") |
|
|
return None |
|
|
|
|
|
def close(self): |
|
|
"""Close MediaPipe resources.""" |
|
|
self.pose.close() |
|
|
|
|
|
|
|
|
def process_all_training_data(input_dir, output_dir, batch_size=100): |
|
|
""" |
|
|
Process all images in the training dataset and write JSON files. |
|
|
|
|
|
Args: |
|
|
input_dir: input images directory (TrainData/train) |
|
|
output_dir: output JSON directory (PoseData) |
|
|
batch_size: progress report batch size |
|
|
""" |
|
|
input_path = Path(input_dir) |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
|
|
|
|
|
detector = PoseDetector() |
|
|
|
|
|
try: |
|
|
|
|
|
total_images = 0 |
|
|
success_count = 0 |
|
|
failed_count = 0 |
|
|
label_stats = {} |
|
|
|
|
|
print(f"Starting processing dataset: {input_path}") |
|
|
print(f"Output directory: {output_path}") |
|
|
|
|
|
|
|
|
print("Counting images...") |
|
|
label_dirs = [] |
|
|
for item in input_path.iterdir(): |
|
|
if item.is_dir() and item.name.startswith('label_'): |
|
|
label = item.name |
|
|
image_files = [f for f in item.iterdir() |
|
|
if f.is_file() and f.suffix.lower() in image_extensions] |
|
|
if image_files: |
|
|
label_dirs.append((item, label, image_files)) |
|
|
total_images += len(image_files) |
|
|
label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0} |
|
|
|
|
|
print(f"Found {len(label_dirs)} label directories, total {total_images} images") |
|
|
for label, stats in label_stats.items(): |
|
|
print(f" {label}: {stats['total']} images") |
|
|
|
|
|
print("\nStarting to process images...") |
|
|
|
|
|
|
|
|
for label_dir, label_name, image_files in label_dirs: |
|
|
print(f"\n--- Processing {label_name} ({len(image_files)} images) ---") |
|
|
|
|
|
|
|
|
output_label_dir = output_path / label_name |
|
|
output_label_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
for i, image_file in enumerate(image_files, 1): |
|
|
json_filename = image_file.stem + '.json' |
|
|
json_path = output_label_dir / json_filename |
|
|
|
|
|
|
|
|
result = detector.detect_pose(image_file) |
|
|
|
|
|
if result is not None: |
|
|
|
|
|
try: |
|
|
with open(json_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(result, f, ensure_ascii=False, indent=2) |
|
|
success_count += 1 |
|
|
label_stats[label_name]['success'] += 1 |
|
|
|
|
|
|
|
|
if success_count % batch_size == 0: |
|
|
progress = (success_count / total_images) * 100 if total_images else 0 |
|
|
print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Failed to save JSON {json_path}: {e}") |
|
|
failed_count += 1 |
|
|
label_stats[label_name]['failed'] += 1 |
|
|
else: |
|
|
failed_count += 1 |
|
|
label_stats[label_name]['failed'] += 1 |
|
|
if failed_count % 10 == 0: |
|
|
print(f" Detection failed: {image_file.name}") |
|
|
|
|
|
|
|
|
stats = label_stats[label_name] |
|
|
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 |
|
|
print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Processing complete!") |
|
|
print(f"Total images: {total_images}") |
|
|
print(f"Successfully processed: {success_count}") |
|
|
print(f"Failed: {failed_count}") |
|
|
total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0 |
|
|
print(f"Overall success rate: {total_success_rate:.1f}%") |
|
|
|
|
|
print("\nPer-label statistics:") |
|
|
for label, stats in label_stats.items(): |
|
|
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 |
|
|
print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)") |
|
|
|
|
|
print(f"\nJSON files saved to: {output_path.absolute()}") |
|
|
print("Directory structure:") |
|
|
print("PoseData/") |
|
|
for label in sorted(label_stats.keys()): |
|
|
print(f"βββ {label}/") |
|
|
print("β βββ *.json") |
|
|
|
|
|
finally: |
|
|
detector.close() |
|
|
|
|
|
|
|
|
def process_directory(input_dir, output_dir): |
|
|
""" |
|
|
Process all images in a directory tree and write JSON files. |
|
|
|
|
|
Args: |
|
|
input_dir: input images directory |
|
|
output_dir: output JSON directory |
|
|
""" |
|
|
input_path = Path(input_dir) |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
|
|
|
|
|
detector = PoseDetector() |
|
|
|
|
|
try: |
|
|
|
|
|
total_images = 0 |
|
|
success_count = 0 |
|
|
failed_count = 0 |
|
|
|
|
|
print(f"Starting to process directory: {input_path}") |
|
|
print(f"Output directory: {output_path}") |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(input_path): |
|
|
root_path = Path(root) |
|
|
|
|
|
|
|
|
relative_path = root_path.relative_to(input_path) |
|
|
current_output_dir = output_path / relative_path |
|
|
current_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
image_files = [f for f in files if Path(f).suffix.lower() in image_extensions] |
|
|
|
|
|
if image_files: |
|
|
print(f"\nProcessing directory: {root_path}") |
|
|
print(f"Found {len(image_files)} images") |
|
|
|
|
|
for filename in image_files: |
|
|
total_images += 1 |
|
|
image_path = root_path / filename |
|
|
|
|
|
|
|
|
json_filename = Path(filename).stem + '.json' |
|
|
json_path = current_output_dir / json_filename |
|
|
|
|
|
|
|
|
result = detector.detect_pose(image_path) |
|
|
|
|
|
if result is not None: |
|
|
|
|
|
try: |
|
|
with open(json_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(result, f, ensure_ascii=False, indent=2) |
|
|
success_count += 1 |
|
|
|
|
|
if success_count % 50 == 0: |
|
|
print(f"Successfully processed {success_count} images...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to save JSON {json_path}: {e}") |
|
|
failed_count += 1 |
|
|
else: |
|
|
failed_count += 1 |
|
|
|
|
|
print("\nProcessing complete!") |
|
|
print(f"Total images: {total_images}") |
|
|
print(f"Successfully processed: {success_count}") |
|
|
print(f"Failed: {failed_count}") |
|
|
print(f"Success rate: {success_count/total_images*100:.1f}%") |
|
|
|
|
|
finally: |
|
|
detector.close() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data") |
|
|
parser.add_argument("--input", "-i", default="TrainData/train", |
|
|
help="input images directory (default: TrainData/train)") |
|
|
parser.add_argument("--output", "-o", default="PoseData", |
|
|
help="output JSON directory (default: PoseData)") |
|
|
parser.add_argument("--batch-size", "-b", type=int, default=100, |
|
|
help="batch size for progress reporting (default: 100)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.input).exists(): |
|
|
print(f"Error: input directory does not exist: {args.input}") |
|
|
return |
|
|
|
|
|
print("MediaPipe pose detection tool") |
|
|
print("=" * 60) |
|
|
print(f"Input directory: {args.input}") |
|
|
print(f"Output directory: {args.output}") |
|
|
print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals") |
|
|
print("Head reference: nose") |
|
|
print(f"Batch size: show progress every {args.batch_size} images") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
process_all_training_data(args.input, args.output, args.batch_size) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|