Spaces:
Sleeping
Sleeping
File size: 12,520 Bytes
32e4125 0a7f9b4 32e4125 0a7f9b4 32e4125 0a7f9b4 32e4125 f75da29 32e4125 f75da29 32e4125 0a7f9b4 159faf0 0a7f9b4 159faf0 0a7f9b4 159faf0 0a7f9b4 32e4125 0a7f9b4 32e4125 0a7f9b4 32e4125 159faf0 32e4125 159faf0 32e4125 f75da29 159faf0 f75da29 32e4125 0a7f9b4 32e4125 0a7f9b4 32e4125 159faf0 32e4125 0a7f9b4 32e4125 0a7f9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 |
"""
Memory monitoring and management utilities for production deployment.
"""
import gc
import logging
import os
import threading
import time
import tracemalloc
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast
logger = logging.getLogger(__name__)
# Environment flag to enable deeper / more frequent memory diagnostics
MEMORY_DEBUG = os.getenv("MEMORY_DEBUG", "0") not in (None, "0", "false", "False")
ENABLE_TRACEMALLOC = os.getenv("ENABLE_TRACEMALLOC", "0") not in (
None,
"0",
"false",
"False",
)
# Memory milestone thresholds (MB) which trigger enhanced logging once per run
MEMORY_THRESHOLDS = [300, 400, 450, 500]
_crossed_thresholds: "set[int]" = set() # type: ignore[type-arg]
_tracemalloc_started = False
_periodic_thread_started = False
_periodic_thread: Optional[threading.Thread] = None
def get_memory_usage() -> float:
"""
Get current memory usage in MB.
Falls back to basic approach if psutil is not available.
"""
try:
import psutil
return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
except ImportError:
# Fallback: use tracemalloc if available
try:
current, peak = tracemalloc.get_traced_memory()
return current / 1024 / 1024
except Exception:
return 0.0
def log_memory_usage(context: str = "") -> float:
"""Log current memory usage with context and return the memory value."""
memory_mb = get_memory_usage()
if context:
logger.info(f"Memory usage ({context}): {memory_mb:.1f}MB")
else:
logger.info(f"Memory usage: {memory_mb:.1f}MB")
return memory_mb
def _collect_detailed_stats() -> Dict[str, Any]:
"""Collect additional (lightweight) diagnostics; guarded by MEMORY_DEBUG."""
stats: Dict[str, Any] = {}
try:
import psutil # type: ignore
p = psutil.Process(os.getpid())
with p.oneshot():
mem = p.memory_info()
stats["rss_mb"] = mem.rss / 1024 / 1024
stats["vms_mb"] = mem.vms / 1024 / 1024
stats["num_threads"] = p.num_threads()
stats["open_files"] = len(p.open_files()) if hasattr(p, "open_files") else None
except Exception:
pass
# tracemalloc snapshot (only if already tracing to avoid overhead)
if tracemalloc.is_tracing():
try:
current, peak = tracemalloc.get_traced_memory()
stats["tracemalloc_current_mb"] = current / 1024 / 1024
stats["tracemalloc_peak_mb"] = peak / 1024 / 1024
except Exception:
pass
# GC counts are cheap
try:
stats["gc_counts"] = gc.get_count()
except Exception:
pass
return stats
def log_memory_checkpoint(context: str, force: bool = False):
"""Log a richer memory diagnostic line if MEMORY_DEBUG is enabled or force=True.
Args:
context: Label for where in code we are capturing this
force: Override MEMORY_DEBUG gate
"""
if not (MEMORY_DEBUG or force):
return
base = get_memory_usage()
stats = _collect_detailed_stats()
logger.info(
"[MEMORY CHECKPOINT] %s | rss=%.1fMB details=%s",
context,
base,
stats,
)
# Automatic milestone snapshot logging
_maybe_log_milestone(base, context)
# If tracemalloc enabled and memory above 380MB (pre-crit), log top allocations
if ENABLE_TRACEMALLOC and base > 380:
log_top_tracemalloc(f"high_mem_{context}")
def start_tracemalloc(nframes: int = 25):
"""Start tracemalloc if enabled via environment flag."""
global _tracemalloc_started
if ENABLE_TRACEMALLOC and not _tracemalloc_started:
try:
tracemalloc.start(nframes)
_tracemalloc_started = True
logger.info("tracemalloc started (nframes=%d)", nframes)
except Exception as e: # pragma: no cover
logger.warning(f"Failed to start tracemalloc: {e}")
def log_top_tracemalloc(label: str, limit: int = 10):
"""Log top memory allocation traces if tracemalloc is running."""
if not tracemalloc.is_tracing():
return
try:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
logger.info("[TRACEMALLOC] Top %d allocations (%s)", limit, label)
for stat in top_stats[:limit]:
logger.info("[TRACEMALLOC] %s", stat)
except Exception as e: # pragma: no cover
logger.debug(f"Failed logging tracemalloc stats: {e}")
def memory_summary(include_tracemalloc: bool = True) -> Dict[str, Any]:
"""Return a dictionary summary of current memory diagnostics."""
summary: Dict[str, Any] = {}
summary["rss_mb"] = get_memory_usage()
# Include which milestones crossed
summary["milestones_crossed"] = sorted(list(_crossed_thresholds))
stats = _collect_detailed_stats()
summary.update(stats)
if include_tracemalloc and tracemalloc.is_tracing():
try:
current, peak = tracemalloc.get_traced_memory()
summary["tracemalloc_current_mb"] = current / 1024 / 1024
summary["tracemalloc_peak_mb"] = peak / 1024 / 1024
except Exception:
pass
return summary
def start_periodic_memory_logger(interval_seconds: int = 60):
"""Start a background thread that logs memory every interval_seconds."""
global _periodic_thread_started, _periodic_thread
if _periodic_thread_started:
return
def _runner():
logger.info(
("Periodic memory logger started (interval=%ds, " "debug=%s, tracemalloc=%s)"),
interval_seconds,
MEMORY_DEBUG,
tracemalloc.is_tracing(),
)
while True:
try:
log_memory_checkpoint("periodic", force=True)
except Exception: # pragma: no cover
logger.debug("Periodic memory logger iteration failed", exc_info=True)
time.sleep(interval_seconds)
_periodic_thread = threading.Thread(target=_runner, name="PeriodicMemoryLogger", daemon=True)
_periodic_thread.start()
_periodic_thread_started = True
logger.info("Periodic memory logger thread started")
R = TypeVar("R")
def memory_monitor(func: Callable[..., R]) -> Callable[..., R]:
"""Decorator to monitor memory usage of functions."""
@wraps(func)
def wrapper(*args: Tuple[Any, ...], **kwargs: Any): # type: ignore[override]
memory_before = get_memory_usage()
result = func(*args, **kwargs)
memory_after = get_memory_usage()
memory_diff = memory_after - memory_before
logger.info(
f"Memory change in {func.__name__}: "
f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
f"(+{memory_diff:.1f}MB)"
)
return result
return cast(Callable[..., R], wrapper)
def force_garbage_collection():
"""Force garbage collection and log memory freed."""
memory_before = get_memory_usage()
# Force garbage collection
collected = gc.collect()
memory_after = get_memory_usage()
memory_freed = memory_before - memory_after
logger.info(f"Garbage collection: freed {memory_freed:.1f}MB, " f"collected {collected} objects")
def check_memory_threshold(threshold_mb: float = 400) -> bool:
"""
Check if memory usage exceeds threshold.
Args:
threshold_mb: Memory threshold in MB (default 400MB for 512MB limit)
Returns:
True if memory usage is above threshold
"""
current_memory = get_memory_usage()
if current_memory > threshold_mb:
logger.warning(f"Memory usage {current_memory:.1f}MB exceeds threshold {threshold_mb}MB")
return True
return False
def clean_memory(context: str = ""):
"""
Clean memory and force garbage collection with context logging.
Args:
context: Description of when/why cleanup is happening
"""
memory_before = get_memory_usage()
# Force garbage collection
collected = gc.collect()
memory_after = get_memory_usage()
memory_freed = memory_before - memory_after
if context:
logger.info(
f"Memory cleanup ({context}): "
f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
f"(freed {memory_freed:.1f}MB, collected {collected} objects)"
)
else:
logger.info(f"Memory cleanup: freed {memory_freed:.1f}MB, collected {collected} objects")
def optimize_memory():
"""
Perform memory optimization operations.
Called when memory usage gets high.
"""
logger.info("Performing memory optimization...")
# Force garbage collection
force_garbage_collection()
# Clear any model caches if they exist
try:
from src.embedding.embedding_service import EmbeddingService
if hasattr(EmbeddingService, "_model_cache"):
cache_attr = getattr(EmbeddingService, "_model_cache")
# type: ignore[attr-defined]
try:
cache_size = len(cache_attr)
# Keep at least one model cached
if cache_size > 1:
keys = list(cache_attr.keys())
for key in keys[:-1]:
del cache_attr[key]
logger.info(
"Cleared %d cached models, kept 1",
cache_size - 1,
)
except Exception as e: # pragma: no cover
logger.debug("Failed clearing model cache: %s", e)
except Exception as e:
logger.debug("Could not clear model cache: %s", e)
class MemoryManager:
"""Context manager for memory-intensive operations."""
def __init__(self, operation_name: str = "operation", threshold_mb: float = 400):
self.operation_name = operation_name
self.threshold_mb = threshold_mb
self.start_memory: Optional[float] = None
def __enter__(self):
self.start_memory = get_memory_usage()
logger.info(f"Starting {self.operation_name} (Memory: {self.start_memory:.1f}MB)")
# Check if we're already near the threshold
if self.start_memory > self.threshold_mb:
logger.warning("Starting operation with high memory usage")
optimize_memory()
return self
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
end_memory = get_memory_usage()
memory_diff = end_memory - (self.start_memory or 0)
logger.info(
f"Completed {self.operation_name} "
f"(Memory: {self.start_memory:.1f}MB -> {end_memory:.1f}MB, "
f"Change: {memory_diff:+.1f}MB)"
)
# If memory usage increased significantly, trigger cleanup
if memory_diff > 50: # More than 50MB increase
logger.info("Large memory increase detected, running cleanup")
force_garbage_collection()
# Capture a post-cleanup checkpoint if deep debugging enabled
log_memory_checkpoint(f"post_cleanup_{self.operation_name}")
# ---------- Milestone & force-clean helpers ---------- #
def _maybe_log_milestone(current_mb: float, context: str):
"""Internal: log when crossing defined memory thresholds."""
for threshold in MEMORY_THRESHOLDS:
if current_mb >= threshold and threshold not in _crossed_thresholds:
_crossed_thresholds.add(threshold)
logger.warning(
"[MEMORY MILESTONE] %.1fMB crossed threshold %dMB " "(context=%s)",
current_mb,
threshold,
context,
)
# Provide immediate snapshot & optionally top allocations
details = memory_summary(include_tracemalloc=True)
logger.info("[MEMORY SNAPSHOT @%dMB] summary=%s", threshold, details)
if ENABLE_TRACEMALLOC and tracemalloc.is_tracing():
log_top_tracemalloc(f"milestone_{threshold}MB")
def force_clean_and_report(label: str = "manual") -> Dict[str, Any]:
"""Force GC + optimization and return post-clean summary."""
logger.info("Force clean invoked (%s)", label)
force_garbage_collection()
optimize_memory()
summary = memory_summary(include_tracemalloc=True)
logger.info("Post-clean memory summary (%s): %s", label, summary)
return summary
|