import json import os import time import gradio as gr import numpy as np import pandas as pd import torch from hf_repo_utils import auto_register_kaloscope_variants, validate_models # Import inference implementations from inference_onnx import ONNXInference, softmax from inference_pytorch import PyTorchInference # This import is crucial to register the lsnet models with timm try: import lsnet.lsnet_artist # noqa: F401 except ImportError as e: print(f"Error: {e}") raise gr.Error("Could not import lsnet.lsnet_artist. Please ensure the lsnet folder is in your workspace.") # ------------------------------------------------------------ # CONFIG SECTION - Model Configuration # ------------------------------------------------------------ # Define available models with their configuration # Format: "display_name": { # "type": "onnx" or "pytorch", # "path": "local path or repo:filename", # for local use - if left empty, will use repo id # "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty), # "subfolder": "subfolder in repo" (optional, if applicable), # "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused # } MODELS = { "Kaloscope v1.1 ONNX": { "type": "onnx", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "kaloscope_1-1.onnx", "arch": "lsnet_xl_artist", }, "Kaloscope v1.1": { "type": "pytorch", "path": "", "repo_id": "heathcliff01/Kaloscope", "subfolder": "224-85.65", "filename": "best_checkpoint.pth", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0 ONNX": { "type": "onnx", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "kaloscope_1-0.onnx", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0": { "type": "pytorch", "path": "", "repo_id": "heathcliff01/Kaloscope", "filename": "best_checkpoint.pth", "arch": "lsnet_xl_artist", }, "Kaloscope v1.0 ema": { "type": "pytorch", "path": "", "repo_id": "DraconicDragon/Kaloscope-onnx", "filename": "best_checkpoint_ema.pth", "arch": "lsnet_xl_artist", }, } MODELS = validate_models(MODELS) auto_register_kaloscope_variants(MODELS) # Class mapping CSV configuration CSV_PATH = "" # if left empty, will use repo's CSV_REPO_ID = "heathcliff01/Kaloscope" CSV_FILENAME = "class_mapping.csv" # Device configuration try: DEVICE = "cuda" if torch.cuda.is_available() else "cpu" except Exception: DEVICE = "cpu" # ------------------------------------------------------------ def load_labels(csv_path): """ Loads the class labels from the provided CSV file into a dictionary. """ try: df = pd.read_csv(csv_path) if "class_id" not in df.columns or "class_name" not in df.columns: raise gr.Error("CSV file must have 'class_id' and 'class_name' columns.") df["class_name"] = df["class_name"].str.strip("'") return dict(zip(df["class_id"], df["class_name"])) except FileNotFoundError: raise gr.Error(f"CSV file not found at '{csv_path}'") except Exception as e: raise gr.Error(f"Error reading CSV file: {e}") # Load labels once at startup # Check if local file exists, otherwise download from HF if not os.path.exists(CSV_PATH): print(f"CSV not found locally at {CSV_PATH}, attempting to download from Hugging Face...") try: from huggingface_hub import hf_hub_download CSV_PATH = hf_hub_download(repo_id=CSV_REPO_ID, filename=CSV_FILENAME) print(f"Downloaded CSV to: {CSV_PATH}") except Exception as e: print(f"Failed to download CSV: {e}") raise gr.Error(f"Could not load class mapping CSV: {e}") labels = load_labels(CSV_PATH) # Initialize model cache model_cache = {} def get_model_inference(model_name): """ Get or create inference object for the specified model. Uses caching to avoid reloading models. """ if model_name not in model_cache: config = MODELS[model_name] model_path = config.get("path") or "" print(f"Loading model: {model_name} ({config['filename']})") if model_name not in MODELS: raise gr.Error(f"Unknown model: {model_name}") # Check if local file exists, otherwise try to download if not model_path or not os.path.exists(model_path): if "repo_id" in config and "filename" in config: target_display = model_path or f"repo {config['repo_id']}" print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...") try: from huggingface_hub import hf_hub_download download_kwargs = { "repo_id": config["repo_id"], "filename": config["filename"], } if config.get("subfolder"): download_kwargs["subfolder"] = config["subfolder"] if config.get("revision"): download_kwargs["revision"] = config["revision"] model_path = hf_hub_download(**download_kwargs) config["path"] = model_path print(f"Downloaded model to: {model_path}") except Exception as e: raise gr.Error(f"Could not load model from local path or Hugging Face: {e}") else: raise gr.Error(f"Model file not found at: {model_path}") # Create inference object based on type if config["type"] == "onnx": model_cache[model_name] = ONNXInference(model_path=model_path, model_arch=config["arch"], device=DEVICE) elif config["type"] == "pytorch": model_cache[model_name] = PyTorchInference( checkpoint_path=model_path, model_arch=config["arch"], device=DEVICE ) else: raise gr.Error(f"Unknown model type: {config['type']}") print(f"Model {model_name} loaded successfully ({config['filename']})") return model_cache[model_name] def predict(image, model_selection, top_k, threshold): """ Main prediction function that takes UI inputs. """ # check if there even is image and throw error if image is none and dont continue if image is None: raise gr.Error("No image provided for prediction.") # Ensure top_k is an integer for slicing top_k = int(top_k) # Get inference object for selected model inference = get_model_inference(model_selection) # Start timing start_time = time.time() # Run inference logits = inference.predict(image, top_k=top_k, threshold=threshold) # End timing inference_time = time.time() - start_time # Compute probabilities probabilities = softmax(logits) # Get all indices and their scores all_indices = np.argsort(probabilities)[::-1] tags = [] json_output = {} table_data = [] predictions_found = 0 for index in all_indices: score = probabilities[index] if score >= threshold and predictions_found < top_k: class_name = labels.get(index, f"Unknown Class #{index}") tags.append(class_name) json_output[class_name] = float(score) # Create Danbooru search URL with markdown link danbooru_url = f"https://danbooru.donmai.us/posts?tags={class_name.replace(' ', '_')}" artist_link = f"[{class_name}]({danbooru_url})" # Add copy button HTML with span instead of button copy_button = f"πŸ“‹" # Add row to table: [Rank, Artist (markdown link), Copy Button, Score] table_data.append([predictions_found + 1, artist_link, copy_button, f"{score:.2%}"]) predictions_found += 1 # Stop early if we have enough predictions if predictions_found >= top_k: break tags_output = ", ".join(tags) # Create DataFrame for display if table_data: df = pd.DataFrame(table_data, columns=["Rank", "Artist", "", "Score"]) else: df = pd.DataFrame(columns=["Rank", "Artist", "", "Score"]) # Get actual device/provider info from inference object if hasattr(inference, "execution_provider"): device_info = inference.execution_provider else: device_info = inference.device # Format time taken with 3-4 decimal places time_taken_str = f"- **{MODELS[model_selection]['type']}:** {device_info} | **Time taken:** {inference_time:.4f}s" return tags_output, df, json.dumps(json_output, indent=4), time_taken_str # --- Gradio Interface --- with gr.Blocks( css=""" * { box-sizing: border-box; } @media (max-width: 1022px) { #slider-row-container { flex-direction: column !important; } #slider-row-container > * { width: 100% !important; } #slider-row-container .block { width: 100% !important; } } #image-upload { max-height: 80vh !important; overflow: hidden !important; display: flex !important; flex-direction: column !important; } #image-upload .image-container { flex: 1 1 auto !important; min-height: 0 !important; overflow: hidden !important; display: flex !important; align-items: center !important; justify-content: center !important; } #image-upload img { max-height: 75vh !important; max-width: 100% !important; width: auto !important; height: auto !important; object-fit: contain !important; } #results-table-wrapper { overflow: hidden !important; width: 100% !important; } #results-table-wrapper .table-wrap, #results-table-wrapper .dataframe-wrap { overflow: hidden !important; width: 100% !important; } #results-table-wrapper table { width: 100% !important; table-layout: fixed !important; border-collapse: collapse !important; } #results-table-wrapper td, #results-table-wrapper th { overflow: hidden !important; text-overflow: ellipsis !important; white-space: nowrap !important; } #results-table-wrapper td:nth-child(1), #results-table-wrapper th:nth-child(1) { width: 55px !important; } #results-table-wrapper td:nth-child(2), #results-table-wrapper th:nth-child(2) { width: auto !important; } #results-table-wrapper td:nth-child(3), #results-table-wrapper th:nth-child(3) { width: 50px !important; border-left: none !important; text-align: center !important; padding: 0 !important; } #results-table-wrapper td:nth-child(4), #results-table-wrapper th:nth-child(4) { width: 69px !important; } #results-table-wrapper th:nth-child(3) { background: transparent !important; } #results-table-wrapper .copy-btn { cursor: pointer; width: 100%; height: 100%; display: flex; align-items: center; justify-content: center; user-select: none; } #results-table-wrapper .copy-btn:hover { background: rgba(128, 128, 128, 0.1); } """, js=""" function() { document.addEventListener('click', function(e) { if (e.target.classList.contains('copy-btn')) { const text = e.target.getAttribute('data-copy'); navigator.clipboard.writeText(text).then(() => { const original = e.target.textContent; e.target.textContent = 'βœ“'; setTimeout(() => { e.target.textContent = original; }, 1000); }); } }); // Fix right-click on links in DataFrame document.addEventListener('contextmenu', function(e) { // Check if the clicked element or its parent is a link inside the results table const link = e.target.closest('#results-table-wrapper a'); if (link) { e.stopPropagation(); // Let the browser's native context menu show for the link return true; } }, true); // Prevent Gradio's default click handling on links document.addEventListener('click', function(e) { const link = e.target.closest('#results-table-wrapper a'); if (link && e.button === 0) { e.stopPropagation(); // Allow normal link behavior (open in new tab due to markdown) return true; } }, true); } """, ) as demo: gr.Markdown("# Kaloscope Artist Style Classification") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload") with gr.Column(): submit_btn = gr.Button("Predict") tags_output = gr.Textbox(label="Predicted Tags", show_copy_button=True) prettier_output = gr.DataFrame( elem_id="results-table-wrapper", #value=[ # [ # 1, # "[Samplaaaae Artist](https://example.com)", # "πŸ“‹", # "95.00%", # ], # [ # 2, # "[Another Artist](https://example.com)", # "πŸ“‹", # "90.00%", # ], # [ # 3, # "[Third Artist](https://example.com)", # "πŸ“‹", # "85.00%", # ], #], interactive=False, datatype=["number", "markdown", "html", "str"], headers=["Rank", "Artist", "", "Score"], ) json_accordion = gr.Accordion("JSON Output", open=False) with json_accordion: json_output = gr.Code(language="json", show_label=False, lines=7) with gr.Group(): model_selection = gr.Dropdown( choices=[ ( f"{name}", # f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}", name, ) for name in MODELS ], value=list(MODELS.keys())[0], label="Select Model", ) with gr.Row(elem_id="slider-row-container"): top_k_slider = gr.Slider(minimum=1, maximum=25, value=5, step=1, label="Top K") threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Threshold") time_display = gr.Markdown() # populated after prediction gr.Markdown( "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) (Original PyTorch releases) " + "and [DraconicDragon/Kaloscope-onnx](https://huggingface.co/DraconicDragon/Kaloscope-onnx) (ONNX converted and EMA weights). \n" + "OpenVINOβ„’ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback." ) submit_btn.click( fn=predict, inputs=[image_input, model_selection, top_k_slider, threshold_slider], outputs=[tags_output, prettier_output, json_output, time_display], ) if __name__ == "__main__": demo.launch()