SpyC0der77 commited on
Commit
b50b6f2
·
verified ·
1 Parent(s): 1168206

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -145
README.md CHANGED
@@ -15,6 +15,12 @@ tags:
15
 
16
  This directory contains the improved v2 artifact classification model with state-of-the-art performance for classifying museum artifacts by both object type and material.
17
 
 
 
 
 
 
 
18
  ## Model Overview
19
 
20
  The v2 model is an advanced multi-output neural network that predicts two attributes simultaneously:
@@ -27,72 +33,19 @@ The v2 model is an advanced multi-output neural network that predicts two attrib
27
  - **Advanced Training**: Incorporates CutMix augmentation, Focal Loss, and mixed precision training
28
  - **Better Regularization**: Uses dropout and batch normalization for improved generalization
29
 
30
- ## Quick Start
31
-
32
- ### Prerequisites
33
-
34
- Ensure you have the required dependencies installed:
35
-
36
- ```bash
37
- pip install torch>=2.0.0 torchvision>=0.15.0 datasets>=2.0.0 pillow>=9.0.0 timm>=1.0.22 huggingface-hub>=0.15.0
38
- ```
39
-
40
- ### Basic Inference
41
 
42
- ```python
43
- import torch
44
- from PIL import Image
45
- from torchvision import transforms
46
- import sys
47
- import os
48
-
49
- # Add the project root to Python path
50
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
51
-
52
- from main import load_model, run_inference
53
-
54
- # Load the model
55
- model_path = "model/v2/best_model.pth"
56
- model, label_mappings = load_model(model_path)
57
-
58
- # Prepare image
59
- image_path = "path/to/your/artifact.jpg"
60
- image = Image.open(image_path).convert('RGB')
61
-
62
- # Preprocessing transform
63
- transform = transforms.Compose([
64
- transforms.Resize(256),
65
- transforms.CenterCrop(224),
66
- transforms.ToTensor(),
67
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
68
- ])
69
-
70
- pixel_values = transform(image).unsqueeze(0) # Add batch dimension
71
-
72
- # Run inference
73
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
- preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)
75
-
76
- # Get predictions
77
- object_pred_id = preds_obj[0].item()
78
- material_pred_id = preds_mat[0].item()
79
- object_conf = confs_obj[0].item()
80
- material_conf = confs_mat[0].item()
81
-
82
- # Convert IDs to labels
83
- object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
84
- material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")
85
-
86
- print(f"Predicted Object: {object_name} (confidence: {object_conf:.3f})")
87
- print(f"Predicted Material: {material_name} (confidence: {material_conf:.3f})")
88
- ```
89
 
90
- ## Model Files
 
 
91
 
92
- - **`best_model.pth`**: The best performing model checkpoint with trained weights and label mappings
93
- - **`model_improved.pth`**: Final model after complete training
94
- - **`checkpoint_epoch_*.pth`**: Intermediate checkpoints saved during training
95
- - **`train.py`**: Training script used to create this model
 
96
 
97
  ## Model Architecture
98
 
@@ -115,38 +68,6 @@ Returns a dictionary with:
115
  - `'object_name'`: Logits for object classification
116
  - `'material'`: Logits for material classification
117
 
118
- ## Evaluation
119
-
120
- ### Using the Main Evaluation Script
121
-
122
- To evaluate the model on the Oriental Museum dataset:
123
-
124
- ```bash
125
- # Evaluate on validation set
126
- python main.py --model_file model/v2/best_model.pth --output eval_results_v2.json
127
-
128
- # Evaluate with custom batch size
129
- python main.py --model_file model/v2/best_model.pth --batch_size 16 --output eval_results_v2.json
130
- ```
131
-
132
- ### Evaluation Metrics
133
-
134
- The evaluation script provides:
135
- - **Object Classification Accuracy**: Accuracy for object name prediction
136
- - **Material Classification Accuracy**: Accuracy for material prediction
137
- - **Overall Accuracy**: Samples where both predictions are correct
138
- - **Confidence Analysis**: Average confidence for correct vs incorrect predictions
139
- - **Per-sample Predictions**: Detailed results for each test sample
140
-
141
- ### Expected Performance
142
-
143
- Based on validation during training:
144
- - Object Classification: ~85-90% accuracy
145
- - Material Classification: ~80-85% accuracy
146
- - Overall Accuracy: ~75-80% accuracy
147
-
148
- *Note: Actual performance may vary depending on the evaluation dataset and preprocessing.*
149
-
150
  ## Training Details
151
 
152
  The model was trained with the following configuration:
@@ -171,56 +92,6 @@ The model was trained with the following configuration:
171
  - **Gradient Scaling**: Prevents gradient underflow
172
  - **Early Stopping**: Saves best model based on validation accuracy
173
 
174
- ## Usage Examples
175
-
176
- ### Batch Inference
177
-
178
- ```python
179
- import torch
180
- from PIL import Image
181
- from torchvision import transforms
182
- import sys
183
- import os
184
-
185
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
186
- from main import load_model, run_inference
187
-
188
- # Load model
189
- model, label_mappings = load_model("model/v2/best_model.pth")
190
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
-
192
- # Load multiple images
193
- image_paths = ["artifact1.jpg", "artifact2.jpg", "artifact3.jpg"]
194
- images = []
195
-
196
- transform = transforms.Compose([
197
- transforms.Resize(256),
198
- transforms.CenterCrop(224),
199
- transforms.ToTensor(),
200
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
201
- ])
202
-
203
- for path in image_paths:
204
- img = Image.open(path).convert('RGB')
205
- images.append(transform(img))
206
-
207
- # Batch tensor
208
- batch = torch.stack(images)
209
-
210
- # Run inference
211
- preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, batch, device)
212
-
213
- # Process results
214
- for i, (obj_pred, obj_conf, mat_pred, mat_conf) in enumerate(zip(preds_obj, confs_obj, preds_mat, confs_mat)):
215
- obj_name = label_mappings['object_name'].get(obj_pred.item(), f"class_{obj_pred.item()}")
216
- mat_name = label_mappings['material'].get(mat_pred.item(), f"class_{mat_pred.item()}")
217
-
218
- print(f"Image {i+1}:")
219
- print(f" Object: {obj_name} ({obj_conf:.3f})")
220
- print(f" Material: {mat_name} ({mat_conf:.3f})")
221
- ```
222
-
223
-
224
  ## Troubleshooting
225
 
226
  ### Common Issues
 
15
 
16
  This directory contains the improved v2 artifact classification model with state-of-the-art performance for classifying museum artifacts by both object type and material.
17
 
18
+ ## Hosted Model
19
+
20
+ The best model is available on Hugging Face at: **[SpyC0der77/artifact-efficientnet](https://huggingface.co/SpyC0der77/artifact-efficientnet)**
21
+
22
+ You can use the model directly from Hugging Face without downloading it locally.
23
+
24
  ## Model Overview
25
 
26
  The v2 model is an advanced multi-output neural network that predicts two attributes simultaneously:
 
33
  - **Advanced Training**: Incorporates CutMix augmentation, Focal Loss, and mixed precision training
34
  - **Better Regularization**: Uses dropout and batch normalization for improved generalization
35
 
36
+ ## Architecture & Usage
 
 
 
 
 
 
 
 
 
 
37
 
38
+ The v2 model uses an EfficientNet-B0 backbone with an attention mechanism for multi-output classification. It processes RGB images of artifacts and outputs predictions for both object type and material composition.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ ### Input
41
+ - **Format**: RGB images (224×224 pixels after preprocessing)
42
+ - **Preprocessing**: Resize to 256×256, center crop to 224×224, normalize with ImageNet statistics
43
 
44
+ ### Output
45
+ - **Object Classification**: Predicts artifact type (e.g., "vase", "statue", "pottery")
46
+ - **Material Classification**: Predicts material composition (e.g., "ceramic", "bronze", "stone")
47
+ - **Confidence Scores**: Probability scores for each prediction
48
+ - **Format**: Dictionary with 'object_name' and 'material' logits
49
 
50
  ## Model Architecture
51
 
 
68
  - `'object_name'`: Logits for object classification
69
  - `'material'`: Logits for material classification
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ## Training Details
72
 
73
  The model was trained with the following configuration:
 
92
  - **Gradient Scaling**: Prevents gradient underflow
93
  - **Early Stopping**: Saves best model based on validation accuracy
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  ## Troubleshooting
96
 
97
  ### Common Issues