Spaces:
Running on Zero
Running on Zero
File size: 9,785 Bytes
5b8133e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | """Token-level N-gram model for context mixing.
Maintains order-1 through order-N context tables with interpolated
backoff smoothing. Used as a fast, lightweight predictor alongside
the LLM in an ensemble. All operations are deterministic for
lossless codec symmetry.
Uses flat numpy arrays for inner storage instead of nested Python
dicts. This eliminates millions of small dict objects (~3.6 GB →
~1 GB per worker) and replaces the O(K) Python iteration loop in
predict() with a single numpy fancy-indexing call that runs at
C level, drastically reducing GIL hold time with 8+ threads.
"""
import numpy as np
def _context_hash(context_tokens, order):
"""Deterministic 64-bit hash of the last *order* tokens.
Replaces tuple(context_tokens[-order:]) as dict key, eliminating
per-token tuple allocations and reducing GC pressure.
"""
h = 0
end = len(context_tokens)
for i in range(end - order, end):
h = (h * 49157 + context_tokens[i]) & 0xFFFFFFFFFFFFFFFF
return h
class NgramModel:
"""Interpolated N-gram model operating on token IDs.
Uses iterative interpolation: higher-order models are progressively
blended with lower orders, weighted by context frequency. Unseen
contexts fall back smoothly to lower orders down to unigram.
The model updates online after each observed token, so it adapts
to the specific document being compressed.
Inner storage uses flat numpy arrays indexed by slot number.
The outer dict (context_hash → slot) preserves insertion order
for deterministic FIFO eviction.
"""
# Smoothing constant for interpolation weights.
ESCAPE = 5
# Maximum context entries per order.
MAX_TABLE_ENTRIES = 500_000
# Maximum unique continuations per context.
MAX_INNER_ENTRIES = 64
def __init__(self, max_order: int = 4, vocab_size: int = 49152):
self.max_order = max_order
self.vocab_size = vocab_size
# Order-0 (unigram) counts: dense array for fast vector ops.
self.unigram_counts = np.zeros(vocab_size, dtype=np.float64)
self.total_unigram = 0
# Order 1..N: context_hash → slot_index.
# Python dict preserves insertion order for FIFO eviction.
self._slot_map: list = [None] + [dict() for _ in range(max_order)]
# Flat inner storage per order. Each context maps to a "slot"
# containing up to MAX_INNER_ENTRIES (token_id, count) pairs.
# Entries within a slot are kept in insertion order so that
# argmin tie-breaking matches the old dict-based behavior.
self._inner_ids: list = [None] + [
np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
dtype=np.int32)
for _ in range(max_order)
]
self._inner_counts: list = [None] + [
np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
dtype=np.int32)
for _ in range(max_order)
]
self._inner_sizes: list = [None] + [
np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int16)
for _ in range(max_order)
]
self._ctx_totals: list = [None] + [
np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int32)
for _ in range(max_order)
]
# Slot allocation: sequential counter + free list for recycling.
self._next_slot = [0] * (max_order + 1)
self._free_slots: list = [None] + [[] for _ in range(max_order)]
# Pre-allocated buffers for building order predictions.
self._buf = np.zeros(vocab_size, dtype=np.float64)
self._probs = np.zeros(vocab_size, dtype=np.float64)
def reset(self):
"""Reset all counts. Call when starting a new sequence."""
self.unigram_counts[:] = 0
self.total_unigram = 0
self._slot_map = [None] + [dict() for _ in range(self.max_order)]
self._next_slot = [0] * (self.max_order + 1)
self._free_slots = [None] + [[] for _ in range(self.max_order)]
for order in range(1, self.max_order + 1):
self._inner_sizes[order][:] = 0
self._ctx_totals[order][:] = 0
def predict(self, context_tokens: list[int]) -> np.ndarray:
"""Predict next-token distribution given context.
Uses numpy fancy indexing instead of Python dict iteration,
replacing up to 256 Python loop iterations with C-level
array operations that minimize GIL hold time.
Args:
context_tokens: List of preceding token IDs.
Returns:
numpy array of shape (vocab_size,) with probabilities summing to ~1.
"""
# Start with unigram (Laplace-smoothed)
probs = self._probs
np.add(self.unigram_counts, 1.0, out=probs)
probs /= (self.total_unigram + self.vocab_size)
for order in range(1, self.max_order + 1):
if len(context_tokens) < order:
break
ctx = _context_hash(context_tokens, order)
slot = self._slot_map[order].get(ctx)
if slot is None:
continue
total = int(self._ctx_totals[order][slot])
if total == 0:
continue
lam = total / (total + self.ESCAPE)
# Vectorized inner loop: single numpy fancy-index call
# replaces K Python dict iterations (K up to 64).
buf = self._buf
buf[:] = 0
size = int(self._inner_sizes[order][slot])
ids = self._inner_ids[order][slot, :size]
cts = self._inner_counts[order][slot, :size]
buf[ids] = cts # C-level scatter — the key optimization
buf /= buf.sum()
# Blend: probs = lam * order_k + (1-lam) * probs
probs *= (1.0 - lam)
buf *= lam
probs += buf
return probs
def _alloc_slot(self, order: int) -> int:
"""Get a free slot index, recycling evicted slots first."""
if self._free_slots[order]:
return self._free_slots[order].pop()
slot = self._next_slot[order]
self._next_slot[order] += 1
return slot
def update(self, context_tokens: list[int], actual_token: int):
"""Update counts after observing a token.
Must be called identically during compression and decompression
to maintain codec symmetry.
Args:
context_tokens: Context that preceded the token.
actual_token: The token that was actually observed.
"""
# Update unigram
self.unigram_counts[actual_token] += 1
self.total_unigram += 1
# Update higher orders
for order in range(1, self.max_order + 1):
if len(context_tokens) < order:
break
ctx = _context_hash(context_tokens, order)
slot_map = self._slot_map[order]
# Evict oldest context if table is full and this is new
if ctx not in slot_map and len(slot_map) >= self.MAX_TABLE_ENTRIES:
evict_ctx = next(iter(slot_map))
evict_slot = slot_map.pop(evict_ctx)
self._free_slots[order].append(evict_slot)
if ctx in slot_map:
slot = slot_map[ctx]
size = int(self._inner_sizes[order][slot])
ids = self._inner_ids[order][slot]
counts = self._inner_counts[order][slot]
# Search for actual_token (numpy vectorized)
mask = ids[:size] == actual_token
if mask.any():
# Token exists: increment its count
idx = int(np.argmax(mask))
counts[idx] += 1
self._ctx_totals[order][slot] += 1
elif size < self.MAX_INNER_ENTRIES:
# New token, space available: append
ids[size] = actual_token
counts[size] = 1
self._inner_sizes[order][slot] = size + 1
self._ctx_totals[order][slot] += 1
else:
# Full (64 entries). Simulate the original add-then-evict:
# new entry has count=1, evicted entry has count ≤ 1 = 1,
# so net total change is always 0.
min_count = int(counts[:size].min())
if min_count == 1:
# Evict oldest entry with count=1, add new at end.
# Shift maintains insertion order so argmin
# tie-breaking matches the original dict behavior.
min_idx = int(np.argmin(counts[:size]))
if min_idx < size - 1:
ids[min_idx:size-1] = ids[min_idx+1:size]
counts[min_idx:size-1] = counts[min_idx+1:size]
ids[size - 1] = actual_token
counts[size - 1] = 1
# else: min_count > 1, new entry would be sole minimum
# and immediately evicted — no-op on entries and total.
else:
# New context: allocate a slot
slot = self._alloc_slot(order)
slot_map[ctx] = slot
self._inner_ids[order][slot, 0] = actual_token
self._inner_counts[order][slot, 0] = 1
self._inner_sizes[order][slot] = 1
self._ctx_totals[order][slot] = 1
|