lorien-danger commited on
Commit
2609847
·
verified ·
1 Parent(s): 956b427

Update index.html

Browse files
Files changed (1) hide show
  1. index.html +62 -58
index.html CHANGED
@@ -77,6 +77,7 @@
77
  // ===== 1) Config =====
78
  const MODEL_ID = "onnx-community/ijepa_vith14_1k"; // <-- I-JEPA ViT-H/14, ImageNet-1k
79
  const EXAMPLE_IMAGE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png";
 
80
 
81
  // DOM
82
  const imageLoader = document.getElementById("imageLoader");
@@ -104,6 +105,7 @@
104
  let animationFrameId = null;
105
  let lastMouseEvent = null;
106
  let maxPixels = null;
 
107
 
108
  function updateStatus(text, isLoading=false){
109
  statusText.textContent = text;
@@ -124,9 +126,7 @@
124
 
125
  try{
126
  extractor = await pipeline("image-feature-extraction", MODEL_ID, { device, dtype });
127
- // Try to fetch config-provided patch size if present
128
  patchSize = extractor?.model?.config?.patch_size ?? patchSize;
129
- // Avoid internal resizes — we control canvas dims
130
  if (extractor?.processor?.image_processor) extractor.processor.image_processor.do_resize = false;
131
  updateStatus("Ready. Please select an image.");
132
  }catch(e){
@@ -154,15 +154,9 @@
154
  const res = await fetch(EXAMPLE_IMAGE_URL);
155
  const blob = await res.blob();
156
  loadImageOntoCanvas(URL.createObjectURL(blob));
157
- }catch(e){
158
- console.error(e);
159
- updateStatus("Failed to load example image.");
160
- }
161
- }
162
- function handleImageUpload(e){
163
- const f = e.target.files?.[0];
164
- if (f) loadImageOntoCanvas(URL.createObjectURL(f));
165
  }
 
166
  function handleDragOver(e){ e.preventDefault(); dropZone.classList.add("border-blue-500","bg-gray-800"); }
167
  function handleDragLeave(e){ e.preventDefault(); dropZone.classList.remove("border-blue-500","bg-gray-800"); }
168
  function handleDrop(e){
@@ -180,6 +174,21 @@
180
  function handleSliderInput(e){ imageScale = parseFloat(e.target.value); scaleValue.textContent = `${imageScale.toFixed(2)}x`; }
181
  function handleSliderChange(){ if (currentImageUrl) loadImageOntoCanvas(currentImageUrl); }
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  function loadImageOntoCanvas(url){
184
  currentImageUrl = url;
185
  originalImage = new Image();
@@ -187,63 +196,68 @@
187
  if (!patchSize){ updateStatus("Error: patch size unknown."); return; }
188
  canvasPlaceholder.style.display = "none";
189
  imageCanvas.style.display = "block";
190
-
191
- let newW = originalImage.naturalWidth * imageScale;
192
- let newH = originalImage.naturalHeight * imageScale;
193
-
194
- const px = newW * newH;
195
- if (px > maxPixels){
196
- const r = Math.sqrt(maxPixels / px);
197
- newW *= r; newH *= r;
 
 
 
 
 
198
  }
199
-
200
- const croppedW = Math.floor(newW / patchSize) * patchSize;
201
- const croppedH = Math.floor(newH / patchSize) * patchSize;
202
- if (croppedW < patchSize || croppedH < patchSize){
203
- updateStatus("Scaled image is too small to process.");
204
- imageCanvas.style.display = "none";
205
- canvasPlaceholder.style.display = "block";
206
- canvasPlaceholder.textContent = "Scaled image is too small.";
207
- return;
208
  }
 
 
 
209
 
210
- imageCanvas.width = croppedW;
211
- imageCanvas.height = croppedH;
212
- ctx.drawImage(originalImage, 0, 0, croppedW, croppedH);
213
- await processImage();
214
  setTimeout(() => { canvasContainer.scrollIntoView({ behavior: "smooth", block: "center" }); }, 100);
215
  };
216
  originalImage.onerror = () => { updateStatus("Failed to load the selected image."); canvasPlaceholder.style.display = "block"; imageCanvas.style.display = "none"; };
217
  originalImage.src = url;
218
  }
219
 
220
- async function processImage(){
221
  if (!extractor) return;
222
  updateStatus("Analyzing with I‑JEPA... 🧠", true);
223
  similarityScores = null; lastHoverData = null;
224
  try{
225
  const imageData = await RawImage.fromCanvas(imageCanvas);
226
- // No pooling: we want per-token outputs
227
  const features = await extractor(imageData, { pooling: "none" }); // [1, T, D]
228
 
229
- // Compute how many tokens are patches vs special tokens robustly.
230
  const totalTokens = features.dims?.[1] ?? features.shape?.[1] ?? features.size?.[1];
231
  const nPatches = (imageCanvas.width / patchSize) * (imageCanvas.height / patchSize);
232
  const specialTokens = Math.max(0, totalTokens - nPatches);
233
 
234
- const patchFeatures = features.slice(null, [specialTokens, nPatches]); // [1, nPatches, D]
235
  const normalized = patchFeatures.normalize(2, -1);
236
- const sims = await matmul(normalized, normalized.permute(0,2,1)); // [1, nPatches, nPatches]
237
  similarityScores = (await sims.tolist())[0];
238
 
239
- updateStatus(`Image processed (${imageCanvas.width}×${imageCanvas.height}). Hover to explore. ✨`);
240
  }catch(err){
241
  console.error("I‑JEPA processing error:", err);
242
- updateStatus("An error occurred during image processing.");
243
  }
244
  }
245
 
246
- function handleTouchMove(e){ e.preventDefault(); if (e.touches.length>0) handleMouseMove(e.touches[0]); }
247
  function handleMouseMove(e){ lastMouseEvent = e; if (!animationFrameId) animationFrameId = requestAnimationFrame(drawLoop); }
248
 
249
  function drawLoop(){
@@ -253,7 +267,7 @@
253
  const scaleY = imageCanvas.height / rect.height;
254
  const x = (lastMouseEvent.clientX - rect.left) * scaleX;
255
  const y = (lastMouseEvent.clientY - rect.top) * scaleY;
256
- if (x<0||x>=imageCanvas.width||y<0||y>=imageCanvas.height){ animationFrameId = null; return; }
257
 
258
  const patchesPerRow = imageCanvas.width / patchSize;
259
  const patchX = Math.floor(x / patchSize);
@@ -267,36 +281,26 @@
267
  animationFrameId = null;
268
  }
269
 
270
- const INFERNO_COLORMAP = [
271
- [0.0,[0,0,4]],[0.1,[39,12,69]],[0.2,[84,15,104]],[0.3,[128,31,103]],[0.4,[170,48,88]],
272
- [0.5,[209,70,68]],[0.6,[240,97,47]],[0.7,[253,138,28]],[0.8,[252,185,26]],[0.9,[240,231,56]],[1.0,[252,255,160]]
273
- ];
274
- function getInfernoColor(t){
275
- for (let i=1;i<INFERNO_COLORMAP.length;i++){
276
- const [tp,cp]=INFERNO_COLORMAP[i-1]; const [tc,cc]=INFERNO_COLORMAP[i];
277
- if (t<=tc){ const a=(t-tp)/(tc-tp); const r=cp[0]+a*(cc[0]-cp[0]); const g=cp[1]+a*(cc[1]-cp[1]); const b=cp[2]+a*(cc[2]-cp[2]); return `rgb(${Math.round(r)}, ${Math.round(g)}, ${Math.round(b)})`; }
278
- }
279
- const last=INFERNO_COLORMAP[INFERNO_COLORMAP.length-1][1];
280
- return `rgb(${last.join(",")})`;
281
- }
282
 
283
  function drawHighlights(queryIndex, allPatches){
284
  const patchesPerRow = imageCanvas.width / patchSize;
285
  if (isOverlayMode){
286
- ctx.drawImage(originalImage, 0, 0, imageCanvas.width, imageCanvas.height);
287
  ctx.fillStyle = "rgba(0,0,0,0.6)"; ctx.fillRect(0,0,imageCanvas.width,imageCanvas.height);
288
  } else {
289
  ctx.fillStyle = getInfernoColor(0); ctx.fillRect(0,0,imageCanvas.width,imageCanvas.height);
290
  }
291
- if (allPatches.length>0){
292
- const scores = allPatches.map(p=>p.score);
293
  const minS = Math.min(...scores); const maxS = Math.max(...scores); const rng = maxS - minS;
294
  for (const p of allPatches){
295
  if (p.index === queryIndex) continue;
296
- const t = rng > 1e-4 ? (p.score - minS)/rng : 1;
297
  const py = Math.floor(p.index / patchesPerRow);
298
  const px = p.index % patchesPerRow;
299
- if (isOverlayMode){ const a = Math.pow(t,2)*0.8; ctx.fillStyle = `rgba(255,255,255,${a})`; }
300
  else { ctx.fillStyle = getInfernoColor(t); }
301
  ctx.fillRect(px*patchSize, py*patchSize, patchSize, patchSize);
302
  }
@@ -310,7 +314,7 @@
310
  function clearHighlights(){
311
  if (animationFrameId){ cancelAnimationFrame(animationFrameId); animationFrameId = null; }
312
  lastMouseEvent = null; lastHoverData = null;
313
- if (originalImage){ ctx.drawImage(originalImage, 0, 0, imageCanvas.width, imageCanvas.height); }
314
  }
315
 
316
  initialize();
 
77
  // ===== 1) Config =====
78
  const MODEL_ID = "onnx-community/ijepa_vith14_1k"; // <-- I-JEPA ViT-H/14, ImageNet-1k
79
  const EXAMPLE_IMAGE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png";
80
+ const SUPPORTED_RESOLUTIONS = [224, 336, 448]; // Specific resolutions I-JEPA ONNX might expect
81
 
82
  // DOM
83
  const imageLoader = document.getElementById("imageLoader");
 
105
  let animationFrameId = null;
106
  let lastMouseEvent = null;
107
  let maxPixels = null;
108
+ let imageCropParams = null; // To store crop parameters for consistent redraws
109
 
110
  function updateStatus(text, isLoading=false){
111
  statusText.textContent = text;
 
126
 
127
  try{
128
  extractor = await pipeline("image-feature-extraction", MODEL_ID, { device, dtype });
 
129
  patchSize = extractor?.model?.config?.patch_size ?? patchSize;
 
130
  if (extractor?.processor?.image_processor) extractor.processor.image_processor.do_resize = false;
131
  updateStatus("Ready. Please select an image.");
132
  }catch(e){
 
154
  const res = await fetch(EXAMPLE_IMAGE_URL);
155
  const blob = await res.blob();
156
  loadImageOntoCanvas(URL.createObjectURL(blob));
157
+ }catch(e){ console.error(e); updateStatus("Failed to load example image."); }
 
 
 
 
 
 
 
158
  }
159
+ function handleImageUpload(e){ if (e.target.files?.[0]) loadImageOntoCanvas(URL.createObjectURL(e.target.files[0])); }
160
  function handleDragOver(e){ e.preventDefault(); dropZone.classList.add("border-blue-500","bg-gray-800"); }
161
  function handleDragLeave(e){ e.preventDefault(); dropZone.classList.remove("border-blue-500","bg-gray-800"); }
162
  function handleDrop(e){
 
174
  function handleSliderInput(e){ imageScale = parseFloat(e.target.value); scaleValue.textContent = `${imageScale.toFixed(2)}x`; }
175
  function handleSliderChange(){ if (currentImageUrl) loadImageOntoCanvas(currentImageUrl); }
176
 
177
+ function findClosestSupportedResolution(targetDim) {
178
+ return SUPPORTED_RESOLUTIONS.reduce((prev, curr) =>
179
+ Math.abs(curr - targetDim) < Math.abs(prev - targetDim) ? curr : prev
180
+ );
181
+ }
182
+
183
+ function redrawOriginalImage() {
184
+ if (!originalImage || !imageCropParams) return;
185
+ ctx.drawImage(
186
+ originalImage,
187
+ imageCropParams.sx, imageCropParams.sy, imageCropParams.sWidth, imageCropParams.sHeight,
188
+ 0, 0, imageCanvas.width, imageCanvas.height
189
+ );
190
+ }
191
+
192
  function loadImageOntoCanvas(url){
193
  currentImageUrl = url;
194
  originalImage = new Image();
 
196
  if (!patchSize){ updateStatus("Error: patch size unknown."); return; }
197
  canvasPlaceholder.style.display = "none";
198
  imageCanvas.style.display = "block";
199
+
200
+ const { naturalWidth: w, naturalHeight: h } = originalImage;
201
+
202
+ // --- NEW CENTER-CROP LOGIC ---
203
+ const cropSize = Math.min(w, h);
204
+ const sx = (w - cropSize) / 2;
205
+ const sy = (h - cropSize) / 2;
206
+ imageCropParams = { sx, sy, sWidth: cropSize, sHeight: cropSize };
207
+
208
+ // Determine target canvas resolution
209
+ let scaledCropSize = cropSize * imageScale;
210
+ if (scaledCropSize * scaledCropSize > maxPixels) {
211
+ scaledCropSize = Math.sqrt(maxPixels);
212
  }
213
+ let chosenResolution = findClosestSupportedResolution(scaledCropSize);
214
+
215
+ // Ensure chosen resolution is at least one patch size
216
+ if (chosenResolution < patchSize) {
217
+ updateStatus("Scaled image is too small to process.");
218
+ imageCanvas.style.display = "none";
219
+ canvasPlaceholder.style.display = "block";
220
+ canvasPlaceholder.textContent = "Scaled image is too small.";
221
+ return;
222
  }
223
+
224
+ imageCanvas.width = chosenResolution;
225
+ imageCanvas.height = chosenResolution;
226
 
227
+ redrawOriginalImage(); // Initial draw with center-crop and resize
228
+
229
+ await processImage(chosenResolution);
 
230
  setTimeout(() => { canvasContainer.scrollIntoView({ behavior: "smooth", block: "center" }); }, 100);
231
  };
232
  originalImage.onerror = () => { updateStatus("Failed to load the selected image."); canvasPlaceholder.style.display = "block"; imageCanvas.style.display = "none"; };
233
  originalImage.src = url;
234
  }
235
 
236
+ async function processImage(chosenResolution){
237
  if (!extractor) return;
238
  updateStatus("Analyzing with I‑JEPA... 🧠", true);
239
  similarityScores = null; lastHoverData = null;
240
  try{
241
  const imageData = await RawImage.fromCanvas(imageCanvas);
 
242
  const features = await extractor(imageData, { pooling: "none" }); // [1, T, D]
243
 
 
244
  const totalTokens = features.dims?.[1] ?? features.shape?.[1] ?? features.size?.[1];
245
  const nPatches = (imageCanvas.width / patchSize) * (imageCanvas.height / patchSize);
246
  const specialTokens = Math.max(0, totalTokens - nPatches);
247
 
248
+ const patchFeatures = features.slice(null, [specialTokens, nPatches]);
249
  const normalized = patchFeatures.normalize(2, -1);
250
+ const sims = await matmul(normalized, normalized.permute(0,2,1));
251
  similarityScores = (await sims.tolist())[0];
252
 
253
+ updateStatus(`Image processed at ${chosenResolution}×${chosenResolution}. Hover to explore. ✨`);
254
  }catch(err){
255
  console.error("I‑JEPA processing error:", err);
256
+ updateStatus("An error occurred during processing. The image size might be unsupported.");
257
  }
258
  }
259
 
260
+ function handleTouchMove(e){ e.preventDefault(); if (e.touches.length > 0) handleMouseMove(e.touches[0]); }
261
  function handleMouseMove(e){ lastMouseEvent = e; if (!animationFrameId) animationFrameId = requestAnimationFrame(drawLoop); }
262
 
263
  function drawLoop(){
 
267
  const scaleY = imageCanvas.height / rect.height;
268
  const x = (lastMouseEvent.clientX - rect.left) * scaleX;
269
  const y = (lastMouseEvent.clientY - rect.top) * scaleY;
270
+ if (x < 0 || x >= imageCanvas.width || y < 0 || y >= imageCanvas.height){ animationFrameId = null; return; }
271
 
272
  const patchesPerRow = imageCanvas.width / patchSize;
273
  const patchX = Math.floor(x / patchSize);
 
281
  animationFrameId = null;
282
  }
283
 
284
+ const INFERNO_COLORMAP = [ [0.0,[0,0,4]],[0.1,[39,12,69]],[0.2,[84,15,104]],[0.3,[128,31,103]],[0.4,[170,48,88]], [0.5,[209,70,68]],[0.6,[240,97,47]],[0.7,[253,138,28]],[0.8,[252,185,26]],[0.9,[240,231,56]],[1.0,[252,255,160]] ];
285
+ function getInfernoColor(t){ for (let i=1;i<INFERNO_COLORMAP.length;i++){ const [tp,cp]=INFERNO_COLORMAP[i-1]; const [tc,cc]=INFERNO_COLORMAP[i]; if (t<=tc){ const a=(t-tp)/(tc-tp); const r=cp[0]+a*(cc[0]-cp[0]); const g=cp[1]+a*(cc[1]-cp[1]); const b=cp[2]+a*(cc[2]-cp[2]); return `rgb(${Math.round(r)}, ${Math.round(g)}, ${Math.round(b)})`; } } const last=INFERNO_COLORMAP[INFERNO_COLORMAP.length-1][1]; return `rgb(${last.join(",")})`; }
 
 
 
 
 
 
 
 
 
 
286
 
287
  function drawHighlights(queryIndex, allPatches){
288
  const patchesPerRow = imageCanvas.width / patchSize;
289
  if (isOverlayMode){
290
+ redrawOriginalImage();
291
  ctx.fillStyle = "rgba(0,0,0,0.6)"; ctx.fillRect(0,0,imageCanvas.width,imageCanvas.height);
292
  } else {
293
  ctx.fillStyle = getInfernoColor(0); ctx.fillRect(0,0,imageCanvas.width,imageCanvas.height);
294
  }
295
+ if (allPatches.length > 0){
296
+ const scores = allPatches.map(p => p.score);
297
  const minS = Math.min(...scores); const maxS = Math.max(...scores); const rng = maxS - minS;
298
  for (const p of allPatches){
299
  if (p.index === queryIndex) continue;
300
+ const t = rng > 1e-4 ? (p.score - minS) / rng : 1;
301
  const py = Math.floor(p.index / patchesPerRow);
302
  const px = p.index % patchesPerRow;
303
+ if (isOverlayMode){ const a = Math.pow(t, 2) * 0.8; ctx.fillStyle = `rgba(255,255,255,${a})`; }
304
  else { ctx.fillStyle = getInfernoColor(t); }
305
  ctx.fillRect(px*patchSize, py*patchSize, patchSize, patchSize);
306
  }
 
314
  function clearHighlights(){
315
  if (animationFrameId){ cancelAnimationFrame(animationFrameId); animationFrameId = null; }
316
  lastMouseEvent = null; lastHoverData = null;
317
+ if (originalImage) redrawOriginalImage();
318
  }
319
 
320
  initialize();