YogaPoseClassify / pose_detection.py
pegasama's picture
train and test python script
b26156a verified
raw
history blame
14.1 kB
#!/usr/bin/env python3
"""
Use MediaPipe to detect poses in images and extract landmark coordinates.
Features:
1. Run MediaPipe pose detection on images in the train folder
2. Use the nose as the head reference point (headPos)
3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals
4. Save processed landmarks into JSON files named after the image files
Usage:
python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR]
"""
import os
import json
import argparse
from pathlib import Path
import cv2
import mediapipe as mp
class PoseDetector:
def __init__(self):
"""Initialize MediaPipe pose detector."""
self.mp_pose = mp.solutions.pose
self.pose = self.mp_pose.Pose(
static_image_mode=True,
model_complexity=2,
enable_segmentation=False,
min_detection_confidence=0.5
)
# MediaPipe pose landmark name mapping
self.landmark_names = [
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
'right_eye_inner', 'right_eye', 'right_eye_outer',
'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
'left_index', 'right_index', 'left_thumb', 'right_thumb',
'left_hip', 'right_hip', 'left_knee', 'right_knee',
'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
'left_foot_index', 'right_foot_index'
]
def get_head_position(self, landmarks):
"""
Compute the head reference position (use the nose landmark).
Args:
landmarks: MediaPipe detected landmarks
Returns:
tuple: (x, y, z) head coordinates
"""
# use nose as the head reference point
nose = landmarks[0] # nose is the 0th landmark
return (nose.x, nose.y, nose.z)
def process_landmarks(self, landmarks, head_pos):
"""
Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals.
Args:
landmarks: MediaPipe detected landmarks
head_pos: head coordinates (x, y, z)
Returns:
dict: processed landmarks dictionary
"""
processed_landmarks = {}
head_pos_x = head_pos[0]
head_pos_y = head_pos[1]
head_pos_z = head_pos[2]
for i, landmark in enumerate(landmarks):
if i < len(self.landmark_names):
name = self.landmark_names[i]
# Calculate coordinates relative to head and multiply by 100
rel_x = round((landmark.x - head_pos_x) * 100, 2)
rel_y = round((landmark.y - head_pos_y) * 100, 2)
rel_z = round((landmark.z - head_pos_z) * 100, 2)
processed_landmarks[name] = {
'x': rel_x,
'y': rel_y,
'z': rel_z,
'visibility': round(landmark.visibility, 3)
}
return processed_landmarks
def detect_pose(self, image_path):
"""
Detect pose for a single image.
Args:
image_path: path to the image file
Returns:
dict: processed landmarks and metadata, or None on failure
"""
try:
# Read image
image = cv2.imread(str(image_path))
if image is None:
print(f"Unable to read image: {image_path}")
return None
# Convert color space (BGR -> RGB)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run pose detection
results = self.pose.process(image_rgb)
if results.pose_landmarks is None:
print(f"No pose detected: {image_path}")
return None
# Get keypoints
landmarks = results.pose_landmarks.landmark
# Get head position
head_pos = self.get_head_position(landmarks)
# Process keypoint coordinates
processed_landmarks = self.process_landmarks(landmarks, head_pos)
# extract label from parent folder name
label = image_path.parent.name
# Add metadata
result = {
'image_path': str(image_path),
'image_name': image_path.name,
'label': label,
'head_position': {
'x': round(head_pos[0], 4),
'y': round(head_pos[1], 4),
'z': round(head_pos[2], 4)
},
'landmarks': processed_landmarks,
'total_landmarks': len(processed_landmarks)
}
return result
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None
def close(self):
"""Close MediaPipe resources."""
self.pose.close()
def process_all_training_data(input_dir, output_dir, batch_size=100):
"""
Process all images in the training dataset and write JSON files.
Args:
input_dir: input images directory (TrainData/train)
output_dir: output JSON directory (PoseData)
batch_size: progress report batch size
"""
input_path = Path(input_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Supported image formats
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
detector = PoseDetector()
try:
# statistics
total_images = 0
success_count = 0
failed_count = 0
label_stats = {}
print(f"Starting processing dataset: {input_path}")
print(f"Output directory: {output_path}")
# first count all images
print("Counting images...")
label_dirs = []
for item in input_path.iterdir():
if item.is_dir() and item.name.startswith('label_'):
label = item.name
image_files = [f for f in item.iterdir()
if f.is_file() and f.suffix.lower() in image_extensions]
if image_files:
label_dirs.append((item, label, image_files))
total_images += len(image_files)
label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0}
print(f"Found {len(label_dirs)} label directories, total {total_images} images")
for label, stats in label_stats.items():
print(f" {label}: {stats['total']} images")
print("\nStarting to process images...")
# process each label directory
for label_dir, label_name, image_files in label_dirs:
print(f"\n--- Processing {label_name} ({len(image_files)} images) ---")
# create output folder for this label
output_label_dir = output_path / label_name
output_label_dir.mkdir(parents=True, exist_ok=True)
# process every image in this label
for i, image_file in enumerate(image_files, 1):
json_filename = image_file.stem + '.json'
json_path = output_label_dir / json_filename
# detect pose
result = detector.detect_pose(image_file)
if result is not None:
# save JSON
try:
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
success_count += 1
label_stats[label_name]['success'] += 1
# progress
if success_count % batch_size == 0:
progress = (success_count / total_images) * 100 if total_images else 0
print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}")
except Exception as e:
print(f" Failed to save JSON {json_path}: {e}")
failed_count += 1
label_stats[label_name]['failed'] += 1
else:
failed_count += 1
label_stats[label_name]['failed'] += 1
if failed_count % 10 == 0: # print every 10 failures
print(f" Detection failed: {image_file.name}")
# report for this label
stats = label_stats[label_name]
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%")
print("\n" + "=" * 60)
print("Processing complete!")
print(f"Total images: {total_images}")
print(f"Successfully processed: {success_count}")
print(f"Failed: {failed_count}")
total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0
print(f"Overall success rate: {total_success_rate:.1f}%")
print("\nPer-label statistics:")
for label, stats in label_stats.items():
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)")
print(f"\nJSON files saved to: {output_path.absolute()}")
print("Directory structure:")
print("PoseData/")
for label in sorted(label_stats.keys()):
print(f"β”œβ”€β”€ {label}/")
print("β”‚ └── *.json")
finally:
detector.close()
def process_directory(input_dir, output_dir):
"""
Process all images in a directory tree and write JSON files.
Args:
input_dir: input images directory
output_dir: output JSON directory
"""
input_path = Path(input_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Supported image formats
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
detector = PoseDetector()
try:
# statistics
total_images = 0
success_count = 0
failed_count = 0
print(f"Starting to process directory: {input_path}")
print(f"Output directory: {output_path}")
# walk through the tree
for root, dirs, files in os.walk(input_path):
root_path = Path(root)
# create corresponding output folder
relative_path = root_path.relative_to(input_path)
current_output_dir = output_path / relative_path
current_output_dir.mkdir(parents=True, exist_ok=True)
# collect image files in this folder
image_files = [f for f in files if Path(f).suffix.lower() in image_extensions]
if image_files:
print(f"\nProcessing directory: {root_path}")
print(f"Found {len(image_files)} images")
for filename in image_files:
total_images += 1
image_path = root_path / filename
# generate JSON filename (replace extension with .json)
json_filename = Path(filename).stem + '.json'
json_path = current_output_dir / json_filename
# detect pose
result = detector.detect_pose(image_path)
if result is not None:
# save JSON file
try:
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
success_count += 1
if success_count % 50 == 0:
print(f"Successfully processed {success_count} images...")
except Exception as e:
print(f"Failed to save JSON {json_path}: {e}")
failed_count += 1
else:
failed_count += 1
print("\nProcessing complete!")
print(f"Total images: {total_images}")
print(f"Successfully processed: {success_count}")
print(f"Failed: {failed_count}")
print(f"Success rate: {success_count/total_images*100:.1f}%")
finally:
detector.close()
def main():
parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data")
parser.add_argument("--input", "-i", default="TrainData/train",
help="input images directory (default: TrainData/train)")
parser.add_argument("--output", "-o", default="PoseData",
help="output JSON directory (default: PoseData)")
parser.add_argument("--batch-size", "-b", type=int, default=100,
help="batch size for progress reporting (default: 100)")
args = parser.parse_args()
# check input directory exists
if not Path(args.input).exists():
print(f"Error: input directory does not exist: {args.input}")
return
print("MediaPipe pose detection tool")
print("=" * 60)
print(f"Input directory: {args.input}")
print(f"Output directory: {args.output}")
print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals")
print("Head reference: nose")
print(f"Batch size: show progress every {args.batch_size} images")
print("=" * 60)
# Start processing the entire training dataset
process_all_training_data(args.input, args.output, args.batch_size)
if __name__ == "__main__":
main()