Spaces:
Build error
Build error
Commit
·
78866a7
1
Parent(s):
14d54e2
Update backend/utils.py
Browse files- backend/utils.py +24 -21
backend/utils.py
CHANGED
|
@@ -17,14 +17,15 @@ from tqdm import trange
|
|
| 17 |
import torch
|
| 18 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 19 |
|
| 20 |
-
|
| 21 |
-
@st.cache_resource
|
| 22 |
def load_dataset(data_index):
|
| 23 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
| 24 |
dataset = pickle.load(file)
|
| 25 |
return dataset
|
| 26 |
|
| 27 |
-
@st.
|
|
|
|
| 28 |
def load_dataset_dict():
|
| 29 |
dataset_dict = {}
|
| 30 |
progress_empty = st.empty()
|
|
@@ -39,13 +40,15 @@ def load_dataset_dict():
|
|
| 39 |
return dataset_dict
|
| 40 |
|
| 41 |
|
| 42 |
-
@st.cache_data
|
|
|
|
| 43 |
def load_image(image_id):
|
| 44 |
dataset = load_dataset(image_id//10000)
|
| 45 |
image = dataset[image_id%10000]
|
| 46 |
return image
|
| 47 |
|
| 48 |
-
@st.cache_data
|
|
|
|
| 49 |
def load_images(image_ids):
|
| 50 |
images = []
|
| 51 |
for image_id in image_ids:
|
|
@@ -54,8 +57,8 @@ def load_images(image_ids):
|
|
| 54 |
return images
|
| 55 |
|
| 56 |
|
| 57 |
-
|
| 58 |
-
@st.cache_resource
|
| 59 |
def load_model(model_name):
|
| 60 |
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
| 61 |
if model_name == 'ResNet':
|
|
@@ -356,21 +359,21 @@ def _set_block_container_style(
|
|
| 356 |
)
|
| 357 |
|
| 358 |
|
| 359 |
-
@st.cache
|
| 360 |
-
def get_dataframe() -> pd.DataFrame():
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
|
| 369 |
|
| 370 |
-
def get_plotly_fig():
|
| 371 |
-
|
| 372 |
-
|
| 373 |
|
| 374 |
|
| 375 |
-
def get_matplotlib_plt():
|
| 376 |
-
|
|
|
|
| 17 |
import torch
|
| 18 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 19 |
|
| 20 |
+
@st.cache(allow_output_mutation=True)
|
| 21 |
+
# @st.cache_resource
|
| 22 |
def load_dataset(data_index):
|
| 23 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
| 24 |
dataset = pickle.load(file)
|
| 25 |
return dataset
|
| 26 |
|
| 27 |
+
@st.cache(allow_output_mutation=True)
|
| 28 |
+
# @st.cache_resource
|
| 29 |
def load_dataset_dict():
|
| 30 |
dataset_dict = {}
|
| 31 |
progress_empty = st.empty()
|
|
|
|
| 40 |
return dataset_dict
|
| 41 |
|
| 42 |
|
| 43 |
+
# @st.cache_data
|
| 44 |
+
@st.cache(allow_output_mutation=True)
|
| 45 |
def load_image(image_id):
|
| 46 |
dataset = load_dataset(image_id//10000)
|
| 47 |
image = dataset[image_id%10000]
|
| 48 |
return image
|
| 49 |
|
| 50 |
+
# @st.cache_data
|
| 51 |
+
@st.cache(allow_output_mutation=True)
|
| 52 |
def load_images(image_ids):
|
| 53 |
images = []
|
| 54 |
for image_id in image_ids:
|
|
|
|
| 57 |
return images
|
| 58 |
|
| 59 |
|
| 60 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
| 61 |
+
# @st.cache_resource
|
| 62 |
def load_model(model_name):
|
| 63 |
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
| 64 |
if model_name == 'ResNet':
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
|
| 362 |
+
# @st.cache
|
| 363 |
+
# def get_dataframe() -> pd.DataFrame():
|
| 364 |
+
# """Dummy DataFrame"""
|
| 365 |
+
# data = [
|
| 366 |
+
# {"quantity": 1, "price": 2},
|
| 367 |
+
# {"quantity": 3, "price": 5},
|
| 368 |
+
# {"quantity": 4, "price": 8},
|
| 369 |
+
# ]
|
| 370 |
+
# return pd.DataFrame(data)
|
| 371 |
|
| 372 |
|
| 373 |
+
# def get_plotly_fig():
|
| 374 |
+
# """Dummy Plotly Plot"""
|
| 375 |
+
# return px.line(data_frame=get_dataframe(), x="quantity", y="price")
|
| 376 |
|
| 377 |
|
| 378 |
+
# def get_matplotlib_plt():
|
| 379 |
+
# get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
|