Commit 
							
							·
						
						3eba160
	
1
								Parent(s):
							
							3601115
								
feat: automatic registering of newer models and validate models + improve ui a bit
Browse files- app.py +40 -18
- hf_repo_utils.py +133 -0
    	
        app.py
    CHANGED
    
    | @@ -7,6 +7,8 @@ import numpy as np | |
| 7 | 
             
            import pandas as pd
         | 
| 8 | 
             
            import torch
         | 
| 9 |  | 
|  | |
|  | |
| 10 | 
             
            # Import inference implementations
         | 
| 11 | 
             
            from inference_onnx import ONNXInference, softmax
         | 
| 12 | 
             
            from inference_pytorch import PyTorchInference
         | 
| @@ -25,27 +27,36 @@ except ImportError as e: | |
| 25 | 
             
            # Define available models with their configuration
         | 
| 26 | 
             
            # Format: "display_name": {
         | 
| 27 | 
             
            #     "type": "onnx" or "pytorch",
         | 
| 28 | 
            -
            #     "path": "local path or repo:filename", # if left empty, will use repo id
         | 
| 29 | 
            -
            #     "repo_id": "huggingface repo id" (optional),
         | 
| 30 | 
            -
            #     " | 
|  | |
| 31 | 
             
            # }
         | 
| 32 |  | 
| 33 | 
             
            MODELS = {
         | 
| 34 | 
            -
                "Kaloscope | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 35 | 
             
                    "type": "onnx",
         | 
| 36 | 
             
                    "path": "",
         | 
| 37 | 
             
                    "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
         | 
| 38 | 
            -
                    "filename": " | 
| 39 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 40 | 
             
                },
         | 
| 41 | 
            -
                "Kaloscope | 
| 42 | 
             
                    "type": "pytorch",
         | 
| 43 | 
             
                    "path": "",
         | 
| 44 | 
             
                    "repo_id": "heathcliff01/Kaloscope",
         | 
| 45 | 
             
                    "filename": "best_checkpoint.pth",
         | 
| 46 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 47 | 
             
                },
         | 
| 48 | 
            -
                "Kaloscope | 
| 49 | 
             
                    "type": "pytorch",
         | 
| 50 | 
             
                    "path": "",
         | 
| 51 | 
             
                    "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
         | 
| @@ -53,6 +64,8 @@ MODELS = { | |
| 53 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 54 | 
             
                },
         | 
| 55 | 
             
            }
         | 
|  | |
|  | |
| 56 |  | 
| 57 | 
             
            # Class mapping CSV configuration
         | 
| 58 | 
             
            CSV_PATH = ""  # if left empty, will use repo's
         | 
| @@ -111,7 +124,7 @@ def get_model_inference(model_name): | |
| 111 |  | 
| 112 | 
             
                if model_name not in model_cache:
         | 
| 113 | 
             
                    config = MODELS[model_name]
         | 
| 114 | 
            -
                    model_path = config | 
| 115 |  | 
| 116 | 
             
                    print(f"Loading model: {model_name} ({config['filename']})")
         | 
| 117 |  | 
| @@ -119,13 +132,24 @@ def get_model_inference(model_name): | |
| 119 | 
             
                        raise gr.Error(f"Unknown model: {model_name}")
         | 
| 120 |  | 
| 121 | 
             
                    # Check if local file exists, otherwise try to download
         | 
| 122 | 
            -
                    if not os.path.exists(model_path):
         | 
| 123 | 
             
                        if "repo_id" in config and "filename" in config:
         | 
| 124 | 
            -
                             | 
|  | |
| 125 | 
             
                            try:
         | 
| 126 | 
             
                                from huggingface_hub import hf_hub_download
         | 
| 127 |  | 
| 128 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 129 | 
             
                                print(f"Downloaded model to: {model_path}")
         | 
| 130 | 
             
                            except Exception as e:
         | 
| 131 | 
             
                                raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
         | 
| @@ -394,7 +418,8 @@ with gr.Blocks( | |
| 394 | 
             
                            model_selection = gr.Dropdown(
         | 
| 395 | 
             
                                choices=[
         | 
| 396 | 
             
                                    (
         | 
| 397 | 
            -
                                        f"{name} | 
|  | |
| 398 | 
             
                                        name,
         | 
| 399 | 
             
                                    )
         | 
| 400 | 
             
                                    for name in MODELS
         | 
| @@ -408,13 +433,10 @@ with gr.Blocks( | |
| 408 | 
             
                            time_display = gr.Markdown()  # populated after prediction
         | 
| 409 |  | 
| 410 | 
             
                gr.Markdown(
         | 
| 411 | 
            -
                    "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) ( | 
| 412 | 
            -
                    + "and [DraconicDragon/Kaloscope-onnx-ema](https://huggingface.co/DraconicDragon/Kaloscope-onnx-ema) (ONNX converted and EMA weights)."
         | 
|  | |
| 413 | 
             
                )
         | 
| 414 | 
            -
                gr.Markdown("OpenVINO™ will be used to accelerate CPU inference with ONNX CPUExecutionProvider as fallback.")
         | 
| 415 | 
            -
                # gr.Markdown(
         | 
| 416 | 
            -
                #    "ONNX models might output different scores compared to PyTorch possibly due to differences in numerical precision and operations (but they are still likely just as accurate overall)"
         | 
| 417 | 
            -
                # )
         | 
| 418 |  | 
| 419 | 
             
                submit_btn.click(
         | 
| 420 | 
             
                    fn=predict,
         | 
|  | |
| 7 | 
             
            import pandas as pd
         | 
| 8 | 
             
            import torch
         | 
| 9 |  | 
| 10 | 
            +
            from hf_repo_utils import auto_register_kaloscope_variants, validate_models
         | 
| 11 | 
            +
             | 
| 12 | 
             
            # Import inference implementations
         | 
| 13 | 
             
            from inference_onnx import ONNXInference, softmax
         | 
| 14 | 
             
            from inference_pytorch import PyTorchInference
         | 
|  | |
| 27 | 
             
            # Define available models with their configuration
         | 
| 28 | 
             
            # Format: "display_name": {
         | 
| 29 | 
             
            #     "type": "onnx" or "pytorch",
         | 
| 30 | 
            +
            #     "path": "local path or repo:filename", # for local use - if left empty, will use repo id
         | 
| 31 | 
            +
            #     "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty),
         | 
| 32 | 
            +
            #     "subfolder": "subfolder in repo" (optional, if applicable),
         | 
| 33 | 
            +
            #     "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused
         | 
| 34 | 
             
            # }
         | 
| 35 |  | 
| 36 | 
             
            MODELS = {
         | 
| 37 | 
            +
                "Kaloscope v1.1": {
         | 
| 38 | 
            +
                    "type": "pytorch",
         | 
| 39 | 
            +
                    "path": "",
         | 
| 40 | 
            +
                    "repo_id": "heathcliff01/Kaloscope",
         | 
| 41 | 
            +
                    "subfolder": "224-85.65",
         | 
| 42 | 
            +
                    "filename": "best_checkpoint.pth",
         | 
| 43 | 
            +
                    "arch": "lsnet_xl_artist",
         | 
| 44 | 
            +
                },
         | 
| 45 | 
            +
                "Kaloscope v1.0 ONNX": {
         | 
| 46 | 
             
                    "type": "onnx",
         | 
| 47 | 
             
                    "path": "",
         | 
| 48 | 
             
                    "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
         | 
| 49 | 
            +
                    "filename": "kaloscope_1-0.onnx",
         | 
| 50 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 51 | 
             
                },
         | 
| 52 | 
            +
                "Kaloscope v1.0": {
         | 
| 53 | 
             
                    "type": "pytorch",
         | 
| 54 | 
             
                    "path": "",
         | 
| 55 | 
             
                    "repo_id": "heathcliff01/Kaloscope",
         | 
| 56 | 
             
                    "filename": "best_checkpoint.pth",
         | 
| 57 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 58 | 
             
                },
         | 
| 59 | 
            +
                "Kaloscope v1.0 ema": {
         | 
| 60 | 
             
                    "type": "pytorch",
         | 
| 61 | 
             
                    "path": "",
         | 
| 62 | 
             
                    "repo_id": "DraconicDragon/Kaloscope-onnx-ema",
         | 
|  | |
| 64 | 
             
                    "arch": "lsnet_xl_artist",
         | 
| 65 | 
             
                },
         | 
| 66 | 
             
            }
         | 
| 67 | 
            +
            MODELS = validate_models(MODELS)
         | 
| 68 | 
            +
            auto_register_kaloscope_variants(MODELS)
         | 
| 69 |  | 
| 70 | 
             
            # Class mapping CSV configuration
         | 
| 71 | 
             
            CSV_PATH = ""  # if left empty, will use repo's
         | 
|  | |
| 124 |  | 
| 125 | 
             
                if model_name not in model_cache:
         | 
| 126 | 
             
                    config = MODELS[model_name]
         | 
| 127 | 
            +
                    model_path = config.get("path") or ""
         | 
| 128 |  | 
| 129 | 
             
                    print(f"Loading model: {model_name} ({config['filename']})")
         | 
| 130 |  | 
|  | |
| 132 | 
             
                        raise gr.Error(f"Unknown model: {model_name}")
         | 
| 133 |  | 
| 134 | 
             
                    # Check if local file exists, otherwise try to download
         | 
| 135 | 
            +
                    if not model_path or not os.path.exists(model_path):
         | 
| 136 | 
             
                        if "repo_id" in config and "filename" in config:
         | 
| 137 | 
            +
                            target_display = model_path or f"repo {config['repo_id']}"
         | 
| 138 | 
            +
                            print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...")
         | 
| 139 | 
             
                            try:
         | 
| 140 | 
             
                                from huggingface_hub import hf_hub_download
         | 
| 141 |  | 
| 142 | 
            +
                                download_kwargs = {
         | 
| 143 | 
            +
                                    "repo_id": config["repo_id"],
         | 
| 144 | 
            +
                                    "filename": config["filename"],
         | 
| 145 | 
            +
                                }
         | 
| 146 | 
            +
                                if config.get("subfolder"):
         | 
| 147 | 
            +
                                    download_kwargs["subfolder"] = config["subfolder"]
         | 
| 148 | 
            +
                                if config.get("revision"):
         | 
| 149 | 
            +
                                    download_kwargs["revision"] = config["revision"]
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                                model_path = hf_hub_download(**download_kwargs)
         | 
| 152 | 
            +
                                config["path"] = model_path
         | 
| 153 | 
             
                                print(f"Downloaded model to: {model_path}")
         | 
| 154 | 
             
                            except Exception as e:
         | 
| 155 | 
             
                                raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
         | 
|  | |
| 418 | 
             
                            model_selection = gr.Dropdown(
         | 
| 419 | 
             
                                choices=[
         | 
| 420 | 
             
                                    (
         | 
| 421 | 
            +
                                        f"{name}",
         | 
| 422 | 
            +
                                        # f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}",
         | 
| 423 | 
             
                                        name,
         | 
| 424 | 
             
                                    )
         | 
| 425 | 
             
                                    for name in MODELS
         | 
|  | |
| 433 | 
             
                            time_display = gr.Markdown()  # populated after prediction
         | 
| 434 |  | 
| 435 | 
             
                gr.Markdown(
         | 
| 436 | 
            +
                    "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) (Original PyTorch releases) "
         | 
| 437 | 
            +
                    + "and [DraconicDragon/Kaloscope-onnx-ema](https://huggingface.co/DraconicDragon/Kaloscope-onnx-ema) (ONNX converted and EMA weights).  \n"
         | 
| 438 | 
            +
                    + "OpenVINO™ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback."
         | 
| 439 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
| 440 |  | 
| 441 | 
             
                submit_btn.click(
         | 
| 442 | 
             
                    fn=predict,
         | 
    	
        hf_repo_utils.py
    ADDED
    
    | @@ -0,0 +1,133 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Utility helpers for Kaloscope-related model management."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from __future__ import annotations
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from typing import Any, Dict, Iterable, Tuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def _iter_repo_files(repo_info: Any) -> Iterable[Any]:
         | 
| 10 | 
            +
                """Yield file metadata objects from a repository info response."""
         | 
| 11 | 
            +
                siblings = getattr(repo_info, "siblings", None)
         | 
| 12 | 
            +
                if not siblings:
         | 
| 13 | 
            +
                    return []
         | 
| 14 | 
            +
                return siblings
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def auto_register_kaloscope_variants(models: Dict[str, Dict[str, Any]]) -> None:
         | 
| 18 | 
            +
                """Discover Kaloscope checkpoints in subfolders and append them to the model list."""
         | 
| 19 | 
            +
                try:
         | 
| 20 | 
            +
                    from huggingface_hub import HfApi
         | 
| 21 | 
            +
                except ImportError:
         | 
| 22 | 
            +
                    print("huggingface_hub is not available; skipping Kaloscope auto-discovery.")
         | 
| 23 | 
            +
                    return
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                repo_id = "heathcliff01/Kaloscope"
         | 
| 26 | 
            +
                api = HfApi()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                try:
         | 
| 29 | 
            +
                    repo_info = api.model_info(repo_id=repo_id)
         | 
| 30 | 
            +
                except Exception as exc:
         | 
| 31 | 
            +
                    print(f"Failed to query Hugging Face repo info: {exc}")
         | 
| 32 | 
            +
                    return
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                existing_entries: set[Tuple[str, str, str]] = {
         | 
| 35 | 
            +
                    (
         | 
| 36 | 
            +
                        config.get("repo_id", ""),
         | 
| 37 | 
            +
                        config.get("subfolder", ""),
         | 
| 38 | 
            +
                        config.get("filename", ""),
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
                    for config in models.values()
         | 
| 41 | 
            +
                    if config.get("repo_id") and config.get("filename")
         | 
| 42 | 
            +
                }
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                discovered: list[Tuple[str, Dict[str, Any], Tuple[str, str, str]]] = []
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                for sibling in _iter_repo_files(repo_info):
         | 
| 47 | 
            +
                    path = getattr(sibling, "rfilename", None) or getattr(sibling, "path", None)
         | 
| 48 | 
            +
                    if not path or not path.endswith(".pth"):
         | 
| 49 | 
            +
                        continue
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    subfolder = os.path.dirname(path).strip("/")
         | 
| 52 | 
            +
                    filename = os.path.basename(path)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    key = (repo_id, subfolder, filename)
         | 
| 55 | 
            +
                    if key in existing_entries:
         | 
| 56 | 
            +
                        continue
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # Use repo_id/folder format if in subfolder, otherwise repo_id/filename
         | 
| 59 | 
            +
                    repo_name = repo_id.split("/")[-1]
         | 
| 60 | 
            +
                    if subfolder:
         | 
| 61 | 
            +
                        display_name = f"{repo_name}/{subfolder.split('/')[-1]}"
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        display_name = f"{repo_name}/{os.path.splitext(filename)[0]}"
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    config: Dict[str, Any] = {
         | 
| 66 | 
            +
                        "type": "pytorch",
         | 
| 67 | 
            +
                        "path": "",
         | 
| 68 | 
            +
                        "repo_id": repo_id,
         | 
| 69 | 
            +
                        "filename": filename,
         | 
| 70 | 
            +
                        "arch": "lsnet_xl_artist",
         | 
| 71 | 
            +
                    }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if subfolder:
         | 
| 74 | 
            +
                        config["subfolder"] = subfolder
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    discovered.append((display_name, config, key))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # Sort alphabetically by display name
         | 
| 79 | 
            +
                discovered.sort(key=lambda item: item[0])
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                for display_name, config, key in discovered:
         | 
| 82 | 
            +
                    candidate_name = display_name
         | 
| 83 | 
            +
                    suffix = 2
         | 
| 84 | 
            +
                    while candidate_name in models:
         | 
| 85 | 
            +
                        candidate_name = f"{display_name} #{suffix}"
         | 
| 86 | 
            +
                        suffix += 1
         | 
| 87 | 
            +
                    models[candidate_name] = config
         | 
| 88 | 
            +
                    existing_entries.add(key)
         | 
| 89 | 
            +
                    print(f"Auto-registered: {candidate_name}")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def validate_models(models: Dict[str, Dict[str, Any]]):
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Validate models at startup and remove invalid entries.
         | 
| 95 | 
            +
                Returns list of valid model names.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                valid_models = {}
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                for model_name, config in models.items():
         | 
| 100 | 
            +
                    model_path = config.get("path") or ""
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # If local path exists, it's valid
         | 
| 103 | 
            +
                    if model_path and os.path.exists(model_path):
         | 
| 104 | 
            +
                        valid_models[model_name] = config
         | 
| 105 | 
            +
                        continue
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Check if it can be downloaded from HF
         | 
| 108 | 
            +
                    if "repo_id" in config and "filename" in config:
         | 
| 109 | 
            +
                        try:
         | 
| 110 | 
            +
                            from huggingface_hub import file_exists
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                            # Build kwargs for checking file existence
         | 
| 113 | 
            +
                            check_kwargs = {
         | 
| 114 | 
            +
                                "repo_id": config["repo_id"],
         | 
| 115 | 
            +
                                "filename": config["filename"],
         | 
| 116 | 
            +
                                "repo_type": "model",
         | 
| 117 | 
            +
                            }
         | 
| 118 | 
            +
                            if config.get("subfolder"):
         | 
| 119 | 
            +
                                check_kwargs["filename"] = f"{config['subfolder']}/{config['filename']}"
         | 
| 120 | 
            +
                            if config.get("revision"):
         | 
| 121 | 
            +
                                check_kwargs["revision"] = config["revision"]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                            # Check if file exists on HF
         | 
| 124 | 
            +
                            if file_exists(**check_kwargs):
         | 
| 125 | 
            +
                                valid_models[model_name] = config
         | 
| 126 | 
            +
                            else:
         | 
| 127 | 
            +
                                print(f"Skipping {model_name}: file not found on Hugging Face")
         | 
| 128 | 
            +
                        except Exception as e:
         | 
| 129 | 
            +
                            print(f"Skipping {model_name}: validation error - {e}")
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        print(f"Skipping {model_name}: no valid path or repo configuration")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                return valid_models
         | 
