import json
import os
import time
import gradio as gr
import numpy as np
import pandas as pd
import torch
from hf_repo_utils import auto_register_kaloscope_variants, validate_models
# Import inference implementations
from inference_onnx import ONNXInference, softmax
from inference_pytorch import PyTorchInference
# This import is crucial to register the lsnet models with timm
try:
    import lsnet.lsnet_artist  # noqa: F401
except ImportError as e:
    print(f"Error: {e}")
    raise gr.Error("Could not import lsnet.lsnet_artist. Please ensure the lsnet folder is in your workspace.")
# ------------------------------------------------------------
# CONFIG SECTION - Model Configuration
# ------------------------------------------------------------
# Define available models with their configuration
# Format: "display_name": {
#     "type": "onnx" or "pytorch",
#     "path": "local path or repo:filename", # for local use - if left empty, will use repo id
#     "repo_id": "huggingface repo id" (optional, but path cant be empty if repo_id empty),
#     "subfolder": "subfolder in repo" (optional, if applicable),
#     "arch": "model architecture name" - lsnet_xl_artist is expected for usual Kaloscope releases b, l, s exist but unused
# }
MODELS = {
    "Kaloscope v1.1 ONNX": {
        "type": "onnx",
        "path": "",
        "repo_id": "DraconicDragon/Kaloscope-onnx",
        "filename": "kaloscope_1-1.onnx",
        "arch": "lsnet_xl_artist",
    },
    "Kaloscope v1.1": {
        "type": "pytorch",
        "path": "",
        "repo_id": "heathcliff01/Kaloscope",
        "subfolder": "224-85.65",
        "filename": "best_checkpoint.pth",
        "arch": "lsnet_xl_artist",
    },
    "Kaloscope v1.0 ONNX": {
        "type": "onnx",
        "path": "",
        "repo_id": "DraconicDragon/Kaloscope-onnx",
        "filename": "kaloscope_1-0.onnx",
        "arch": "lsnet_xl_artist",
    },
    "Kaloscope v1.0": {
        "type": "pytorch",
        "path": "",
        "repo_id": "heathcliff01/Kaloscope",
        "filename": "best_checkpoint.pth",
        "arch": "lsnet_xl_artist",
    },
    "Kaloscope v1.0 ema": {
        "type": "pytorch",
        "path": "",
        "repo_id": "DraconicDragon/Kaloscope-onnx",
        "filename": "best_checkpoint_ema.pth",
        "arch": "lsnet_xl_artist",
    },
}
MODELS = validate_models(MODELS)
auto_register_kaloscope_variants(MODELS)
# Class mapping CSV configuration
CSV_PATH = ""  # if left empty, will use repo's
CSV_REPO_ID = "heathcliff01/Kaloscope"
CSV_FILENAME = "class_mapping.csv"
# Device configuration
try:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
    DEVICE = "cpu"
# ------------------------------------------------------------
def load_labels(csv_path):
    """
    Loads the class labels from the provided CSV file into a dictionary.
    """
    try:
        df = pd.read_csv(csv_path)
        if "class_id" not in df.columns or "class_name" not in df.columns:
            raise gr.Error("CSV file must have 'class_id' and 'class_name' columns.")
        df["class_name"] = df["class_name"].str.strip("'")
        return dict(zip(df["class_id"], df["class_name"]))
    except FileNotFoundError:
        raise gr.Error(f"CSV file not found at '{csv_path}'")
    except Exception as e:
        raise gr.Error(f"Error reading CSV file: {e}")
# Load labels once at startup
# Check if local file exists, otherwise download from HF
if not os.path.exists(CSV_PATH):
    print(f"CSV not found locally at {CSV_PATH}, attempting to download from Hugging Face...")
    try:
        from huggingface_hub import hf_hub_download
        CSV_PATH = hf_hub_download(repo_id=CSV_REPO_ID, filename=CSV_FILENAME)
        print(f"Downloaded CSV to: {CSV_PATH}")
    except Exception as e:
        print(f"Failed to download CSV: {e}")
        raise gr.Error(f"Could not load class mapping CSV: {e}")
labels = load_labels(CSV_PATH)
# Initialize model cache
model_cache = {}
def get_model_inference(model_name):
    """
    Get or create inference object for the specified model.
    Uses caching to avoid reloading models.
    """
    if model_name not in model_cache:
        config = MODELS[model_name]
        model_path = config.get("path") or ""
        print(f"Loading model: {model_name} ({config['filename']})")
        if model_name not in MODELS:
            raise gr.Error(f"Unknown model: {model_name}")
        # Check if local file exists, otherwise try to download
        if not model_path or not os.path.exists(model_path):
            if "repo_id" in config and "filename" in config:
                target_display = model_path or f"repo {config['repo_id']}"
                print(f"Model not found locally at {target_display}, attempting to download from Hugging Face...")
                try:
                    from huggingface_hub import hf_hub_download
                    download_kwargs = {
                        "repo_id": config["repo_id"],
                        "filename": config["filename"],
                    }
                    if config.get("subfolder"):
                        download_kwargs["subfolder"] = config["subfolder"]
                    if config.get("revision"):
                        download_kwargs["revision"] = config["revision"]
                    model_path = hf_hub_download(**download_kwargs)
                    config["path"] = model_path
                    print(f"Downloaded model to: {model_path}")
                except Exception as e:
                    raise gr.Error(f"Could not load model from local path or Hugging Face: {e}")
            else:
                raise gr.Error(f"Model file not found at: {model_path}")
        # Create inference object based on type
        if config["type"] == "onnx":
            model_cache[model_name] = ONNXInference(model_path=model_path, model_arch=config["arch"], device=DEVICE)
        elif config["type"] == "pytorch":
            model_cache[model_name] = PyTorchInference(
                checkpoint_path=model_path, model_arch=config["arch"], device=DEVICE
            )
        else:
            raise gr.Error(f"Unknown model type: {config['type']}")
        print(f"Model {model_name} loaded successfully ({config['filename']})")
    return model_cache[model_name]
def predict(image, model_selection, top_k, threshold):
    """
    Main prediction function that takes UI inputs.
    """
    # check if there even is image and throw error if image is none and dont continue
    if image is None:
        raise gr.Error("No image provided for prediction.")
    # Ensure top_k is an integer for slicing
    top_k = int(top_k)
    # Get inference object for selected model
    inference = get_model_inference(model_selection)
    # Start timing
    start_time = time.time()
    # Run inference
    logits = inference.predict(image, top_k=top_k, threshold=threshold)
    # End timing
    inference_time = time.time() - start_time
    # Compute probabilities
    probabilities = softmax(logits)
    # Get all indices and their scores
    all_indices = np.argsort(probabilities)[::-1]
    tags = []
    json_output = {}
    table_data = []
    predictions_found = 0
    for index in all_indices:
        score = probabilities[index]
        if score >= threshold and predictions_found < top_k:
            class_name = labels.get(index, f"Unknown Class #{index}")
            tags.append(class_name)
            json_output[class_name] = float(score)
            # Create Danbooru search URL with markdown link
            danbooru_url = f"https://danbooru.donmai.us/posts?tags={class_name.replace(' ', '_')}"
            artist_link = f"[{class_name}]({danbooru_url})"
            # Add copy button HTML with span instead of button
            copy_button = f"π"
            # Add row to table: [Rank, Artist (markdown link), Copy Button, Score]
            table_data.append([predictions_found + 1, artist_link, copy_button, f"{score:.2%}"])
            predictions_found += 1
        # Stop early if we have enough predictions
        if predictions_found >= top_k:
            break
    tags_output = ", ".join(tags)
    # Create DataFrame for display
    if table_data:
        df = pd.DataFrame(table_data, columns=["Rank", "Artist", "", "Score"])
    else:
        df = pd.DataFrame(columns=["Rank", "Artist", "", "Score"])
    # Get actual device/provider info from inference object
    if hasattr(inference, "execution_provider"):
        device_info = inference.execution_provider
    else:
        device_info = inference.device
    # Format time taken with 3-4 decimal places
    time_taken_str = f"- **{MODELS[model_selection]['type']}:** {device_info} | **Time taken:** {inference_time:.4f}s"
    return tags_output, df, json.dumps(json_output, indent=4), time_taken_str
# --- Gradio Interface ---
with gr.Blocks(
    css="""
    * { box-sizing: border-box; }
    @media (max-width: 1022px) {
        #slider-row-container {
            flex-direction: column !important;
        }
        #slider-row-container > * {
            width: 100% !important;
        }
        #slider-row-container .block {
            width: 100% !important;
        }
    }
    #image-upload {
        max-height: 80vh !important;
        overflow: hidden !important;
        display: flex !important;
        flex-direction: column !important;
    }
    #image-upload .image-container {
        flex: 1 1 auto !important;
        min-height: 0 !important;
        overflow: hidden !important;
        display: flex !important;
        align-items: center !important;
        justify-content: center !important;
    }
    #image-upload img {
        max-height: 75vh !important;
        max-width: 100% !important;
        width: auto !important;
        height: auto !important;
        object-fit: contain !important;
    }
    #results-table-wrapper {
        overflow: hidden !important;
        width: 100% !important;
    }
    #results-table-wrapper .table-wrap,
    #results-table-wrapper .dataframe-wrap {
        overflow: hidden !important;
        width: 100% !important;
    }
    #results-table-wrapper table {
        width: 100% !important;
        table-layout: fixed !important;
        border-collapse: collapse !important;
    }
    #results-table-wrapper td,
    #results-table-wrapper th {
        overflow: hidden !important;
        text-overflow: ellipsis !important;
        white-space: nowrap !important;
    }
    #results-table-wrapper td:nth-child(1),
    #results-table-wrapper th:nth-child(1) {
        width: 55px !important;
    }
    #results-table-wrapper td:nth-child(2),
    #results-table-wrapper th:nth-child(2) {
        width: auto !important;
    }
    #results-table-wrapper td:nth-child(3),
    #results-table-wrapper th:nth-child(3) {
        width: 50px !important;
        border-left: none !important;
        text-align: center !important;
        padding: 0 !important;
    }
    #results-table-wrapper td:nth-child(4),
    #results-table-wrapper th:nth-child(4) {
        width: 69px !important;
    }
    #results-table-wrapper th:nth-child(3) {
        background: transparent !important;
    }
    #results-table-wrapper .copy-btn {
        cursor: pointer;
        width: 100%;
        height: 100%;
        display: flex;
        align-items: center;
        justify-content: center;
        user-select: none;
    }
    #results-table-wrapper .copy-btn:hover {
        background: rgba(128, 128, 128, 0.1);
    }
""",
    js="""
    function() {
        document.addEventListener('click', function(e) {
            if (e.target.classList.contains('copy-btn')) {
                const text = e.target.getAttribute('data-copy');
                navigator.clipboard.writeText(text).then(() => {
                    const original = e.target.textContent;
                    e.target.textContent = 'β';
                    setTimeout(() => { e.target.textContent = original; }, 1000);
                });
            }
        });
        
        // Fix right-click on links in DataFrame
        document.addEventListener('contextmenu', function(e) {
            // Check if the clicked element or its parent is a link inside the results table
            const link = e.target.closest('#results-table-wrapper a');
            if (link) {
                e.stopPropagation();
                // Let the browser's native context menu show for the link
                return true;
            }
        }, true);
        
        // Prevent Gradio's default click handling on links
        document.addEventListener('click', function(e) {
            const link = e.target.closest('#results-table-wrapper a');
            if (link && e.button === 0) {
                e.stopPropagation();
                // Allow normal link behavior (open in new tab due to markdown)
                return true;
            }
        }, true);
    }
    """,
) as demo:
    gr.Markdown("# Kaloscope Artist Style Classification")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload")
        with gr.Column():
            submit_btn = gr.Button("Predict")
            tags_output = gr.Textbox(label="Predicted Tags", show_copy_button=True)
            prettier_output = gr.DataFrame(
                elem_id="results-table-wrapper",
                #value=[
                #    [
                #        1,
                #        "[Samplaaaae Artist](https://example.com)",
                #        "π",
                #        "95.00%",
                #    ],
                #    [
                #        2,
                #        "[Another Artist](https://example.com)",
                #        "π",
                #        "90.00%",
                #    ],
                #    [
                #        3,
                #        "[Third Artist](https://example.com)",
                #        "π",
                #        "85.00%",
                #    ],
                #],
                interactive=False,
                datatype=["number", "markdown", "html", "str"],
                headers=["Rank", "Artist", "", "Score"],
            )
            json_accordion = gr.Accordion("JSON Output", open=False)
            with json_accordion:
                json_output = gr.Code(language="json", show_label=False, lines=7)
            with gr.Group():
                model_selection = gr.Dropdown(
                    choices=[
                        (
                            f"{name}",
                            # f"{name} | Repo: {MODELS[name].get('repo_id') or 'local'}",
                            name,
                        )
                        for name in MODELS
                    ],
                    value=list(MODELS.keys())[0],
                    label="Select Model",
                )
                with gr.Row(elem_id="slider-row-container"):
                    top_k_slider = gr.Slider(minimum=1, maximum=25, value=5, step=1, label="Top K")
                    threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Threshold")
                time_display = gr.Markdown()  # populated after prediction
    gr.Markdown(
        "Models sourced from [heathcliff01/Kaloscope](https://huggingface.co/heathcliff01/Kaloscope) (Original PyTorch releases) "
        + "and [DraconicDragon/Kaloscope-onnx](https://huggingface.co/DraconicDragon/Kaloscope-onnx) (ONNX converted and EMA weights).  \n"
        + "OpenVINOβ’ will be used to accelerate ONNX CPU inference with ONNX CPUExecutionProvider as fallback."
    )
    submit_btn.click(
        fn=predict,
        inputs=[image_input, model_selection, top_k_slider, threshold_slider],
        outputs=[tags_output, prettier_output, json_output, time_display],
    )
if __name__ == "__main__":
    demo.launch()