Update handler.py
Browse files- handler.py +24 -7
handler.py
CHANGED
|
@@ -8,6 +8,11 @@ import tempfile
|
|
| 8 |
import numpy as np
|
| 9 |
from moviepy.editor import ImageSequenceClip
|
| 10 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class EndpointHandler:
|
| 13 |
def __init__(self, path: str = ""):
|
|
@@ -34,22 +39,28 @@ class EndpointHandler:
|
|
| 34 |
# Set default FPS
|
| 35 |
self.fps = 24
|
| 36 |
|
| 37 |
-
def _create_video_file(self,
|
| 38 |
"""Convert frames to an MP4 video file.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
-
|
| 42 |
fps (int): Frames per second for the output video
|
| 43 |
|
| 44 |
Returns:
|
| 45 |
bytes: MP4 video file content
|
| 46 |
"""
|
| 47 |
-
#
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
video_np = (video_np * 255).astype(np.uint8)
|
| 50 |
|
| 51 |
# Get dimensions
|
| 52 |
-
height, width = video_np.shape
|
|
|
|
| 53 |
|
| 54 |
# Create temporary file
|
| 55 |
output_path = tempfile.mktemp(suffix=".mp4")
|
|
@@ -103,6 +114,9 @@ class EndpointHandler:
|
|
| 103 |
guidance_scale = data.get("guidance_scale", 7.5)
|
| 104 |
num_inference_steps = data.get("num_inference_steps", 50)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
# Check if image is provided for image-to-video generation
|
| 107 |
image_data = data.get("image")
|
| 108 |
|
|
@@ -112,6 +126,7 @@ class EndpointHandler:
|
|
| 112 |
# Decode base64 image
|
| 113 |
image_bytes = base64.b64decode(image_data)
|
| 114 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
| 115 |
|
| 116 |
# Generate video from image
|
| 117 |
output = self.image_to_video(
|
|
@@ -121,8 +136,9 @@ class EndpointHandler:
|
|
| 121 |
guidance_scale=guidance_scale,
|
| 122 |
num_inference_steps=num_inference_steps,
|
| 123 |
output_type="pt"
|
| 124 |
-
).frames[0]
|
| 125 |
else:
|
|
|
|
| 126 |
# Generate video from text only
|
| 127 |
output = self.text_to_video(
|
| 128 |
prompt=prompt,
|
|
@@ -130,7 +146,7 @@ class EndpointHandler:
|
|
| 130 |
guidance_scale=guidance_scale,
|
| 131 |
num_inference_steps=num_inference_steps,
|
| 132 |
output_type="pt"
|
| 133 |
-
).frames[0]
|
| 134 |
|
| 135 |
# Convert frames to video file
|
| 136 |
video_content = self._create_video_file(output, fps=fps)
|
|
@@ -144,4 +160,5 @@ class EndpointHandler:
|
|
| 144 |
}
|
| 145 |
|
| 146 |
except Exception as e:
|
|
|
|
| 147 |
raise RuntimeError(f"Error generating video: {str(e)}")
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from moviepy.editor import ImageSequenceClip
|
| 10 |
import os
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
class EndpointHandler:
|
| 18 |
def __init__(self, path: str = ""):
|
|
|
|
| 39 |
# Set default FPS
|
| 40 |
self.fps = 24
|
| 41 |
|
| 42 |
+
def _create_video_file(self, frames: torch.Tensor, fps: int = 24) -> bytes:
|
| 43 |
"""Convert frames to an MP4 video file.
|
| 44 |
|
| 45 |
Args:
|
| 46 |
+
frames (torch.Tensor): Generated frames tensor
|
| 47 |
fps (int): Frames per second for the output video
|
| 48 |
|
| 49 |
Returns:
|
| 50 |
bytes: MP4 video file content
|
| 51 |
"""
|
| 52 |
+
# Log frame information
|
| 53 |
+
num_frames = frames.shape[1] # Shape should be [1, num_frames, channels, height, width]
|
| 54 |
+
duration = num_frames / fps
|
| 55 |
+
logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)")
|
| 56 |
+
|
| 57 |
+
# Convert tensor to numpy array - remove batch dimension and rearrange to [num_frames, height, width, channels]
|
| 58 |
+
video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy()
|
| 59 |
video_np = (video_np * 255).astype(np.uint8)
|
| 60 |
|
| 61 |
# Get dimensions
|
| 62 |
+
_, height, width, _ = video_np.shape
|
| 63 |
+
logger.info(f"Video dimensions: {width}x{height}")
|
| 64 |
|
| 65 |
# Create temporary file
|
| 66 |
output_path = tempfile.mktemp(suffix=".mp4")
|
|
|
|
| 114 |
guidance_scale = data.get("guidance_scale", 7.5)
|
| 115 |
num_inference_steps = data.get("num_inference_steps", 50)
|
| 116 |
|
| 117 |
+
logger.info(f"Generating video with prompt: '{prompt}'")
|
| 118 |
+
logger.info(f"Parameters: num_frames={num_frames}, fps={fps}, guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}")
|
| 119 |
+
|
| 120 |
# Check if image is provided for image-to-video generation
|
| 121 |
image_data = data.get("image")
|
| 122 |
|
|
|
|
| 126 |
# Decode base64 image
|
| 127 |
image_bytes = base64.b64decode(image_data)
|
| 128 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 129 |
+
logger.info("Using image-to-video generation mode")
|
| 130 |
|
| 131 |
# Generate video from image
|
| 132 |
output = self.image_to_video(
|
|
|
|
| 136 |
guidance_scale=guidance_scale,
|
| 137 |
num_inference_steps=num_inference_steps,
|
| 138 |
output_type="pt"
|
| 139 |
+
).frames # Remove [0] to keep all frames
|
| 140 |
else:
|
| 141 |
+
logger.info("Using text-to-video generation mode")
|
| 142 |
# Generate video from text only
|
| 143 |
output = self.text_to_video(
|
| 144 |
prompt=prompt,
|
|
|
|
| 146 |
guidance_scale=guidance_scale,
|
| 147 |
num_inference_steps=num_inference_steps,
|
| 148 |
output_type="pt"
|
| 149 |
+
).frames # Remove [0] to keep all frames
|
| 150 |
|
| 151 |
# Convert frames to video file
|
| 152 |
video_content = self._create_video_file(output, fps=fps)
|
|
|
|
| 160 |
}
|
| 161 |
|
| 162 |
except Exception as e:
|
| 163 |
+
logger.error(f"Error generating video: {str(e)}")
|
| 164 |
raise RuntimeError(f"Error generating video: {str(e)}")
|