{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "q4KpnNL4lY6q" }, "source": [ "### Getting Ready" ] }, { "cell_type": "code", "source": [ "#!pip install datasets\n", "#!pip uninstall -y diffusers\n", "!git clone https://github.com/huggingface/diffusers.git\n", "!pip install git+https://github.com/huggingface/diffusers.git\n", "#!pip install --upgrade transformers accelerate safetensors torch torchvision" ], "metadata": { "id": "yOvCmByVINi7", "collapsed": true }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n" ], "metadata": { "id": "I4vsjgK2AbgI" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#Add trigger word to dataset and create the training paramters\n", "\n", "import os\n", "import json\n", "from datasets import load_dataset\n", "from accelerate.utils import write_basic_config\n", "from huggingface_hub import create_repo, upload_folder\n", "\n", "# --- 2. Configuration ---\n", "# This is where you set all the important parameters for the training job.\n", "\n", "# Model and Dataset Parameters\n", "base_model_id = \"runwayml/stable-diffusion-v1-5\"\n", "dataset_name = \"iresidentevil/pepe_the_frog\" # The original dataset\n", "text_column = \"prompt\"\n", "image_column = \"image\"\n", "trigger_word = \"pepe_style_frog\" # The trigger word we decided on\n", "\n", "# Training Parameters\n", "output_dir = \"/content/drive/MyDrive/pepe-lora-sdxl-turbo_2\" # Where the trained LoRA will be saved\n", "resolution = 512 # SDXL-Turbo works well at 512x512. Higher resolutions need more VRAM.\n", "learning_rate = 1e-4\n", "train_batch_size = 1 # Keep this at 1 for a small dataset to see each image.\n", "gradient_accumulation_steps = 4\n", "max_train_steps = 500 # A good starting point for a small dataset. Adjust as needed.\n", "checkpointing_steps = 100 # Save a checkpoint every 100 steps.\n", "\n", "# LoRA Specific Parameters\n", "lora_rank = 16 # Rank (dimension) of the LoRA. 16 is a good balance.\n", "\n", "# Hugging Face Hub Parameters\n", "hf_hub_repo_id = \"your-username/pepe-lora-sdxl-turbo\" # Change to your Hub username and desired repo name\n", "push_to_hub = True # Set to True to automatically upload your LoRA to the Hub\n", "\n", "\n", "# --- 3. Prepare Dataset in \"Image Folder\" format ---\n", "# This section now creates a local folder with images and a metadata.jsonl file,\n", "# which is the format expected by the training script.\n", "\n", "print(\"Loading original dataset...\")\n", "dataset = load_dataset(dataset_name, split=\"train\")\n", "\n", "\n", "image_folder_path = \"/content/drive/MyDrive/pepe-data\"\n", "os.makedirs(image_folder_path, exist_ok=True)\n", "print(f\"Created directory for prepared data: {image_folder_path}\")\n", "\n", "metadata_file_path = os.path.join(image_folder_path, \"metadata.jsonl\")\n", "\n", "with open(metadata_file_path, \"w\") as f:\n", " for i, example in enumerate(dataset):\n", " # Get image and caption\n", " image = example[image_column]\n", " caption = example[text_column]\n", "\n", " # Add the trigger word\n", " full_caption = f\"{trigger_word} {caption}\"\n", "\n", " # Save the image\n", " image_filename = f\"image_{i}.png\"\n", " image.save(os.path.join(image_folder_path, image_filename))\n", "\n", " # Write the metadata entry\n", " metadata_entry = {\n", " \"file_name\": image_filename,\n", " text_column: full_caption\n", " }\n", " f.write(json.dumps(metadata_entry) + \"\\n\")\n", "\n", "print(f\"Dataset prepared and saved in 'image folder' format at: {image_folder_path}\")\n", "\n", "\n", "# --- 4. Set up the Training Command ---\n", "# This command now points to our correctly formatted image folder.\n", "write_basic_config()\n", "\n", "command = [\n", " \"accelerate\", \"launch\",\n", " \"train_text_to_image_lora.py\",\n", " f\"--pretrained_model_name_or_path={base_model_id}\",\n", " f\"--train_data_dir={image_folder_path}\",\n", " f\"--caption_column={text_column}\",\n", " f\"--image_column={image_column}\",\n", " f\"--dataloader_num_workers=8\",\n", " f\"--resolution={resolution}\", \"--center_crop\", \"--random_flip\",\n", " f\"--train_batch_size={train_batch_size}\",\n", " f\"--gradient_accumulation_steps={gradient_accumulation_steps}\",\n", " f\"--max_train_steps={max_train_steps}\",\n", " f\"--learning_rate={learning_rate}\",\n", " \"--lr_scheduler=constant\",\n", " \"--lr_warmup_steps=0\",\n", " f\"--output_dir={output_dir}\",\n", " f\"--rank={lora_rank}\",\n", " f\"--validation_prompt='{trigger_word} a sad frog in a blue hoodie, cartoon style'\",\n", " f\"--checkpointing_steps={checkpointing_steps}\",\n", " \"--checkpoints_total_limit=3\",\n", "]\n", "\n", "if push_to_hub:\n", " command.extend([f\"--push_to_hub\", f\"--hub_model_id={hf_hub_repo_id}\"])\n", "\n", "training_command_str = \" \".join(command)\n", "\n", "\n", "# --- 5. Execute the Training ---\n", "print(\"\\n\" + \"=\"*80)\n", "print(\" TRAINING COMMAND\")\n", "print(\"=\"*80)\n", "print(\"The following command will be executed in your terminal:\")\n", "print(training_command_str)\n", "print(\"\\n\" + \"=\"*80)\n", "print(\"To start training, copy the command above and paste it into your terminal.\")\n", "print(\"Make sure you are in the correct environment where the diffusers examples are located.\")\n", "print(\"You may need to clone the diffusers repo first: git clone https://github.com/huggingface/diffusers.git\")\n", "print(\"CORRECTED PATH: Then navigate to: cd diffusers/examples/text_to_image\")\n", "print(\"=\"*80)\n", "\n" ], "metadata": { "id": "RPv7Gv5h--SO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yGDgzchblY6s" }, "outputs": [], "source": [ "import os\n", "import sys\n", "import datasets\n", "import diffusers\n", "import huggingface_hub\n", "import requests\n", "import torch\n", "from dotenv import load_dotenv\n", "from huggingface_hub import HfApi\n", "from IPython.display import display" ] }, { "cell_type": "markdown", "metadata": { "id": "6hoZLPDalY6t" }, "source": [ "We'll print out version number of the critical packages, to help with future reproducibility." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CaRvn_celY6t" }, "outputs": [], "source": [ "print(\"Platform:\", sys.platform)\n", "print(\"Python version:\", sys.version)\n", "print(\"---\")\n", "print(\"datasets version: \", datasets.__version__)\n", "print(\"diffusers version: \", diffusers.__version__)\n", "print(\"huggingface_hub version: \", huggingface_hub.__version__)\n", "print(\"torch version:\", torch.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "VLBQ_2A0lY6u" }, "source": [ "Let's check if a GPU is available. If not, this notebook will take a long time to run!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jWTKdjUDlY6u" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "else:\n", " device = \"cpu\"\n", " dtype = torch.float32\n", "\n", "print(f\"Using {device} device with {dtype} data type.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RCI8s5uylY6u" }, "source": [ "### Load Stable Diffusion" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2RU4U5mulY6w" }, "outputs": [], "source": [ "\n", "MODEL_NAME = \"runwayml/stable-diffusion-v1-5\"\n", "\n", "pipeline = diffusers.AutoPipelineForText2Image.from_pretrained(\n", " MODEL_NAME, torch_dtype=dtype\n", ")\n", "pipeline.to(device)\n", "\n", "print(type(pipeline))" ] }, { "cell_type": "markdown", "metadata": { "id": "BMvqxn99lY6w" }, "source": [ "Test base Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-kBJqj9xlY6w" }, "outputs": [], "source": [ "images = pipeline([\"pepe the frog rolling eyes\"]*1).images\n", "\n", "for im in images:\n", " display(im)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HqZRLoajlY6x" }, "outputs": [], "source": [ "#DATASET_NAME = \"worldquant-university/maya-dataset-v1\"\n", "DATASET_NAME= \"iresidentevil/pepe_the_frog\"\n", "data_builder = datasets.load_dataset_builder(DATASET_NAME)\n", "\n", "print(data_builder.dataset_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EeHRlBmlY6x" }, "outputs": [], "source": [ "print(data_builder.info.features)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rgXvHJJVlY6y" }, "outputs": [], "source": [ "print(data_builder.info.splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-L2YvGMnlY6y" }, "outputs": [], "source": [ "data = datasets.load_dataset(DATASET_NAME)\n", "\n", "print(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k2iL94ILlY6z" }, "outputs": [], "source": [ "data[\"train\"][\"image\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6vBJgSPnlY6z" }, "outputs": [], "source": [ "# The values are PIL images, so they will be displayed\n", "# automatically by Jupyter.\n", "data[\"train\"][\"image\"][3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kbj0aOW9lY6z" }, "outputs": [], "source": [ "# Use dictionary indexing to look up the text values.\n", "data[\"train\"][\"prompt\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "Q0RrkjXVlY60" }, "source": [ "### LoRA Fine-tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "36Jc_ijlwD75" }, "outputs": [], "source": [ "%cd diffusers/examples/text_to_image\n", "\n", "!accelerate launch train_text_to_image_lora.py \\\n", " --pretrained_model_name_or_path=\"runwayml/stable-diffusion-v1-5\" \\\n", " --train_data_dir=image_folder_path \\\n", " --caption_column=\"prompt\" \\\n", " --image_column=\"image\" \\\n", " --resolution=512 --center_crop --random_flip \\\n", " --train_batch_size=1 \\\n", " --gradient_accumulation_steps=4 \\\n", " --max_train_steps=2000 \\\n", " --learning_rate=1e-4 \\\n", " --lr_scheduler=\"cosine\" \\\n", " --lr_warmup_steps=0 \\\n", " --output_dir=output_dir \\\n", " --rank=16 \\\n", " --validation_prompt=\"pepe_style_frog, a high-quality, detailed image of pepe the frog smiling and holding a cup of coffee at sunrise\" \\\n", " --seed=42 \\\n", " --mixed_precision=\"fp16\" \\\n", " --checkpointing_steps=150" ] }, { "cell_type": "markdown", "metadata": { "id": "VKOcWmJ9lY62" }, "source": [ "### Load LoRA Weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SBGjOCmTlY63" }, "outputs": [], "source": [ "pipeline.load_lora_weights(\n", " output_dir,\n", "\n", "\n", " weight_name=\"pytorch_lora_weights.safetensors\",\n", ")\n", "pipeline.safety_checker = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RYRckHGLlY63" }, "outputs": [], "source": [ "images = pipeline([\"pepe_style_frog making fun of rabbit that racing a tortile\"]).images\n", "\n", "for im in images:\n", " display(im)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 0 }