File size: 9,785 Bytes
5b8133e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Token-level N-gram model for context mixing.



Maintains order-1 through order-N context tables with interpolated

backoff smoothing. Used as a fast, lightweight predictor alongside

the LLM in an ensemble. All operations are deterministic for

lossless codec symmetry.



Uses flat numpy arrays for inner storage instead of nested Python

dicts. This eliminates millions of small dict objects (~3.6 GB →

~1 GB per worker) and replaces the O(K) Python iteration loop in

predict() with a single numpy fancy-indexing call that runs at

C level, drastically reducing GIL hold time with 8+ threads.

"""

import numpy as np


def _context_hash(context_tokens, order):
    """Deterministic 64-bit hash of the last *order* tokens.



    Replaces tuple(context_tokens[-order:]) as dict key, eliminating

    per-token tuple allocations and reducing GC pressure.

    """
    h = 0
    end = len(context_tokens)
    for i in range(end - order, end):
        h = (h * 49157 + context_tokens[i]) & 0xFFFFFFFFFFFFFFFF
    return h


class NgramModel:
    """Interpolated N-gram model operating on token IDs.



    Uses iterative interpolation: higher-order models are progressively

    blended with lower orders, weighted by context frequency. Unseen

    contexts fall back smoothly to lower orders down to unigram.



    The model updates online after each observed token, so it adapts

    to the specific document being compressed.



    Inner storage uses flat numpy arrays indexed by slot number.

    The outer dict (context_hash → slot) preserves insertion order

    for deterministic FIFO eviction.

    """

    # Smoothing constant for interpolation weights.
    ESCAPE = 5

    # Maximum context entries per order.
    MAX_TABLE_ENTRIES = 500_000

    # Maximum unique continuations per context.
    MAX_INNER_ENTRIES = 64

    def __init__(self, max_order: int = 4, vocab_size: int = 49152):
        self.max_order = max_order
        self.vocab_size = vocab_size

        # Order-0 (unigram) counts: dense array for fast vector ops.
        self.unigram_counts = np.zeros(vocab_size, dtype=np.float64)
        self.total_unigram = 0

        # Order 1..N: context_hash → slot_index.
        # Python dict preserves insertion order for FIFO eviction.
        self._slot_map: list = [None] + [dict() for _ in range(max_order)]

        # Flat inner storage per order.  Each context maps to a "slot"
        # containing up to MAX_INNER_ENTRIES (token_id, count) pairs.
        # Entries within a slot are kept in insertion order so that
        # argmin tie-breaking matches the old dict-based behavior.
        self._inner_ids: list = [None] + [
            np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
                     dtype=np.int32)
            for _ in range(max_order)
        ]
        self._inner_counts: list = [None] + [
            np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
                     dtype=np.int32)
            for _ in range(max_order)
        ]
        self._inner_sizes: list = [None] + [
            np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int16)
            for _ in range(max_order)
        ]
        self._ctx_totals: list = [None] + [
            np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int32)
            for _ in range(max_order)
        ]

        # Slot allocation: sequential counter + free list for recycling.
        self._next_slot = [0] * (max_order + 1)
        self._free_slots: list = [None] + [[] for _ in range(max_order)]

        # Pre-allocated buffers for building order predictions.
        self._buf = np.zeros(vocab_size, dtype=np.float64)
        self._probs = np.zeros(vocab_size, dtype=np.float64)

    def reset(self):
        """Reset all counts. Call when starting a new sequence."""
        self.unigram_counts[:] = 0
        self.total_unigram = 0
        self._slot_map = [None] + [dict() for _ in range(self.max_order)]
        self._next_slot = [0] * (self.max_order + 1)
        self._free_slots = [None] + [[] for _ in range(self.max_order)]
        for order in range(1, self.max_order + 1):
            self._inner_sizes[order][:] = 0
            self._ctx_totals[order][:] = 0

    def predict(self, context_tokens: list[int]) -> np.ndarray:
        """Predict next-token distribution given context.



        Uses numpy fancy indexing instead of Python dict iteration,

        replacing up to 256 Python loop iterations with C-level

        array operations that minimize GIL hold time.



        Args:

            context_tokens: List of preceding token IDs.



        Returns:

            numpy array of shape (vocab_size,) with probabilities summing to ~1.

        """
        # Start with unigram (Laplace-smoothed)
        probs = self._probs
        np.add(self.unigram_counts, 1.0, out=probs)
        probs /= (self.total_unigram + self.vocab_size)

        for order in range(1, self.max_order + 1):
            if len(context_tokens) < order:
                break

            ctx = _context_hash(context_tokens, order)
            slot = self._slot_map[order].get(ctx)
            if slot is None:
                continue

            total = int(self._ctx_totals[order][slot])
            if total == 0:
                continue

            lam = total / (total + self.ESCAPE)

            # Vectorized inner loop: single numpy fancy-index call
            # replaces K Python dict iterations (K up to 64).
            buf = self._buf
            buf[:] = 0
            size = int(self._inner_sizes[order][slot])
            ids = self._inner_ids[order][slot, :size]
            cts = self._inner_counts[order][slot, :size]
            buf[ids] = cts   # C-level scatter — the key optimization
            buf /= buf.sum()

            # Blend: probs = lam * order_k + (1-lam) * probs
            probs *= (1.0 - lam)
            buf *= lam
            probs += buf

        return probs

    def _alloc_slot(self, order: int) -> int:
        """Get a free slot index, recycling evicted slots first."""
        if self._free_slots[order]:
            return self._free_slots[order].pop()
        slot = self._next_slot[order]
        self._next_slot[order] += 1
        return slot

    def update(self, context_tokens: list[int], actual_token: int):
        """Update counts after observing a token.



        Must be called identically during compression and decompression

        to maintain codec symmetry.



        Args:

            context_tokens: Context that preceded the token.

            actual_token: The token that was actually observed.

        """
        # Update unigram
        self.unigram_counts[actual_token] += 1
        self.total_unigram += 1

        # Update higher orders
        for order in range(1, self.max_order + 1):
            if len(context_tokens) < order:
                break

            ctx = _context_hash(context_tokens, order)
            slot_map = self._slot_map[order]

            # Evict oldest context if table is full and this is new
            if ctx not in slot_map and len(slot_map) >= self.MAX_TABLE_ENTRIES:
                evict_ctx = next(iter(slot_map))
                evict_slot = slot_map.pop(evict_ctx)
                self._free_slots[order].append(evict_slot)

            if ctx in slot_map:
                slot = slot_map[ctx]
                size = int(self._inner_sizes[order][slot])
                ids = self._inner_ids[order][slot]
                counts = self._inner_counts[order][slot]

                # Search for actual_token (numpy vectorized)
                mask = ids[:size] == actual_token
                if mask.any():
                    # Token exists: increment its count
                    idx = int(np.argmax(mask))
                    counts[idx] += 1
                    self._ctx_totals[order][slot] += 1
                elif size < self.MAX_INNER_ENTRIES:
                    # New token, space available: append
                    ids[size] = actual_token
                    counts[size] = 1
                    self._inner_sizes[order][slot] = size + 1
                    self._ctx_totals[order][slot] += 1
                else:
                    # Full (64 entries). Simulate the original add-then-evict:
                    # new entry has count=1, evicted entry has count ≤ 1 = 1,
                    # so net total change is always 0.
                    min_count = int(counts[:size].min())
                    if min_count == 1:
                        # Evict oldest entry with count=1, add new at end.
                        # Shift maintains insertion order so argmin
                        # tie-breaking matches the original dict behavior.
                        min_idx = int(np.argmin(counts[:size]))
                        if min_idx < size - 1:
                            ids[min_idx:size-1] = ids[min_idx+1:size]
                            counts[min_idx:size-1] = counts[min_idx+1:size]
                        ids[size - 1] = actual_token
                        counts[size - 1] = 1
                    # else: min_count > 1, new entry would be sole minimum
                    # and immediately evicted — no-op on entries and total.
            else:
                # New context: allocate a slot
                slot = self._alloc_slot(order)
                slot_map[ctx] = slot
                self._inner_ids[order][slot, 0] = actual_token
                self._inner_counts[order][slot, 0] = 1
                self._inner_sizes[order][slot] = 1
                self._ctx_totals[order][slot] = 1