aliasthebone commited on
Commit
3eba160
·
1 Parent(s): 3601115

feat: automatic registering of newer models and validate models + improve ui a bit

Browse files
Files changed (2) hide show
  1. app.py +40 -18
  2. hf_repo_utils.py +133 -0
app.py CHANGED
@@ -7,6 +7,8 @@ import numpy as np
7
  import pandas as pd
8
  import torch
9
 
 
 
10
  # Import inference implementations
11
  from inference_onnx import ONNXInference, softmax
12
  from inference_pytorch import PyTorchInference
@@ -25,27 +27,36 @@ except ImportError as e:
25
  # Define available models with their configuration
26
  # Format: "display_name": {
27
  # "type": "onnx" or "pytorch",
28
- # "path": "local path or repo:filename", # if left empty, will use repo id
29
- # "repo_id": "huggingface repo id" (optional),
30
- # "arch": "model architecture name"
 
31
  # }
32
 
33
  MODELS = {
34
- "Kaloscope-onnx-opset18": {
 
 
 
 
 
 
 
 
35
  "type": "onnx",
36
  "path": "",
37
  "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
38
- "filename": "lsnet_xl_artist-dynamo-opset18_merged.onnx",
39
  "arch": "lsnet_xl_artist",
40
  },
41
- "Kaloscope-release ('best_checkpoint.pth')": {
42
  "type": "pytorch",
43
  "path": "",
44
  "repo_id": "heathcliff01/Kaloscope",
45
  "filename": "best_checkpoint.pth",
46
  "arch": "lsnet_xl_artist",
47
  },
48
- "Kaloscope-release-ema": {
49
  "type": "pytorch",
50
  "path": "",
51
  "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
@@ -53,6 +64,8 @@ MODELS = {
53
  "arch": "lsnet_xl_artist",
54
  },
55
  }
 
 
56
 
57
  # Class mapping CSV configuration
58
  CSV_PATH = "" # if left empty, will use repo's
@@ -111,7 +124,7 @@ def get_model_inference(model_name):
111
 
112
  if model_name not in model_cache:
113
  config = MODELS[model_name]
114
- model_path = config["path"]
115
 
116
  print(f"Loading model: {model_name} ({config['filename']})")
117
 
@@ -119,13 +132,24 @@ def get_model_inference(model_name):
119
  raise gr.Error(f"Unknown model: {model_name}")
120
 
121
  # Check if local file exists, otherwise try to download
122
- if not os.path.exists(model_path):
123
  if "repo_id" in config and "filename" in config:
124
- print(f"Model not found locally at {model_path}, attempting to download from Hugging Face...")
 
125
  try:
126
  from huggingface_hub import hf_hub_download
127
 
128
- model_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])
 
 
 
 
 
 
 
 
 
 
129
  print(f"Downloaded model to: {model_path}")
130
  except Exception as e:
131
  raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
@@ -394,7 +418,8 @@ with gr.Blocks(
394
  model_selection = gr.Dropdown(
395
  choices=[
396
  (
397
- f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}",
 
398
  name,
399
  )
400
  for name in MODELS
@@ -408,13 +433,10 @@ with gr.Blocks(
408
  time_display = gr.Markdown() # populated after prediction
409
 
410
  gr.Markdown(
411
- "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) (original release) "
412
- + "and [DraconicDragon/Kaloscope-onnx-ema](https://huggingface.co/DraconicDragon/Kaloscope-onnx-ema) (ONNX converted and EMA weights)."
 
413
  )
414
- gr.Markdown("OpenVINO™ will be used to accelerate CPU inference with ONNX CPUExecutionProvider as fallback.")
415
- # gr.Markdown(
416
- # "ONNX models might output different scores compared to PyTorch possibly due to differences in numerical precision and operations (but they are still likely just as accurate overall)"
417
- # )
418
 
419
  submit_btn.click(
420
  fn=predict,
 
7
  import pandas as pd
8
  import torch
9
 
10
+ from hf_repo_utils import auto_register_kaloscope_variants, validate_models
11
+
12
  # Import inference implementations
13
  from inference_onnx import ONNXInference, softmax
14
  from inference_pytorch import PyTorchInference
 
27
  # Define available models with their configuration
28
  # Format: "display_name": {
29
  # "type": "onnx" or "pytorch",
30
+ # "path": "local path or repo:filename", # for local use - if left empty, will use repo id
31
+ # "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty),
32
+ # "subfolder": "subfolder in repo" (optional, if applicable),
33
+ # "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused
34
  # }
35
 
36
  MODELS = {
37
+ "Kaloscope v1.1": {
38
+ "type": "pytorch",
39
+ "path": "",
40
+ "repo_id": "heathcliff01/Kaloscope",
41
+ "subfolder": "224-85.65",
42
+ "filename": "best_checkpoint.pth",
43
+ "arch": "lsnet_xl_artist",
44
+ },
45
+ "Kaloscope v1.0 ONNX": {
46
  "type": "onnx",
47
  "path": "",
48
  "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
49
+ "filename": "kaloscope_1-0.onnx",
50
  "arch": "lsnet_xl_artist",
51
  },
52
+ "Kaloscope v1.0": {
53
  "type": "pytorch",
54
  "path": "",
55
  "repo_id": "heathcliff01/Kaloscope",
56
  "filename": "best_checkpoint.pth",
57
  "arch": "lsnet_xl_artist",
58
  },
59
+ "Kaloscope v1.0 ema": {
60
  "type": "pytorch",
61
  "path": "",
62
  "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
 
64
  "arch": "lsnet_xl_artist",
65
  },
66
  }
67
+ MODELS = validate_models(MODELS)
68
+ auto_register_kaloscope_variants(MODELS)
69
 
70
  # Class mapping CSV configuration
71
  CSV_PATH = "" # if left empty, will use repo's
 
124
 
125
  if model_name not in model_cache:
126
  config = MODELS[model_name]
127
+ model_path = config.get("path") or ""
128
 
129
  print(f"Loading model: {model_name} ({config['filename']})")
130
 
 
132
  raise gr.Error(f"Unknown model: {model_name}")
133
 
134
  # Check if local file exists, otherwise try to download
135
+ if not model_path or not os.path.exists(model_path):
136
  if "repo_id" in config and "filename" in config:
137
+ target_display = model_path or f"repo {config['repo_id']}"
138
+ print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...")
139
  try:
140
  from huggingface_hub import hf_hub_download
141
 
142
+ download_kwargs = {
143
+ "repo_id": config["repo_id"],
144
+ "filename": config["filename"],
145
+ }
146
+ if config.get("subfolder"):
147
+ download_kwargs["subfolder"] = config["subfolder"]
148
+ if config.get("revision"):
149
+ download_kwargs["revision"] = config["revision"]
150
+
151
+ model_path = hf_hub_download(**download_kwargs)
152
+ config["path"] = model_path
153
  print(f"Downloaded model to: {model_path}")
154
  except Exception as e:
155
  raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
 
418
  model_selection = gr.Dropdown(
419
  choices=[
420
  (
421
+ f"{name}",
422
+ # f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}",
423
  name,
424
  )
425
  for name in MODELS
 
433
  time_display = gr.Markdown() # populated after prediction
434
 
435
  gr.Markdown(
436
+ "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) (Original PyTorch releases) "
437
+ + "and [DraconicDragon/Kaloscope-onnx-ema](https://huggingface.co/DraconicDragon/Kaloscope-onnx-ema) (ONNX converted and EMA weights). \n"
438
+ + "OpenVINO™ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback."
439
  )
 
 
 
 
440
 
441
  submit_btn.click(
442
  fn=predict,
hf_repo_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility helpers for Kaloscope-related model management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any, Dict, Iterable, Tuple
7
+
8
+
9
+ def _iter_repo_files(repo_info: Any) -> Iterable[Any]:
10
+ """Yield file metadata objects from a repository info response."""
11
+ siblings = getattr(repo_info, "siblings", None)
12
+ if not siblings:
13
+ return []
14
+ return siblings
15
+
16
+
17
+ def auto_register_kaloscope_variants(models: Dict[str, Dict[str, Any]]) -> None:
18
+ """Discover Kaloscope checkpoints in subfolders and append them to the model list."""
19
+ try:
20
+ from huggingface_hub import HfApi
21
+ except ImportError:
22
+ print("huggingface_hub is not available; skipping Kaloscope auto-discovery.")
23
+ return
24
+
25
+ repo_id = "heathcliff01/Kaloscope"
26
+ api = HfApi()
27
+
28
+ try:
29
+ repo_info = api.model_info(repo_id=repo_id)
30
+ except Exception as exc:
31
+ print(f"Failed to query Hugging Face repo info: {exc}")
32
+ return
33
+
34
+ existing_entries: set[Tuple[str, str, str]] = {
35
+ (
36
+ config.get("repo_id", ""),
37
+ config.get("subfolder", ""),
38
+ config.get("filename", ""),
39
+ )
40
+ for config in models.values()
41
+ if config.get("repo_id") and config.get("filename")
42
+ }
43
+
44
+ discovered: list[Tuple[str, Dict[str, Any], Tuple[str, str, str]]] = []
45
+
46
+ for sibling in _iter_repo_files(repo_info):
47
+ path = getattr(sibling, "rfilename", None) or getattr(sibling, "path", None)
48
+ if not path or not path.endswith(".pth"):
49
+ continue
50
+
51
+ subfolder = os.path.dirname(path).strip("/")
52
+ filename = os.path.basename(path)
53
+
54
+ key = (repo_id, subfolder, filename)
55
+ if key in existing_entries:
56
+ continue
57
+
58
+ # Use repo_id/folder format if in subfolder, otherwise repo_id/filename
59
+ repo_name = repo_id.split("/")[-1]
60
+ if subfolder:
61
+ display_name = f"{repo_name}/{subfolder.split('/')[-1]}"
62
+ else:
63
+ display_name = f"{repo_name}/{os.path.splitext(filename)[0]}"
64
+
65
+ config: Dict[str, Any] = {
66
+ "type": "pytorch",
67
+ "path": "",
68
+ "repo_id": repo_id,
69
+ "filename": filename,
70
+ "arch": "lsnet_xl_artist",
71
+ }
72
+
73
+ if subfolder:
74
+ config["subfolder"] = subfolder
75
+
76
+ discovered.append((display_name, config, key))
77
+
78
+ # Sort alphabetically by display name
79
+ discovered.sort(key=lambda item: item[0])
80
+
81
+ for display_name, config, key in discovered:
82
+ candidate_name = display_name
83
+ suffix = 2
84
+ while candidate_name in models:
85
+ candidate_name = f"{display_name} #{suffix}"
86
+ suffix += 1
87
+ models[candidate_name] = config
88
+ existing_entries.add(key)
89
+ print(f"Auto-registered: {candidate_name}")
90
+
91
+
92
+ def validate_models(models: Dict[str, Dict[str, Any]]):
93
+ """
94
+ Validate models at startup and remove invalid entries.
95
+ Returns list of valid model names.
96
+ """
97
+ valid_models = {}
98
+
99
+ for model_name, config in models.items():
100
+ model_path = config.get("path") or ""
101
+
102
+ # If local path exists, it's valid
103
+ if model_path and os.path.exists(model_path):
104
+ valid_models[model_name] = config
105
+ continue
106
+
107
+ # Check if it can be downloaded from HF
108
+ if "repo_id" in config and "filename" in config:
109
+ try:
110
+ from huggingface_hub import file_exists
111
+
112
+ # Build kwargs for checking file existence
113
+ check_kwargs = {
114
+ "repo_id": config["repo_id"],
115
+ "filename": config["filename"],
116
+ "repo_type": "model",
117
+ }
118
+ if config.get("subfolder"):
119
+ check_kwargs["filename"] = f"{config['subfolder']}/{config['filename']}"
120
+ if config.get("revision"):
121
+ check_kwargs["revision"] = config["revision"]
122
+
123
+ # Check if file exists on HF
124
+ if file_exists(**check_kwargs):
125
+ valid_models[model_name] = config
126
+ else:
127
+ print(f"Skipping {model_name}: file not found on Hugging Face")
128
+ except Exception as e:
129
+ print(f"Skipping {model_name}: validation error - {e}")
130
+ else:
131
+ print(f"Skipping {model_name}: no valid path or repo configuration")
132
+
133
+ return valid_models