|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
import requests |
|
|
import io |
|
|
import sys |
|
|
|
|
|
|
|
|
ONNX_MODEL_PATH = "./lsnet_xl_artist-dynamo-opset18_merged.onnx" |
|
|
CSV_PATH = "./class_mapping.csv" |
|
|
IMAGE_URL = "https://cdn.donmai.us/sample/9f/bb/__vampire_s_sister_original_drawn_by_gogalking__sample-9fbb30aa76bdc8242a1c122d3d6b41d9.jpg" |
|
|
IMAGE_SIZE = (224, 224) |
|
|
|
|
|
TOP_K = 5 |
|
|
PREDICTION_THRESHOLD = 0.0 |
|
|
|
|
|
|
|
|
def preprocess_image_from_url(image_url, size=(224, 224)): |
|
|
""" |
|
|
Downloads an image from a URL, preprocesses it, and prepares it for the model. |
|
|
""" |
|
|
try: |
|
|
response = requests.get(image_url) |
|
|
response.raise_for_status() |
|
|
image_bytes = io.BytesIO(response.content) |
|
|
image = Image.open(image_bytes).convert("RGB") |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error: Failed to download image from URL '{image_url}'.\nDetails: {e}") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"Error: Could not process the downloaded image. It may not be a valid image file.\nDetails: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
image = image.resize(size, Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
image_np = np.array(image, dtype=np.float32) / 255.0 |
|
|
|
|
|
|
|
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
|
|
std = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
|
|
|
|
|
|
|
|
normalized_image = (image_np - mean) / std |
|
|
|
|
|
|
|
|
transposed_image = normalized_image.transpose((2, 0, 1)) |
|
|
|
|
|
|
|
|
batched_image = np.expand_dims(transposed_image, axis=0) |
|
|
|
|
|
return batched_image |
|
|
|
|
|
def load_labels(csv_path): |
|
|
""" |
|
|
Loads the class labels from the provided CSV file into a dictionary, |
|
|
handling the header row and stripping quotes from names. |
|
|
""" |
|
|
try: |
|
|
df = pd.read_csv(csv_path) |
|
|
if 'class_id' not in df.columns or 'class_name' not in df.columns: |
|
|
print(f"Error: CSV file must have 'class_id' and 'class_name' columns.") |
|
|
sys.exit(1) |
|
|
df['class_name'] = df['class_name'].str.strip("'") |
|
|
return dict(zip(df['class_id'], df['class_name'])) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: CSV file not found at '{csv_path}'") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"Error reading CSV file: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
def softmax(x): |
|
|
"""Compute softmax values for a set of scores.""" |
|
|
e_x = np.exp(x - np.max(x)) |
|
|
return e_x / e_x.sum(axis=0) |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to run the ONNX model inference. |
|
|
""" |
|
|
print("1. Loading class labels...") |
|
|
labels = load_labels(CSV_PATH) |
|
|
print(f" Loaded {len(labels)} labels.") |
|
|
|
|
|
print("\n2. Downloading and preprocessing image from URL...") |
|
|
input_tensor = preprocess_image_from_url(IMAGE_URL, IMAGE_SIZE) |
|
|
print(f" Image shape: {input_tensor.shape}, Data type: {input_tensor.dtype}") |
|
|
|
|
|
print("\n3. Initializing ONNX runtime session...") |
|
|
try: |
|
|
session = ort.InferenceSession(ONNX_MODEL_PATH, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
|
input_name = session.get_inputs()[0].name |
|
|
output_name = session.get_outputs()[0].name |
|
|
print(" ONNX session created successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading ONNX model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("\n4. Running inference...") |
|
|
results = session.run([output_name], {input_name: input_tensor}) |
|
|
logits = results[0][0] |
|
|
print(" Inference complete.") |
|
|
|
|
|
print("\n5. Processing results...") |
|
|
probabilities = softmax(logits) |
|
|
top_k_indices = np.argsort(probabilities)[-TOP_K:][::-1] |
|
|
|
|
|
print(f"\n--- Predictions for image URL (Top K: {TOP_K}, Threshold: {PREDICTION_THRESHOLD:.1%}) ---") |
|
|
|
|
|
predictions_found = 0 |
|
|
for i, index in enumerate(top_k_indices): |
|
|
score = probabilities[index] |
|
|
if score >= PREDICTION_THRESHOLD: |
|
|
class_name = labels.get(index, f"Unknown Class #{index}") |
|
|
print(f"Rank {i+1}: {class_name} (Score: {score:.2%})") |
|
|
predictions_found += 1 |
|
|
|
|
|
if predictions_found == 0: |
|
|
print("No predictions met the specified threshold.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |