File size: 12,520 Bytes
32e4125
 
 
 
 
 
 
0a7f9b4
 
32e4125
 
0a7f9b4
32e4125
 
 
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32e4125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75da29
 
32e4125
 
 
 
 
f75da29
32e4125
 
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
159faf0
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
0a7f9b4
 
 
 
 
 
 
 
 
 
 
159faf0
0a7f9b4
 
 
 
 
 
 
 
 
32e4125
 
 
0a7f9b4
32e4125
 
 
 
 
 
 
 
 
 
 
 
0a7f9b4
32e4125
 
 
 
 
 
 
 
 
 
 
 
159faf0
32e4125
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
32e4125
 
 
 
f75da29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
f75da29
 
32e4125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32e4125
0a7f9b4
32e4125
 
 
 
 
 
 
 
 
 
 
 
159faf0
32e4125
 
 
 
 
 
 
 
0a7f9b4
 
 
 
 
 
32e4125
 
 
 
 
 
 
 
 
 
 
 
 
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""
Memory monitoring and management utilities for production deployment.
"""

import gc
import logging
import os
import threading
import time
import tracemalloc
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast

logger = logging.getLogger(__name__)

# Environment flag to enable deeper / more frequent memory diagnostics
MEMORY_DEBUG = os.getenv("MEMORY_DEBUG", "0") not in (None, "0", "false", "False")
ENABLE_TRACEMALLOC = os.getenv("ENABLE_TRACEMALLOC", "0") not in (
    None,
    "0",
    "false",
    "False",
)

# Memory milestone thresholds (MB) which trigger enhanced logging once per run
MEMORY_THRESHOLDS = [300, 400, 450, 500]
_crossed_thresholds: "set[int]" = set()  # type: ignore[type-arg]

_tracemalloc_started = False
_periodic_thread_started = False
_periodic_thread: Optional[threading.Thread] = None


def get_memory_usage() -> float:
    """
    Get current memory usage in MB.
    Falls back to basic approach if psutil is not available.
    """
    try:
        import psutil

        return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
    except ImportError:
        # Fallback: use tracemalloc if available
        try:
            current, peak = tracemalloc.get_traced_memory()
            return current / 1024 / 1024
        except Exception:
            return 0.0


def log_memory_usage(context: str = "") -> float:
    """Log current memory usage with context and return the memory value."""
    memory_mb = get_memory_usage()
    if context:
        logger.info(f"Memory usage ({context}): {memory_mb:.1f}MB")
    else:
        logger.info(f"Memory usage: {memory_mb:.1f}MB")
    return memory_mb


def _collect_detailed_stats() -> Dict[str, Any]:
    """Collect additional (lightweight) diagnostics; guarded by MEMORY_DEBUG."""
    stats: Dict[str, Any] = {}
    try:
        import psutil  # type: ignore

        p = psutil.Process(os.getpid())
        with p.oneshot():
            mem = p.memory_info()
            stats["rss_mb"] = mem.rss / 1024 / 1024
            stats["vms_mb"] = mem.vms / 1024 / 1024
            stats["num_threads"] = p.num_threads()
            stats["open_files"] = len(p.open_files()) if hasattr(p, "open_files") else None
    except Exception:
        pass
    # tracemalloc snapshot (only if already tracing to avoid overhead)
    if tracemalloc.is_tracing():
        try:
            current, peak = tracemalloc.get_traced_memory()
            stats["tracemalloc_current_mb"] = current / 1024 / 1024
            stats["tracemalloc_peak_mb"] = peak / 1024 / 1024
        except Exception:
            pass
    # GC counts are cheap
    try:
        stats["gc_counts"] = gc.get_count()
    except Exception:
        pass
    return stats


def log_memory_checkpoint(context: str, force: bool = False):
    """Log a richer memory diagnostic line if MEMORY_DEBUG is enabled or force=True.

    Args:
        context: Label for where in code we are capturing this
        force: Override MEMORY_DEBUG gate
    """
    if not (MEMORY_DEBUG or force):
        return
    base = get_memory_usage()
    stats = _collect_detailed_stats()
    logger.info(
        "[MEMORY CHECKPOINT] %s | rss=%.1fMB details=%s",
        context,
        base,
        stats,
    )

    # Automatic milestone snapshot logging
    _maybe_log_milestone(base, context)

    # If tracemalloc enabled and memory above 380MB (pre-crit), log top allocations
    if ENABLE_TRACEMALLOC and base > 380:
        log_top_tracemalloc(f"high_mem_{context}")


def start_tracemalloc(nframes: int = 25):
    """Start tracemalloc if enabled via environment flag."""
    global _tracemalloc_started
    if ENABLE_TRACEMALLOC and not _tracemalloc_started:
        try:
            tracemalloc.start(nframes)
            _tracemalloc_started = True
            logger.info("tracemalloc started (nframes=%d)", nframes)
        except Exception as e:  # pragma: no cover
            logger.warning(f"Failed to start tracemalloc: {e}")


def log_top_tracemalloc(label: str, limit: int = 10):
    """Log top memory allocation traces if tracemalloc is running."""
    if not tracemalloc.is_tracing():
        return
    try:
        snapshot = tracemalloc.take_snapshot()
        top_stats = snapshot.statistics("lineno")
        logger.info("[TRACEMALLOC] Top %d allocations (%s)", limit, label)
        for stat in top_stats[:limit]:
            logger.info("[TRACEMALLOC] %s", stat)
    except Exception as e:  # pragma: no cover
        logger.debug(f"Failed logging tracemalloc stats: {e}")


def memory_summary(include_tracemalloc: bool = True) -> Dict[str, Any]:
    """Return a dictionary summary of current memory diagnostics."""
    summary: Dict[str, Any] = {}
    summary["rss_mb"] = get_memory_usage()
    # Include which milestones crossed
    summary["milestones_crossed"] = sorted(list(_crossed_thresholds))
    stats = _collect_detailed_stats()
    summary.update(stats)
    if include_tracemalloc and tracemalloc.is_tracing():
        try:
            current, peak = tracemalloc.get_traced_memory()
            summary["tracemalloc_current_mb"] = current / 1024 / 1024
            summary["tracemalloc_peak_mb"] = peak / 1024 / 1024
        except Exception:
            pass
    return summary


def start_periodic_memory_logger(interval_seconds: int = 60):
    """Start a background thread that logs memory every interval_seconds."""
    global _periodic_thread_started, _periodic_thread
    if _periodic_thread_started:
        return

    def _runner():
        logger.info(
            ("Periodic memory logger started (interval=%ds, " "debug=%s, tracemalloc=%s)"),
            interval_seconds,
            MEMORY_DEBUG,
            tracemalloc.is_tracing(),
        )
        while True:
            try:
                log_memory_checkpoint("periodic", force=True)
            except Exception:  # pragma: no cover
                logger.debug("Periodic memory logger iteration failed", exc_info=True)
            time.sleep(interval_seconds)

    _periodic_thread = threading.Thread(target=_runner, name="PeriodicMemoryLogger", daemon=True)
    _periodic_thread.start()
    _periodic_thread_started = True
    logger.info("Periodic memory logger thread started")


R = TypeVar("R")


def memory_monitor(func: Callable[..., R]) -> Callable[..., R]:
    """Decorator to monitor memory usage of functions."""

    @wraps(func)
    def wrapper(*args: Tuple[Any, ...], **kwargs: Any):  # type: ignore[override]
        memory_before = get_memory_usage()
        result = func(*args, **kwargs)
        memory_after = get_memory_usage()
        memory_diff = memory_after - memory_before

        logger.info(
            f"Memory change in {func.__name__}: "
            f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
            f"(+{memory_diff:.1f}MB)"
        )
        return result

    return cast(Callable[..., R], wrapper)


def force_garbage_collection():
    """Force garbage collection and log memory freed."""
    memory_before = get_memory_usage()

    # Force garbage collection
    collected = gc.collect()

    memory_after = get_memory_usage()
    memory_freed = memory_before - memory_after

    logger.info(f"Garbage collection: freed {memory_freed:.1f}MB, " f"collected {collected} objects")


def check_memory_threshold(threshold_mb: float = 400) -> bool:
    """
    Check if memory usage exceeds threshold.

    Args:
        threshold_mb: Memory threshold in MB (default 400MB for 512MB limit)

    Returns:
        True if memory usage is above threshold
    """
    current_memory = get_memory_usage()
    if current_memory > threshold_mb:
        logger.warning(f"Memory usage {current_memory:.1f}MB exceeds threshold {threshold_mb}MB")
        return True
    return False


def clean_memory(context: str = ""):
    """
    Clean memory and force garbage collection with context logging.

    Args:
        context: Description of when/why cleanup is happening
    """
    memory_before = get_memory_usage()

    # Force garbage collection
    collected = gc.collect()

    memory_after = get_memory_usage()
    memory_freed = memory_before - memory_after

    if context:
        logger.info(
            f"Memory cleanup ({context}): "
            f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
            f"(freed {memory_freed:.1f}MB, collected {collected} objects)"
        )
    else:
        logger.info(f"Memory cleanup: freed {memory_freed:.1f}MB, collected {collected} objects")


def optimize_memory():
    """
    Perform memory optimization operations.
    Called when memory usage gets high.
    """
    logger.info("Performing memory optimization...")

    # Force garbage collection
    force_garbage_collection()

    # Clear any model caches if they exist
    try:
        from src.embedding.embedding_service import EmbeddingService

        if hasattr(EmbeddingService, "_model_cache"):
            cache_attr = getattr(EmbeddingService, "_model_cache")
            # type: ignore[attr-defined]
            try:
                cache_size = len(cache_attr)
                # Keep at least one model cached
                if cache_size > 1:
                    keys = list(cache_attr.keys())
                    for key in keys[:-1]:
                        del cache_attr[key]
                    logger.info(
                        "Cleared %d cached models, kept 1",
                        cache_size - 1,
                    )
            except Exception as e:  # pragma: no cover
                logger.debug("Failed clearing model cache: %s", e)
    except Exception as e:
        logger.debug("Could not clear model cache: %s", e)


class MemoryManager:
    """Context manager for memory-intensive operations."""

    def __init__(self, operation_name: str = "operation", threshold_mb: float = 400):
        self.operation_name = operation_name
        self.threshold_mb = threshold_mb
        self.start_memory: Optional[float] = None

    def __enter__(self):
        self.start_memory = get_memory_usage()
        logger.info(f"Starting {self.operation_name} (Memory: {self.start_memory:.1f}MB)")

        # Check if we're already near the threshold
        if self.start_memory > self.threshold_mb:
            logger.warning("Starting operation with high memory usage")
            optimize_memory()

        return self

    def __exit__(
        self,
        exc_type: Optional[type],
        exc_val: Optional[BaseException],
        exc_tb: Optional[Any],
    ) -> None:
        end_memory = get_memory_usage()
        memory_diff = end_memory - (self.start_memory or 0)

        logger.info(
            f"Completed {self.operation_name} "
            f"(Memory: {self.start_memory:.1f}MB -> {end_memory:.1f}MB, "
            f"Change: {memory_diff:+.1f}MB)"
        )

        # If memory usage increased significantly, trigger cleanup
        if memory_diff > 50:  # More than 50MB increase
            logger.info("Large memory increase detected, running cleanup")
            force_garbage_collection()

            # Capture a post-cleanup checkpoint if deep debugging enabled
            log_memory_checkpoint(f"post_cleanup_{self.operation_name}")


# ---------- Milestone & force-clean helpers ---------- #


def _maybe_log_milestone(current_mb: float, context: str):
    """Internal: log when crossing defined memory thresholds."""
    for threshold in MEMORY_THRESHOLDS:
        if current_mb >= threshold and threshold not in _crossed_thresholds:
            _crossed_thresholds.add(threshold)
            logger.warning(
                "[MEMORY MILESTONE] %.1fMB crossed threshold %dMB " "(context=%s)",
                current_mb,
                threshold,
                context,
            )
            # Provide immediate snapshot & optionally top allocations
            details = memory_summary(include_tracemalloc=True)
            logger.info("[MEMORY SNAPSHOT @%dMB] summary=%s", threshold, details)
            if ENABLE_TRACEMALLOC and tracemalloc.is_tracing():
                log_top_tracemalloc(f"milestone_{threshold}MB")


def force_clean_and_report(label: str = "manual") -> Dict[str, Any]:
    """Force GC + optimization and return post-clean summary."""
    logger.info("Force clean invoked (%s)", label)
    force_garbage_collection()
    optimize_memory()
    summary = memory_summary(include_tracemalloc=True)
    logger.info("Post-clean memory summary (%s): %s", label, summary)
    return summary