Ziyuan111 commited on
Commit
9942f0f
·
verified ·
1 Parent(s): 3bbc97b

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +11 -3
utils.py CHANGED
@@ -11,6 +11,8 @@ from transformers import (
11
  BlipProcessor,
12
  BlipForConditionalGeneration
13
  )
 
 
14
 
15
  MEAN = [0.48145466, 0.4578275, 0.40821073]
16
  STD = [0.26862954, 0.26130258, 0.27577711]
@@ -22,9 +24,15 @@ def load_models(device):
22
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
  return clip_model, clip_processor, blip_model, blip_processor
24
 
25
- def load_data(parquet_path="src/food101_embeddings_10000.parquet"):
26
- table = pq.read_table(parquet_path)
27
- df = table.to_pandas()
 
 
 
 
 
 
28
  embeddings = np.vstack(df["embedding"].to_numpy())
29
  return df, embeddings
30
 
 
11
  BlipProcessor,
12
  BlipForConditionalGeneration
13
  )
14
+ import zipfile
15
+ import os
16
 
17
  MEAN = [0.48145466, 0.4578275, 0.40821073]
18
  STD = [0.26862954, 0.26130258, 0.27577711]
 
24
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
  return clip_model, clip_processor, blip_model, blip_processor
26
 
27
+
28
+ def load_data(parquet_path="food101_embeddings_10000.parquet"):
29
+
30
+ if not os.path.exists("food_images"):
31
+ with zipfile.ZipFile("food_images_10000.zip", "r") as zip_ref:
32
+ zip_ref.extractall("food_images")
33
+
34
+ df = pd.read_parquet(parquet_path)
35
+ df["image_path"] = df["image_path"].apply(lambda p: os.path.join("food_images", os.path.basename(p)))
36
  embeddings = np.vstack(df["embedding"].to_numpy())
37
  return df, embeddings
38