DraconicDragon commited on
Commit
ec19f5c
·
verified ·
1 Parent(s): a8804a0

Update inference_onnx.py

Browse files
Files changed (1) hide show
  1. inference_onnx.py +69 -14
inference_onnx.py CHANGED
@@ -1,9 +1,8 @@
1
  """
2
- ONNX Inference implementation for LSNet models
3
  """
4
 
5
  import numpy as np
6
- import onnxruntime as ort
7
  from timm.data import resolve_data_config
8
  from timm.data.transforms_factory import create_transform
9
  from timm.models import create_model
@@ -20,6 +19,7 @@ class ONNXInference:
20
  def __init__(self, model_path, model_arch="lsnet_xl_artist", device="cpu"):
21
  """
22
  Initialize ONNX inference session
 
23
 
24
  Args:
25
  model_path: Path to ONNX model file
@@ -29,18 +29,68 @@ class ONNXInference:
29
  self.model_path = model_path
30
  self.model_arch = model_arch
31
  self.device = device
 
32
 
33
- # Set providers based on device | barebones, theres a lot more https://onnxruntime.ai/docs/execution-providers/
34
  if device == "cuda":
35
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
36
- else:
37
- providers = ["CPUExecutionProvider"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Load ONNX session
40
- self.session = ort.InferenceSession(model_path, providers=providers)
 
41
 
42
- # Store the actual provider being used
 
 
 
 
 
43
  self.execution_provider = self.session.get_providers()[0]
 
44
 
45
  # Get transform from timm model
46
  self.transform = self._get_transform()
@@ -81,11 +131,16 @@ class ONNXInference:
81
  """
82
  input_tensor = self.preprocess(image)
83
 
84
- input_name = self.session.get_inputs()[0].name
85
- output_name = self.session.get_outputs()[0].name
86
-
87
- results = self.session.run([output_name], {input_name: input_tensor})
88
- logits = results[0][0]
 
 
 
 
 
89
 
90
  return logits
91
 
 
1
  """
2
+ ONNX Inference implementation for Kaloscope LSNet model
3
  """
4
 
5
  import numpy as np
 
6
  from timm.data import resolve_data_config
7
  from timm.data.transforms_factory import create_transform
8
  from timm.models import create_model
 
19
  def __init__(self, model_path, model_arch="lsnet_xl_artist", device="cpu"):
20
  """
21
  Initialize ONNX inference session
22
+ Tries CUDA GPU execution when selected and available, and OpenVINO for CPU with CPUExecutionProvider as last fallback
23
 
24
  Args:
25
  model_path: Path to ONNX model file
 
29
  self.model_path = model_path
30
  self.model_arch = model_arch
31
  self.device = device
32
+ self.use_openvino = False
33
 
 
34
  if device == "cuda":
35
+ # Try CUDA first for GPU
36
+ try:
37
+ import onnxruntime as ort
38
+
39
+ # Set session options to suppress warnings
40
+ sess_options = ort.SessionOptions()
41
+ sess_options.log_severity_level = 3 # 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
42
+
43
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
44
+ self.session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
45
+ self.execution_provider = self.session.get_providers()[0]
46
+
47
+ # Check if CUDA is actually being used
48
+ if self.execution_provider == "CUDAExecutionProvider":
49
+ print(f"Using ONNX Runtime with {self.execution_provider}")
50
+ # Get transform from timm model
51
+ self.transform = self._get_transform()
52
+ return
53
+ else:
54
+ # CUDA failed, fall through to CPU logic
55
+ print("CUDA not available in ONNX Runtime, falling back to CPU options")
56
+ except Exception as e:
57
+ print(f"ONNX Runtime CUDA initialization failed: {e}, falling back to CPU options")
58
+
59
+ # For CPU or if CUDA failed, prefer OpenVINO
60
+ try:
61
+ import openvino as ov
62
+
63
+ # error here on purpose
64
+ # raise ImportError("aaa")
65
+
66
+ core = ov.Core()
67
+ self.model = core.read_model(model_path)
68
+ self.session = core.compile_model(self.model, "CPU")
69
+ self.execution_provider = "CPU – OpenVINO™"
70
+ self.use_openvino = True
71
+ print("Using OpenVINO runtime for inference on CPU")
72
+ except ImportError:
73
+ print("OpenVINO not available, falling back to ONNX Runtime CPU")
74
+ self._init_onnx_runtime_cpu(model_path)
75
+ except Exception as e:
76
+ print(f"OpenVINO initialization failed: {e}, falling back to ONNX Runtime CPU")
77
+ self._init_onnx_runtime_cpu(model_path)
78
+
79
+ # Get transform from timm model
80
+ self.transform = self._get_transform()
81
 
82
+ def _init_onnx_runtime_cpu(self, model_path):
83
+ """Initialize ONNX Runtime with CPU as fallback"""
84
+ import onnxruntime as ort
85
 
86
+ # Set session options to suppress warnings
87
+ sess_options = ort.SessionOptions()
88
+ sess_options.log_severity_level = 3 # Only show errors and fatal messages
89
+
90
+ providers = ["CPUExecutionProvider"]
91
+ self.session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
92
  self.execution_provider = self.session.get_providers()[0]
93
+ print(f"Using ONNX Runtime with {self.execution_provider}")
94
 
95
  # Get transform from timm model
96
  self.transform = self._get_transform()
 
131
  """
132
  input_tensor = self.preprocess(image)
133
 
134
+ if self.use_openvino:
135
+ # OpenVINO inference
136
+ results = self.session(input_tensor)
137
+ logits = list(results.values())[0][0]
138
+ else:
139
+ # ONNX Runtime inference
140
+ input_name = self.session.get_inputs()[0].name
141
+ output_name = self.session.get_outputs()[0].name
142
+ results = self.session.run([output_name], {input_name: input_tensor})
143
+ logits = results[0][0]
144
 
145
  return logits
146