pangkaicheng
commited on
Commit
·
f8a73ec
0
Parent(s):
first commit
Browse files- .gitignore +15 -0
- .python-version +1 -0
- README.md +13 -0
- chainlit.md +14 -0
- chainlit_app.py +159 -0
- mcp_client.py +110 -0
- mcp_server_config.json +16 -0
- mcp_servers/fashion_vlm/__init__.py +0 -0
- mcp_servers/fashion_vlm/fashion_vlm_infer.py +157 -0
- mcp_servers/fashion_vlm/main.py +592 -0
- mcp_servers/fashion_vlm/models/__init__.py +4 -0
- mcp_servers/fashion_vlm/models/clip_encoder.py +140 -0
- mcp_servers/fashion_vlm/models/common_modules.py +357 -0
- mcp_servers/fashion_vlm/models/misc.py +53 -0
- mcp_servers/fashion_vlm/models/modeling_magvitv2.py +440 -0
- mcp_servers/fashion_vlm/models/modeling_showo.py +237 -0
- mcp_servers/fashion_vlm/models/modeling_utils.py +1207 -0
- mcp_servers/fashion_vlm/models/phi.py +1489 -0
- mcp_servers/fashion_vlm/models/sampling.py +118 -0
- mcp_servers/fashion_vlm/prompting_utils.py +628 -0
- mcp_servers/product_user_database.py +463 -0
- mcp_servers/virtual_try_on.py +94 -0
- requirements.txt +228 -0
- system_message.py +9 -0
- utils.py +74 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv/
|
| 11 |
+
.env
|
| 12 |
+
|
| 13 |
+
generated_images/
|
| 14 |
+
.chainlit/
|
| 15 |
+
.idea/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FashionRec Dataset download is required
|
| 2 |
+
|
| 3 |
+
.env file is required, including
|
| 4 |
+
|
| 5 |
+
CHAINLIT_PORT=8888
|
| 6 |
+
|
| 7 |
+
PROXY=http://127.0.0.1:10809
|
| 8 |
+
|
| 9 |
+
OPENAI_API_KEY=
|
| 10 |
+
GEMINI_API_KEY=
|
| 11 |
+
|
| 12 |
+
FASHION_DATA_ROOT=path_to_FashionRec
|
| 13 |
+
GEN_IMG_DIR="./generated_images"
|
chainlit.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Welcome to Chainlit! 🚀🤖
|
| 2 |
+
|
| 3 |
+
Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
|
| 4 |
+
|
| 5 |
+
## Useful Links 🔗
|
| 6 |
+
|
| 7 |
+
- **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
|
| 8 |
+
- **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
|
| 9 |
+
|
| 10 |
+
We can't wait to see what you create with Chainlit! Happy coding! 💻😊
|
| 11 |
+
|
| 12 |
+
## Welcome screen
|
| 13 |
+
|
| 14 |
+
To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
|
chainlit_app.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from openai import AsyncOpenAI
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import chainlit as cl
|
| 9 |
+
|
| 10 |
+
from system_message import SYSTEM_MESSAGE
|
| 11 |
+
from mcp_client import MCPClient
|
| 12 |
+
from utils import create_image_grid
|
| 13 |
+
import httpx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Load environment variables from .env
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
# Initialize OpenAI client
|
| 20 |
+
CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888")
|
| 21 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 22 |
+
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
|
| 23 |
+
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT")
|
| 24 |
+
PROXY = os.getenv("PROXY")
|
| 25 |
+
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet")
|
| 26 |
+
item_id_set = set(items_df.item_id)
|
| 27 |
+
|
| 28 |
+
http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class FashionAgent:
|
| 32 |
+
def __init__(self, user_id=None):
|
| 33 |
+
self.mcp_client = MCPClient("mcp_server_config.json", user_id)
|
| 34 |
+
self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client)
|
| 35 |
+
self.user_id = user_id
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# 全局 FashionAgent 实例
|
| 39 |
+
agent = FashionAgent(user_id=None)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@cl.on_chat_start
|
| 43 |
+
async def on_chat_start():
|
| 44 |
+
await agent.mcp_client.connect_to_servers()
|
| 45 |
+
cl.user_session.set("agent", agent)
|
| 46 |
+
await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@cl.on_message
|
| 50 |
+
async def on_message(message: cl.Message):
|
| 51 |
+
agent = cl.user_session.get("agent")
|
| 52 |
+
user_id = cl.user_session.get("user_id")
|
| 53 |
+
chat_history = cl.user_session.get("chat_history", [])
|
| 54 |
+
|
| 55 |
+
user_message = message.content
|
| 56 |
+
|
| 57 |
+
upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)]
|
| 58 |
+
if len(upload_image) == 1:
|
| 59 |
+
user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}"
|
| 60 |
+
elif len(upload_image) > 1:
|
| 61 |
+
merged_image_path = f".files/{uuid.uuid4().hex}.jpg"
|
| 62 |
+
create_image_grid(upload_image[:4], merged_image_path)
|
| 63 |
+
|
| 64 |
+
user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}"
|
| 65 |
+
|
| 66 |
+
image_in_database = []
|
| 67 |
+
for image in message.elements:
|
| 68 |
+
if isinstance(image, cl.Image):
|
| 69 |
+
item_id = image.name.split(".")[0]
|
| 70 |
+
if item_id in item_id_set:
|
| 71 |
+
image_in_database.append(item_id)
|
| 72 |
+
if len(image_in_database) > 0:
|
| 73 |
+
user_message += f"\nUser id is: {user_id}"
|
| 74 |
+
user_message += f"\nlist_of_items are: {image_in_database}"
|
| 75 |
+
elif user_id:
|
| 76 |
+
user_message += f"\nUser id is: {user_id}"
|
| 77 |
+
|
| 78 |
+
# Prepare messages for OpenAI API
|
| 79 |
+
messages = [
|
| 80 |
+
{"role": "system", "content": SYSTEM_MESSAGE},
|
| 81 |
+
*[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history],
|
| 82 |
+
{"role": "user", "content": user_message}
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# Fetch available tools
|
| 86 |
+
available_tools = await agent.mcp_client.get_tools()
|
| 87 |
+
|
| 88 |
+
# Initial OpenAI API call
|
| 89 |
+
response = await agent.openai.chat.completions.create(
|
| 90 |
+
model="gpt-4o-mini",
|
| 91 |
+
messages=messages,
|
| 92 |
+
max_tokens=1000,
|
| 93 |
+
tools=available_tools if available_tools else None,
|
| 94 |
+
tool_choice="auto" if available_tools else None
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Process the response
|
| 98 |
+
response_message = response.choices[0].message
|
| 99 |
+
if response_message.tool_calls:
|
| 100 |
+
# Handle tool calls
|
| 101 |
+
for tool_call in response_message.tool_calls:
|
| 102 |
+
tool_name = tool_call.function.name
|
| 103 |
+
params = json.loads(tool_call.function.arguments)
|
| 104 |
+
try:
|
| 105 |
+
print(f"Agent execute {tool_name} with params: {params}")
|
| 106 |
+
result = await agent.mcp_client.execute_tool(tool_name, params)
|
| 107 |
+
if tool_name == "retrieve_image":
|
| 108 |
+
image_path = json.loads(result['result'][0].text)['image_path']
|
| 109 |
+
similarity = json.loads(result['result'][0].text)['similarity']
|
| 110 |
+
output = f"I found a matching fashion item with a similarity score of {similarity:.2f}"
|
| 111 |
+
|
| 112 |
+
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
|
| 113 |
+
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
|
| 114 |
+
if tool_name == "image_generate":
|
| 115 |
+
image_path = result['result'][0].text
|
| 116 |
+
images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
|
| 117 |
+
output = f"Here is the generated image."
|
| 118 |
+
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
|
| 119 |
+
if tool_name == "fashion_recommend_without_image":
|
| 120 |
+
output = result['result'][0].text
|
| 121 |
+
await cl.Message(content=output, author="Fashion Agent").send()
|
| 122 |
+
if tool_name == "fashion_recommend":
|
| 123 |
+
output = json.loads(result['result'][0].text)['recommendation']
|
| 124 |
+
# user_preference = json.loads(result['result'][0].text)['user_preference']
|
| 125 |
+
# await cl.Message(content=user_preference, author="Fashion Agent").send()
|
| 126 |
+
await cl.Message(content=output, author="Fashion Agent").send()
|
| 127 |
+
if tool_name == "try_on":
|
| 128 |
+
image_path = result['result'][0].text
|
| 129 |
+
images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")]
|
| 130 |
+
output = f"Here is the virtual try-on image."
|
| 131 |
+
await cl.Message(content=output, elements=images, author="Fashion Agent").send()
|
| 132 |
+
else:
|
| 133 |
+
output = result
|
| 134 |
+
except Exception as e:
|
| 135 |
+
output = f"Error executing tool {tool_name}: {str(e)}"
|
| 136 |
+
|
| 137 |
+
# Update chat history
|
| 138 |
+
chat_history.append(cl.Message(content=message.content, author="user"))
|
| 139 |
+
chat_history.append(cl.Message(content=output, author="assistant"))
|
| 140 |
+
cl.user_session.set("chat_history", chat_history)
|
| 141 |
+
else:
|
| 142 |
+
# Direct response from the model
|
| 143 |
+
output = response_message.content
|
| 144 |
+
chat_history.append(cl.Message(content=message.content, author="user"))
|
| 145 |
+
chat_history.append(cl.Message(content=output, author="assistant"))
|
| 146 |
+
cl.user_session.set("chat_history", chat_history)
|
| 147 |
+
await cl.Message(content=output, author="Fashion Agent").send()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@cl.on_chat_end
|
| 151 |
+
def on_chat_end():
|
| 152 |
+
print("Goodbye", cl.user_session.get("id"))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
from chainlit.cli import run_chainlit
|
| 157 |
+
|
| 158 |
+
os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT
|
| 159 |
+
run_chainlit(__file__)
|
mcp_client.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List, Optional
|
| 2 |
+
from contextlib import AsyncExitStack
|
| 3 |
+
import json
|
| 4 |
+
import aiohttp
|
| 5 |
+
import asyncio
|
| 6 |
+
|
| 7 |
+
from mcp import ClientSession, StdioServerParameters
|
| 8 |
+
from mcp.client.stdio import stdio_client
|
| 9 |
+
from mcp.client.sse import sse_client
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MCPClient:
|
| 13 |
+
def __init__(self, config_path: str, user_id: Optional[str] = None):
|
| 14 |
+
"""
|
| 15 |
+
Initialize MCPClient with a list of server configurations.
|
| 16 |
+
Each config should be a dict with 'path' (script path) and optionally 'type' (python/node).
|
| 17 |
+
"""
|
| 18 |
+
self.user_id = user_id
|
| 19 |
+
with open(config_path, 'r') as f:
|
| 20 |
+
self.server_configs = json.load(f)['mcpServers']
|
| 21 |
+
self.sessions: Dict[str, Any] = {} # 存储 stdio 的 ClientSession 或 sse 的 aiohttp session
|
| 22 |
+
self.exit_stack = AsyncExitStack()
|
| 23 |
+
|
| 24 |
+
async def connect_to_servers(self):
|
| 25 |
+
"""Connect to all configured MCP servers based on their transport type."""
|
| 26 |
+
for server_name, config in self.server_configs.items():
|
| 27 |
+
transport = config.get("transport", "stdio") # 默认使用 stdio
|
| 28 |
+
print(f"Connecting to {server_name} ({transport})...")
|
| 29 |
+
if transport == "stdio":
|
| 30 |
+
command = config.get("command")
|
| 31 |
+
args = config.get("args", [])
|
| 32 |
+
env = config.get("env", None)
|
| 33 |
+
if not command:
|
| 34 |
+
raise ValueError(f"No command specified for server {server_name}")
|
| 35 |
+
|
| 36 |
+
server_params = StdioServerParameters(command=command, args=args, env=env)
|
| 37 |
+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
| 38 |
+
stdio, write = stdio_transport
|
| 39 |
+
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
|
| 40 |
+
await session.initialize()
|
| 41 |
+
self.sessions[server_name] = session
|
| 42 |
+
# self.stdio_transports[server_name] = (stdio, write)
|
| 43 |
+
elif transport == "sse":
|
| 44 |
+
server_url = config.get("url", "")
|
| 45 |
+
if not server_url:
|
| 46 |
+
raise ValueError(f"No base_url specified for server {server_name}")
|
| 47 |
+
|
| 48 |
+
# 建立 SSE 连接
|
| 49 |
+
streams_context = sse_client(url=f"{server_url}/sse")
|
| 50 |
+
streams = await self.exit_stack.enter_async_context(streams_context)
|
| 51 |
+
session_context = ClientSession(*streams)
|
| 52 |
+
session = await self.exit_stack.enter_async_context(session_context)
|
| 53 |
+
|
| 54 |
+
# 初始化会话
|
| 55 |
+
await session.initialize()
|
| 56 |
+
self.sessions[server_name] = session
|
| 57 |
+
# self.sse_contexts[server_name] = (streams_context, session_context)
|
| 58 |
+
|
| 59 |
+
# 验证连接
|
| 60 |
+
print(f"Initialized SSE client for {server_name}...")
|
| 61 |
+
print("Listing tools...")
|
| 62 |
+
response = await session.list_tools()
|
| 63 |
+
tools = response.tools
|
| 64 |
+
print(f"Connected to {server_name} with tools:", [tool.name for tool in tools])
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unsupported transport type '{transport}' for {server_name}")
|
| 67 |
+
|
| 68 |
+
async def get_tools(self) -> List[Dict[str, Any]]:
|
| 69 |
+
"""
|
| 70 |
+
Fetch the list of available tools from all connected MCP servers.
|
| 71 |
+
Returns a list of tool definitions with name, description, and inputSchema.
|
| 72 |
+
"""
|
| 73 |
+
all_tools = []
|
| 74 |
+
for server_name, session in self.sessions.items():
|
| 75 |
+
response = await session.list_tools()
|
| 76 |
+
for tool in response.tools:
|
| 77 |
+
if not self.user_id and tool.name == 'personalized_fashion_recommend':
|
| 78 |
+
continue
|
| 79 |
+
all_tools.append(
|
| 80 |
+
{
|
| 81 |
+
"type": "function",
|
| 82 |
+
"function": {
|
| 83 |
+
"name": tool.name,
|
| 84 |
+
"description": tool.description,
|
| 85 |
+
"parameters": tool.inputSchema
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
return all_tools
|
| 90 |
+
|
| 91 |
+
async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 92 |
+
"""
|
| 93 |
+
Execute a tool with the given parameters on the appropriate server.
|
| 94 |
+
"""
|
| 95 |
+
# Find which server has this tool
|
| 96 |
+
for server_name, session in self.sessions.items():
|
| 97 |
+
response = await session.list_tools()
|
| 98 |
+
for tool in response.tools:
|
| 99 |
+
if tool.name == tool_name:
|
| 100 |
+
# Execute the tool on the correct server
|
| 101 |
+
result = await session.call_tool(tool_name, params)
|
| 102 |
+
return {
|
| 103 |
+
"result": result.content,
|
| 104 |
+
"server": server_name
|
| 105 |
+
}
|
| 106 |
+
raise Exception(f"Tool {tool_name} not found on any connected server")
|
| 107 |
+
|
| 108 |
+
async def close(self):
|
| 109 |
+
"""Close all server connections."""
|
| 110 |
+
await self.exit_stack.aclose()
|
mcp_server_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mcpServers": {
|
| 3 |
+
"fashion_vlm": {
|
| 4 |
+
"transport": "sse",
|
| 5 |
+
"url": "http://localhost:8000"
|
| 6 |
+
},
|
| 7 |
+
"virtual_try_on": {
|
| 8 |
+
"command": "python",
|
| 9 |
+
"args": ["mcp_servers/virtual_try_on.py"],
|
| 10 |
+
"env": {
|
| 11 |
+
"HTTP_PROXY": "http://127.0.0.1:10809",
|
| 12 |
+
"HTTPS_PROXY": "http://127.0.0.1:10809"
|
| 13 |
+
}
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
}
|
mcp_servers/fashion_vlm/__init__.py
ADDED
|
File without changes
|
mcp_servers/fashion_vlm/fashion_vlm_infer.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import os
|
| 3 |
+
import datetime
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import tensor
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from models import Showo, MAGVITv2, CLIPVisionTower, get_mask_chedule
|
| 12 |
+
from prompting_utils import (UniversalPrompting,
|
| 13 |
+
create_attention_mask_for_mmu,
|
| 14 |
+
create_attention_mask_predict_next)
|
| 15 |
+
from transformers import AutoTokenizer, CLIPImageProcessor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def image_transform(image, resolution=256, normalize=True):
|
| 19 |
+
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
|
| 20 |
+
image = transforms.CenterCrop((resolution, resolution))(image)
|
| 21 |
+
image = transforms.ToTensor()(image)
|
| 22 |
+
if normalize:
|
| 23 |
+
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FashionVLM:
|
| 28 |
+
def __init__(self, temperature, top_k, max_new_tokens, fashion_vlm_name,
|
| 29 |
+
batch_size=3, guidance_scale=5, generation_temperature=1.0, generation_timesteps=50,
|
| 30 |
+
save_dir="generated_images"):
|
| 31 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
self.temperature = temperature # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
| 33 |
+
self.top_k = top_k # retain only the top_k most likely tokens, clamp others to have 0 probability
|
| 34 |
+
self.max_new_tokens = max_new_tokens
|
| 35 |
+
self.fashion_vlm_name = fashion_vlm_name
|
| 36 |
+
|
| 37 |
+
#param for t2i
|
| 38 |
+
self.save_dir = save_dir
|
| 39 |
+
self.batch_size = batch_size
|
| 40 |
+
self.guidance_scale = guidance_scale
|
| 41 |
+
self.generation_temperature = generation_temperature
|
| 42 |
+
self.generation_timesteps = generation_timesteps
|
| 43 |
+
|
| 44 |
+
self._init_models()
|
| 45 |
+
|
| 46 |
+
def _init_models(self):
|
| 47 |
+
# 初始化Universal Prompting
|
| 48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 49 |
+
"microsoft/phi-1_5",
|
| 50 |
+
padding_side="left"
|
| 51 |
+
)
|
| 52 |
+
self.uni_prompting = UniversalPrompting(
|
| 53 |
+
self.tokenizer,
|
| 54 |
+
max_text_len=381,
|
| 55 |
+
special_tokens=(
|
| 56 |
+
"<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"
|
| 57 |
+
),
|
| 58 |
+
ignore_id=-100, cond_dropout_prob=0.1
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# 初始化VQ模型
|
| 62 |
+
self.vq_model = MAGVITv2.from_pretrained("showlab/magvitv2").to(self.device)
|
| 63 |
+
self.vq_model.requires_grad_(False)
|
| 64 |
+
self.vq_model.eval()
|
| 65 |
+
|
| 66 |
+
self.model = Showo.from_pretrained(self.fashion_vlm_name).to(self.device)
|
| 67 |
+
self.model.eval()
|
| 68 |
+
|
| 69 |
+
def mmu_infer_tensor(self, image: tensor, prompt: tensor):
|
| 70 |
+
"""
|
| 71 |
+
Image size: batch * 3 * 256 * 256
|
| 72 |
+
"""
|
| 73 |
+
image = image.to(self.device)
|
| 74 |
+
prompt = prompt.to(self.device)
|
| 75 |
+
# pixel_values = self.clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
|
| 76 |
+
image_tokens = self.vq_model.get_code(image) + len(self.uni_prompting.text_tokenizer)
|
| 77 |
+
|
| 78 |
+
input_ids = torch.cat([
|
| 79 |
+
(torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|mmu|>']).to(self.device),
|
| 80 |
+
(torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|soi|>']).to(self.device),
|
| 81 |
+
image_tokens,
|
| 82 |
+
(torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|eoi|>']).to(self.device),
|
| 83 |
+
(torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|sot|>']).to(self.device),
|
| 84 |
+
prompt
|
| 85 |
+
], dim=1).long()
|
| 86 |
+
|
| 87 |
+
attention_mask = create_attention_mask_for_mmu(
|
| 88 |
+
input_ids.to(self.device),
|
| 89 |
+
eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>'])
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
cont_toks_list = self.model.mmu_generate(
|
| 93 |
+
input_ids, attention_mask=attention_mask,
|
| 94 |
+
max_new_tokens=self.max_new_tokens, top_k=self.top_k,
|
| 95 |
+
eot_token=self.uni_prompting.sptids_dict['<|eot|>']
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
|
| 99 |
+
|
| 100 |
+
text = self.uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
|
| 101 |
+
return text
|
| 102 |
+
|
| 103 |
+
def t2i_infer(self, prompts: List[str]):
|
| 104 |
+
output_path = []
|
| 105 |
+
for step in tqdm(range(0, len(prompts), self.batch_size)):
|
| 106 |
+
batch_prompt = prompts[step:step + self.batch_size]
|
| 107 |
+
image_tokens = torch.ones((len(batch_prompt), self.model.config.num_vq_tokens),
|
| 108 |
+
dtype=torch.long, device=self.device) * self.model.config.mask_token_id
|
| 109 |
+
input_ids, _ = self.uni_prompting((batch_prompt, image_tokens), 't2i_gen')
|
| 110 |
+
if self.guidance_scale > 0:
|
| 111 |
+
uncond_input_ids, _ = self.uni_prompting(([''] * len(batch_prompt), image_tokens), 't2i_gen')
|
| 112 |
+
attention_mask = create_attention_mask_predict_next(
|
| 113 |
+
torch.cat([input_ids, uncond_input_ids], dim=0),
|
| 114 |
+
pad_id=int(self.uni_prompting.sptids_dict['<|pad|>']),
|
| 115 |
+
soi_id=int(self.uni_prompting.sptids_dict['<|soi|>']),
|
| 116 |
+
eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>']),
|
| 117 |
+
rm_pad_in_image=True
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
uncond_input_ids = None
|
| 121 |
+
attention_mask = create_attention_mask_predict_next(
|
| 122 |
+
input_ids,
|
| 123 |
+
pad_id=int(self.uni_prompting.sptids_dict['<|pad|>']),
|
| 124 |
+
soi_id=int(self.uni_prompting.sptids_dict['<|soi|>']),
|
| 125 |
+
eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>']),
|
| 126 |
+
rm_pad_in_image=True
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
mask_schedule = get_mask_chedule("cosine")
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
gen_token_ids = self.model.t2i_generate(
|
| 132 |
+
input_ids=input_ids,
|
| 133 |
+
uncond_input_ids=uncond_input_ids,
|
| 134 |
+
attention_mask=attention_mask,
|
| 135 |
+
guidance_scale=self.guidance_scale,
|
| 136 |
+
temperature=self.generation_temperature,
|
| 137 |
+
timesteps=self.generation_timesteps,
|
| 138 |
+
noise_schedule=mask_schedule,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=self.model.config.codebook_size - 1, min=0)
|
| 142 |
+
images = self.vq_model.decode_code(gen_token_ids)
|
| 143 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
| 144 |
+
images *= 255.0
|
| 145 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
| 146 |
+
|
| 147 |
+
# 保存图片
|
| 148 |
+
for idx, image in enumerate(images, start=1):
|
| 149 |
+
image = Image.fromarray(image)
|
| 150 |
+
# 使用时间戳和索引创建唯一的文件名
|
| 151 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 152 |
+
image_filename = f"{timestamp}_{step + idx}.jpg"
|
| 153 |
+
image_path = os.path.join(self.save_dir, image_filename)
|
| 154 |
+
image.save(image_path)
|
| 155 |
+
output_path.append(image_path)
|
| 156 |
+
|
| 157 |
+
return output_path
|
mcp_servers/fashion_vlm/main.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
import numpy as np
|
| 4 |
+
from itertools import combinations
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from scipy import sparse
|
| 9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
+
import pickle
|
| 11 |
+
import uvicorn
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from transformers import AutoTokenizer, CLIPProcessor, CLIPModel
|
| 16 |
+
from mcp.server.fastmcp import FastMCP
|
| 17 |
+
from mcp.server.sse import SseServerTransport
|
| 18 |
+
from starlette.routing import Route, Mount
|
| 19 |
+
from starlette.applications import Starlette
|
| 20 |
+
from openai import AsyncOpenAI
|
| 21 |
+
|
| 22 |
+
from fashion_vlm_infer import FashionVLM
|
| 23 |
+
from prompting_utils import UniversalPrompting
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Load environment variables
|
| 27 |
+
load_dotenv()
|
| 28 |
+
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT", "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec")
|
| 29 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 30 |
+
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
|
| 31 |
+
openai = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
| 32 |
+
VALID_CATEGORIES = [
|
| 33 |
+
'Pants', 'Coats', 'Cross-body bags', 'Shirts', 'Hats & caps', 'Sneakers', 'Jeans', 'Boots', 'Dresses', 'Sandals',
|
| 34 |
+
'T-shirts & vests', 'Knitwear', 'Skirts', 'Earrings', 'Hats', 'Sweaters & knitwear', 'Loafers', 'Ballet flats',
|
| 35 |
+
'Espadrilles', 'Tote bags', 'Shoulder bags', 'Slides & flip flops', 'Pumps', 'Necklaces', 'Polo shirts', 'Suits',
|
| 36 |
+
'Oxford shoes', 'Bracelets', 'Jackets', 'Tops', 'Rings', 'Mules', 'Luggage & holdalls', 'Brogues', 'Activewear',
|
| 37 |
+
'Belts', 'Derby shoes', 'Mini bags', 'Watches', 'Backpacks', 'Denim', 'Laptop bags & briefcases', 'Clutch bags',
|
| 38 |
+
'Clutches', 'Lingerie & Nightwear', 'Skiwear', 'Sunglasses', 'Ties & bow ties', 'Shorts', 'Scarves', 'Messenger bags'
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
###################################
|
| 43 |
+
#########Loading Data##############
|
| 44 |
+
###################################
|
| 45 |
+
# Load item metadata
|
| 46 |
+
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet").set_index("item_id")
|
| 47 |
+
outfits_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/outfits_lite.parquet").set_index("outfit_id")
|
| 48 |
+
users_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/users_lite.parquet").set_index("user_id")
|
| 49 |
+
image_paths = items_df["path"].to_dict()
|
| 50 |
+
|
| 51 |
+
###################################
|
| 52 |
+
#########Loading Model#############
|
| 53 |
+
###################################
|
| 54 |
+
# Load CLIP model and processor
|
| 55 |
+
print("Loading CLIP Model")
|
| 56 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
|
| 57 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
|
| 58 |
+
clip_model.eval()
|
| 59 |
+
|
| 60 |
+
print("Loading Fashion VLM params")
|
| 61 |
+
|
| 62 |
+
fashion_vlm = FashionVLM(
|
| 63 |
+
max_new_tokens=1000,
|
| 64 |
+
temperature=0.8,
|
| 65 |
+
top_k=1,
|
| 66 |
+
fashion_vlm_name='Anony100/FashionVLM',
|
| 67 |
+
save_dir="/mnt/d/PostDoc/fifth paper/code/FashionM3/generated_images"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
resolution = 512
|
| 71 |
+
image_transform = transforms.Compose([
|
| 72 |
+
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 73 |
+
transforms.CenterCrop((resolution, resolution)),
|
| 74 |
+
transforms.ToTensor(),
|
| 75 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
print("Loading tokenizer")
|
| 79 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-1_5', padding_side="left")
|
| 80 |
+
uni_prompting = UniversalPrompting(
|
| 81 |
+
tokenizer,
|
| 82 |
+
max_text_len=128,
|
| 83 |
+
special_tokens=(
|
| 84 |
+
"<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"
|
| 85 |
+
),
|
| 86 |
+
ignore_id=-100, cond_dropout_prob=0.1
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InteractionDataManager:
|
| 91 |
+
def __init__(self, users_df, outfits_df, items_df):
|
| 92 |
+
"""
|
| 93 |
+
初始化类,加载数据并设置基本参数
|
| 94 |
+
|
| 95 |
+
参数:
|
| 96 |
+
- users_file: 用户数据文件路径 (parquet)
|
| 97 |
+
- outfits_file: Outfit 数据文件路径 (parquet)
|
| 98 |
+
- items_file: 单品数据文件路径 (parquet)
|
| 99 |
+
"""
|
| 100 |
+
self.users_df = users_df
|
| 101 |
+
self.outfits_df = outfits_df
|
| 102 |
+
self.items_df = items_df
|
| 103 |
+
|
| 104 |
+
# 创建映射
|
| 105 |
+
self.item_id_to_index = {item_id: index for index, item_id in enumerate(self.items_df.index)}
|
| 106 |
+
self.index_to_item_id = {index: item_id for index, item_id in enumerate(self.items_df.index)}
|
| 107 |
+
self.user_id_to_index = {user_id: index for index, user_id in enumerate(self.users_df.index)}
|
| 108 |
+
self.index_to_user_id = {index: user_id for index, user_id in enumerate(self.users_df.index)}
|
| 109 |
+
self.outfit_ids_dict = self.outfits_df['item_ids'].to_dict() # get outfit's item ids from outfit id
|
| 110 |
+
self.item_category_dict = self.items_df['category'].to_dict() # get item's category from item id
|
| 111 |
+
self.item_subcategory_dict = self.items_df['subcategory'].to_dict() # get item's subcategory from item id
|
| 112 |
+
self.n_items = len(self.items_df)
|
| 113 |
+
self.n_users = len(self.users_df)
|
| 114 |
+
|
| 115 |
+
self.user_outfit_pairs = []
|
| 116 |
+
outfit_set = set(self.outfits_df.index)
|
| 117 |
+
for uid, user in self.users_df.iterrows():
|
| 118 |
+
oids = user.outfit_ids.split(",")
|
| 119 |
+
self.user_outfit_pairs.extend([(uid, oid) for oid in oids if oid in outfit_set])
|
| 120 |
+
|
| 121 |
+
# 预处理类别到物品ID的映射(使用groupby)
|
| 122 |
+
self.subcategory_to_items = self.items_df.groupby('subcategory').apply(lambda x: set(x.index)).to_dict()
|
| 123 |
+
|
| 124 |
+
# 预处理类别到物品索引的映射(优化查找效率)
|
| 125 |
+
self.subcategory_to_indices = {}
|
| 126 |
+
for subcategory, item_ids in self.subcategory_to_items.items():
|
| 127 |
+
self.subcategory_to_indices[subcategory] = set([self.item_id_to_index[item_id]
|
| 128 |
+
for item_id in item_ids
|
| 129 |
+
if item_id in self.item_id_to_index])
|
| 130 |
+
|
| 131 |
+
item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/item_matrix.npz'
|
| 132 |
+
try:
|
| 133 |
+
self.load_matrix('item', item_interaction_matrix_path)
|
| 134 |
+
except FileNotFoundError:
|
| 135 |
+
self.build_item_interaction_matrix()
|
| 136 |
+
self.save_matrix('item', item_interaction_matrix_path)
|
| 137 |
+
|
| 138 |
+
user_item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/user_item_matrix.npz'
|
| 139 |
+
try:
|
| 140 |
+
self.load_matrix('user_item', user_item_interaction_matrix_path)
|
| 141 |
+
except FileNotFoundError:
|
| 142 |
+
self.build_user_item_interaction_matrix()
|
| 143 |
+
self.save_matrix('user_item', user_item_interaction_matrix_path)
|
| 144 |
+
|
| 145 |
+
# 加载item clip features
|
| 146 |
+
with open(f"{FASHION_DATA_ROOT}/meta/clip_features.pkl", "rb") as f:
|
| 147 |
+
print("Loading Fashion Features...")
|
| 148 |
+
self.clip_features = pickle.load(f)
|
| 149 |
+
print("Loading Fashion Features Successfully")
|
| 150 |
+
|
| 151 |
+
# Prepare embeddings and item IDs
|
| 152 |
+
self.item_ids = list(self.clip_features.keys())
|
| 153 |
+
self.image_embeddings = np.array([self.clip_features[item_id]["image_embeds"] for item_id in self.item_ids])
|
| 154 |
+
|
| 155 |
+
def save_matrix(self, matrix_type, filepath):
|
| 156 |
+
"""
|
| 157 |
+
保存矩阵到文件
|
| 158 |
+
|
| 159 |
+
参数:
|
| 160 |
+
- matrix_type: 'item' 或 'user_item',指定保存的矩阵类型
|
| 161 |
+
- filepath: 保存路径 (例如 'temp/item_matrix.npz')
|
| 162 |
+
"""
|
| 163 |
+
if matrix_type == 'item':
|
| 164 |
+
matrix = self.item_interaction_matrix
|
| 165 |
+
elif matrix_type == 'user_item':
|
| 166 |
+
matrix = self.user_item_interaction_matrix
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError("matrix_type must be 'item' or 'user_item'")
|
| 169 |
+
|
| 170 |
+
if matrix is None:
|
| 171 |
+
raise ValueError(f"{matrix_type} matrix has not been built yet.")
|
| 172 |
+
|
| 173 |
+
sparse.save_npz(filepath, matrix)
|
| 174 |
+
print(f"Saved {matrix_type} matrix to {filepath}")
|
| 175 |
+
|
| 176 |
+
def load_matrix(self, matrix_type, filepath):
|
| 177 |
+
"""
|
| 178 |
+
从文件加载矩阵
|
| 179 |
+
|
| 180 |
+
参数:
|
| 181 |
+
- matrix_type: 'item' 或 'user_item',指定加载的矩阵类型
|
| 182 |
+
- filepath: 加载路径 (例如 'temp/item_matrix.npz')
|
| 183 |
+
"""
|
| 184 |
+
if not os.path.exists(filepath):
|
| 185 |
+
raise FileNotFoundError(f"File {filepath} does not exist.")
|
| 186 |
+
|
| 187 |
+
matrix = sparse.load_npz(filepath)
|
| 188 |
+
if matrix_type == 'item':
|
| 189 |
+
self.item_interaction_matrix = matrix
|
| 190 |
+
elif matrix_type == 'user_item':
|
| 191 |
+
self.user_item_interaction_matrix = matrix
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError("matrix_type must be 'item' or 'user_item'")
|
| 194 |
+
|
| 195 |
+
print(f"Loaded {matrix_type} matrix from {filepath}")
|
| 196 |
+
return matrix
|
| 197 |
+
|
| 198 |
+
def build_item_interaction_matrix(self):
|
| 199 |
+
"""构建 Item-Item 交互矩阵"""
|
| 200 |
+
# 初始化单品交互矩阵
|
| 201 |
+
self.item_interaction_matrix = sparse.lil_matrix((self.n_items, self.n_items), dtype=int)
|
| 202 |
+
|
| 203 |
+
for index, outfit in self.outfits_df.iterrows():
|
| 204 |
+
item_ids = outfit['item_ids'].split(',')
|
| 205 |
+
# 记录 item 对的共现
|
| 206 |
+
for item_id1, item_id2 in combinations(item_ids, r=2):
|
| 207 |
+
if item_id1 in self.item_id_to_index and item_id2 in self.item_id_to_index:
|
| 208 |
+
idx1 = self.item_id_to_index[item_id1]
|
| 209 |
+
idx2 = self.item_id_to_index[item_id2]
|
| 210 |
+
self.item_interaction_matrix[idx1, idx2] += 1
|
| 211 |
+
self.item_interaction_matrix[idx2, idx1] += 1 # 无序对称
|
| 212 |
+
|
| 213 |
+
# 转换为 CSR 格式
|
| 214 |
+
self.item_interaction_matrix = self.item_interaction_matrix.tocsr()
|
| 215 |
+
return self.item_interaction_matrix
|
| 216 |
+
|
| 217 |
+
def build_user_item_interaction_matrix(self):
|
| 218 |
+
"""构建 User-Item 交互矩阵"""
|
| 219 |
+
# 初始化用户-单品交互矩阵
|
| 220 |
+
self.user_item_interaction_matrix = sparse.lil_matrix((self.n_users, self.n_items), dtype=int)
|
| 221 |
+
|
| 222 |
+
for uid, user in self.users_df.iterrows():
|
| 223 |
+
oids = user["outfit_ids"].split(",")
|
| 224 |
+
outfits = self.outfits_df.loc[self.outfits_df.index.isin(oids)]
|
| 225 |
+
for oid, outfit in outfits.iterrows():
|
| 226 |
+
item_ids = outfit['item_ids'].split(',')
|
| 227 |
+
# 记录 user-item 对的出现
|
| 228 |
+
for iid in item_ids:
|
| 229 |
+
if iid in self.item_id_to_index:
|
| 230 |
+
uidx = self.user_id_to_index[uid]
|
| 231 |
+
iidx = self.item_id_to_index[iid]
|
| 232 |
+
self.user_item_interaction_matrix[uidx, iidx] += 1
|
| 233 |
+
|
| 234 |
+
# 转换为 CSR 格式
|
| 235 |
+
self.user_item_interaction_matrix = self.user_item_interaction_matrix.tocsr()
|
| 236 |
+
return self.user_item_interaction_matrix
|
| 237 |
+
|
| 238 |
+
def _process_interactions_for_category(
|
| 239 |
+
self,
|
| 240 |
+
matrix,
|
| 241 |
+
given_id,
|
| 242 |
+
category_indices,
|
| 243 |
+
id_to_index
|
| 244 |
+
):
|
| 245 |
+
"""
|
| 246 |
+
处理单个实体与目标类别的交互
|
| 247 |
+
|
| 248 |
+
参数:
|
| 249 |
+
- matrix: 交互矩阵
|
| 250 |
+
- given_id: 给定的实体ID(用户或物品)
|
| 251 |
+
- category_indices: 目标类别的物品索引集合
|
| 252 |
+
|
| 253 |
+
返回:
|
| 254 |
+
- 交互列表,每个元素为一个包含item_id、interaction_count和score的字典
|
| 255 |
+
"""
|
| 256 |
+
interactions = []
|
| 257 |
+
|
| 258 |
+
given_index = id_to_index[given_id]
|
| 259 |
+
row = matrix[given_index]
|
| 260 |
+
|
| 261 |
+
# 提取该行的非零元素
|
| 262 |
+
row_start = row.indptr[0]
|
| 263 |
+
row_end = row.indptr[1]
|
| 264 |
+
col_indices = row.indices[row_start:row_end]
|
| 265 |
+
data_values = row.data[row_start:row_end]
|
| 266 |
+
|
| 267 |
+
# 筛选出属于目标类别的物品
|
| 268 |
+
for col_idx, value in zip(col_indices, data_values):
|
| 269 |
+
# 检查是否为目标类别的物品
|
| 270 |
+
if col_idx in category_indices:
|
| 271 |
+
# 获取物品ID
|
| 272 |
+
output_id = self.index_to_item_id[col_idx]
|
| 273 |
+
interactions.append({
|
| 274 |
+
'item_id': output_id,
|
| 275 |
+
'interaction_count': int(value),
|
| 276 |
+
'score': 0.0
|
| 277 |
+
})
|
| 278 |
+
|
| 279 |
+
return interactions
|
| 280 |
+
|
| 281 |
+
def get_item_category_interactions(
|
| 282 |
+
self,
|
| 283 |
+
target_category: str,
|
| 284 |
+
given_ids: List[str],
|
| 285 |
+
query_type='item', # item or user
|
| 286 |
+
top_k=None,
|
| 287 |
+
):
|
| 288 |
+
"""
|
| 289 |
+
获取指定实体(用户或单品)与目标类别的所有交互情况
|
| 290 |
+
|
| 291 |
+
参数:
|
| 292 |
+
- target_category: 待查询的subcategory
|
| 293 |
+
- given_ids: List of 目标类别
|
| 294 |
+
- query_type: 查询的类别, item或user
|
| 295 |
+
- top_k: 返回交互次数最多的前k个物品, 如果是None直接全部返回
|
| 296 |
+
|
| 297 |
+
返回:
|
| 298 |
+
- 列表,包含与目标类别的交互统计信息,按交互次数排序
|
| 299 |
+
"""
|
| 300 |
+
if query_type == 'item':
|
| 301 |
+
matrix = self.item_interaction_matrix
|
| 302 |
+
id_to_index = self.item_id_to_index
|
| 303 |
+
elif query_type == 'user':
|
| 304 |
+
matrix = self.user_item_interaction_matrix
|
| 305 |
+
id_to_index = self.user_id_to_index
|
| 306 |
+
else:
|
| 307 |
+
print(f'query_type must be either item or user but got {query_type}')
|
| 308 |
+
return []
|
| 309 |
+
|
| 310 |
+
# 收集所有交互记录
|
| 311 |
+
all_interactions = []
|
| 312 |
+
category = target_category
|
| 313 |
+
category_indices = self.subcategory_to_indices.get(category, set()) # 获取该类别的所有物品索引
|
| 314 |
+
|
| 315 |
+
# 获取该实体的所有交互
|
| 316 |
+
for given_id in given_ids:
|
| 317 |
+
interactions = self._process_interactions_for_category(
|
| 318 |
+
matrix, given_id, category_indices, id_to_index
|
| 319 |
+
)
|
| 320 |
+
# 将交互添加到结果列表
|
| 321 |
+
all_interactions.extend(interactions)
|
| 322 |
+
|
| 323 |
+
# 合并相同物品的交互次数
|
| 324 |
+
item_interactions = {}
|
| 325 |
+
for interaction in all_interactions:
|
| 326 |
+
item_id = interaction['item_id']
|
| 327 |
+
count = interaction['interaction_count']
|
| 328 |
+
|
| 329 |
+
if item_id in item_interactions:
|
| 330 |
+
item_interactions[item_id] += count
|
| 331 |
+
else:
|
| 332 |
+
item_interactions[item_id] = count
|
| 333 |
+
|
| 334 |
+
# 转换为结果格式
|
| 335 |
+
merged_interactions = [
|
| 336 |
+
{'item_id': item_id, 'interaction_count': count, 'score': 0.0}
|
| 337 |
+
for item_id, count in item_interactions.items()
|
| 338 |
+
]
|
| 339 |
+
|
| 340 |
+
# 排序
|
| 341 |
+
if merged_interactions:
|
| 342 |
+
merged_interactions.sort(key=lambda x: x['interaction_count'], reverse=True)
|
| 343 |
+
|
| 344 |
+
# 截取top-k
|
| 345 |
+
if top_k and merged_interactions:
|
| 346 |
+
merged_interactions = merged_interactions[:top_k]
|
| 347 |
+
|
| 348 |
+
# 存储结果
|
| 349 |
+
return merged_interactions
|
| 350 |
+
|
| 351 |
+
def rank_by_similarity(self, item_interactions, user_interactions, beta=2.0):
|
| 352 |
+
"""
|
| 353 |
+
计算用户交互项与商品交互项的相似度并排序
|
| 354 |
+
"""
|
| 355 |
+
def get_combined_features(feature_dict):
|
| 356 |
+
return (feature_dict['image_embeds'] + feature_dict['text_embeds']) / 2
|
| 357 |
+
|
| 358 |
+
if not item_interactions:
|
| 359 |
+
return user_interactions
|
| 360 |
+
|
| 361 |
+
item_feature_list = []
|
| 362 |
+
for item in item_interactions:
|
| 363 |
+
item_id = item['item_id']
|
| 364 |
+
if item_id not in self.clip_features:
|
| 365 |
+
raise ValueError(f"Didn't find clip feature of item with id: {item_id}")
|
| 366 |
+
|
| 367 |
+
item_features = get_combined_features(self.clip_features[item_id])
|
| 368 |
+
item_feature_list.append(item_features)
|
| 369 |
+
|
| 370 |
+
weights = np.array([x['interaction_count'] for x in item_interactions], dtype=np.float32)
|
| 371 |
+
weights = weights / np.sum(weights)
|
| 372 |
+
item_feature = np.sum(np.stack(item_feature_list, axis=0) * weights[:, np.newaxis], axis=0).reshape(1, -1)
|
| 373 |
+
|
| 374 |
+
max_count = max((user_item.get('interaction_count', 1) for user_item in user_interactions), default=1)
|
| 375 |
+
for user_item in user_interactions:
|
| 376 |
+
user_item_id = user_item['item_id']
|
| 377 |
+
if user_item_id not in self.clip_features:
|
| 378 |
+
raise ValueError(f"Didn't find clip feature of item with id: {user_item_id}")
|
| 379 |
+
|
| 380 |
+
user_item_features = get_combined_features(self.clip_features[user_item_id]).reshape(1, -1)
|
| 381 |
+
similarity = cosine_similarity(user_item_features, item_feature).item()
|
| 382 |
+
interaction_count = user_item['interaction_count']
|
| 383 |
+
count_factor = (interaction_count / max_count) * beta + 1
|
| 384 |
+
user_item['score'] = float(similarity) * count_factor
|
| 385 |
+
|
| 386 |
+
user_interactions.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 387 |
+
return user_interactions
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
data_manager = InteractionDataManager(users_df, outfits_df, items_df)
|
| 391 |
+
mcp = FastMCP('fashion-vlm-server')
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
async def compute_text_embedding(text: str) -> np.ndarray:
|
| 395 |
+
inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True)
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
text_embedding = clip_model.get_text_features(**inputs).numpy()
|
| 398 |
+
return text_embedding / np.linalg.norm(text_embedding, axis=1, keepdims=True)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
async def find_most_similar_image(text_embedding: np.ndarray) -> Dict[str, Any]:
|
| 402 |
+
similarities = np.dot(data_manager.image_embeddings, text_embedding.T).flatten()
|
| 403 |
+
most_similar_idx = np.argmax(similarities)
|
| 404 |
+
most_similar_item_id = data_manager.item_ids[most_similar_idx]
|
| 405 |
+
return {
|
| 406 |
+
"image_path": image_paths[most_similar_item_id],
|
| 407 |
+
"similarity": float(similarities[most_similar_idx])
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@mcp.tool()
|
| 412 |
+
async def retrieve_image(text: str) -> Dict[str, Any]:
|
| 413 |
+
"""Search for the most similar fashion image based on a text description.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
text (str): Text description of the fashion item to search.
|
| 417 |
+
"""
|
| 418 |
+
print(f"Searching for {text}")
|
| 419 |
+
text_embedding = await compute_text_embedding(text)
|
| 420 |
+
return await find_most_similar_image(text_embedding)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def get_recommendation(query, image_path):
|
| 424 |
+
image = Image.open(image_path).convert("RGB")
|
| 425 |
+
image = image_transform(image).unsqueeze(0)
|
| 426 |
+
prompt = uni_prompting.text_tokenizer(['USER: \n' + query + ' ASSISTANT:'])['input_ids'][0]
|
| 427 |
+
prompt = torch.tensor(prompt).unsqueeze(0)
|
| 428 |
+
results = fashion_vlm.mmu_infer_tensor(image, prompt)
|
| 429 |
+
response = results[0]
|
| 430 |
+
return response
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@mcp.tool()
|
| 434 |
+
async def fashion_recommend(query: str, image_path: str, target_category: str, user_id: Optional[str], list_of_items: List[str]) -> Dict[str, str]:
|
| 435 |
+
"""Generate fashion recommendations based on a user's query and uploaded image.
|
| 436 |
+
|
| 437 |
+
This function processes the recommendation in the following steps:
|
| 438 |
+
1. Retrieves the user's interaction history for the specified target category using user_id, target_category, and list_of_items.
|
| 439 |
+
2. Summarizes the user's preferences for the target category by analyzing descriptions of previously interacted fashion items via a language model.
|
| 440 |
+
3. Appends the summarized preference (as a single sentence) to the query and processes it with the uploaded image using a Fashion Vision-Language Model (VLM).
|
| 441 |
+
4. Returns the personalized recommendation along with the derived user preference.
|
| 442 |
+
|
| 443 |
+
The target_category is inferred from the query (e.g., "I want a skirt ..." implies "Skirts") and must belong to a predefined list of valid categories.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
query (str): A complete sentence explicitly stating the user's desired fashion item (e.g., "I want a skirt for summer"). Must be in English.
|
| 447 |
+
image_path (str): File path to the user-uploaded image, provided via the prompt.
|
| 448 |
+
target_category (str): The specific fashion category of interest, derived from the query (e.g., "Skirts"). Must be in valid categories.
|
| 449 |
+
user_id (str): Unique identifier for the user, provided via the prompt.
|
| 450 |
+
list_of_items (List[str]): List of item IDs used to filter the user's interaction history, provided via the prompt.
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Dict[str, str]: A dictionary containing:
|
| 454 |
+
- "recommendation": The personalized fashion recommendation text.
|
| 455 |
+
- "user_preference": The summarized user preference sentence.
|
| 456 |
+
|
| 457 |
+
Valid Categories:
|
| 458 |
+
['Pants', 'Coats', 'Cross-body bags', 'Shirts', 'Hats & caps', 'Sneakers', 'Jeans', 'Boots', 'Dresses', 'Sandals',
|
| 459 |
+
'T-shirts & vests', 'Knitwear', 'Skirts', 'Earrings', 'Hats', 'Sweaters & knitwear', 'Loafers', 'Ballet flats',
|
| 460 |
+
'Espadrilles', 'Tote bags', 'Shoulder bags', 'Slides & flip flops', 'Pumps', 'Necklaces', 'Polo shirts', 'Suits',
|
| 461 |
+
'Oxford shoes', 'Bracelets', 'Jackets', 'Tops', 'Rings', 'Mules', 'Luggage & holdalls', 'Brogues', 'Activewear',
|
| 462 |
+
'Belts', 'Derby shoes', 'Mini bags', 'Watches', 'Backpacks', 'Denim', 'Laptop bags & briefcases', 'Clutch bags',
|
| 463 |
+
'Clutches', 'Lingerie & Nightwear', 'Skiwear', 'Sunglasses', 'Ties & bow ties', 'Shorts', 'Scarves', 'Messenger bags']
|
| 464 |
+
"""
|
| 465 |
+
def get_item(item_id: str) -> pd.Series:
|
| 466 |
+
return data_manager.items_df.loc[item_id]
|
| 467 |
+
|
| 468 |
+
# If no image uploaded, we should use fashion_recommend_without_image
|
| 469 |
+
if image_path == "":
|
| 470 |
+
recommendation = await fashion_recommend_without_image(query)
|
| 471 |
+
return {
|
| 472 |
+
"recommendation": recommendation,
|
| 473 |
+
"user_preference": ""
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
# If no user_id provided or user_id not found in database
|
| 477 |
+
if not user_id or user_id not in data_manager.user_id_to_index.keys():
|
| 478 |
+
return {
|
| 479 |
+
"recommendation": get_recommendation(query, image_path),
|
| 480 |
+
"user_preference": ""
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
user_preference = ""
|
| 484 |
+
if target_category in VALID_CATEGORIES:
|
| 485 |
+
user_interaction_result = data_manager.get_item_category_interactions(
|
| 486 |
+
target_category, [user_id], query_type='user'
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if len(list_of_items) != 0:
|
| 490 |
+
item_interaction_result = data_manager.get_item_category_interactions(
|
| 491 |
+
target_category, list_of_items, query_type='item'
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
item_interaction_result = []
|
| 495 |
+
|
| 496 |
+
descriptions_for_summary = []
|
| 497 |
+
historical_image_path = []
|
| 498 |
+
|
| 499 |
+
if len(user_interaction_result) >= 0:
|
| 500 |
+
user_interaction_result = data_manager.rank_by_similarity(
|
| 501 |
+
item_interaction_result,
|
| 502 |
+
user_interaction_result
|
| 503 |
+
)
|
| 504 |
+
for x in user_interaction_result[:5]:
|
| 505 |
+
item = get_item(x['item_id'])
|
| 506 |
+
descriptions_for_summary.append(item['gen_description'])
|
| 507 |
+
historical_image_path.append(os.path.abspath(item['path']))
|
| 508 |
+
|
| 509 |
+
if descriptions_for_summary:
|
| 510 |
+
user_message = f"Summary user's preference of {target_category} based on following descriptions of fashion items that user brought previously:"
|
| 511 |
+
for x in descriptions_for_summary:
|
| 512 |
+
user_message += f"\n{x}"
|
| 513 |
+
# Get summary using OpenAI API call
|
| 514 |
+
response = await openai.chat.completions.create(
|
| 515 |
+
model="gpt-4o-mini",
|
| 516 |
+
messages=[
|
| 517 |
+
{"role": "system", "content": f"You are a user preference summary assistant. Your response is limited in one sentence, staring at 'I prefer ...'"},
|
| 518 |
+
{"role": "user", "content": user_message}
|
| 519 |
+
],
|
| 520 |
+
max_tokens=1000,
|
| 521 |
+
)
|
| 522 |
+
user_preference = response.choices[0].message.content
|
| 523 |
+
query += user_preference
|
| 524 |
+
|
| 525 |
+
return {
|
| 526 |
+
"recommendation": get_recommendation(query, image_path),
|
| 527 |
+
"user_preference": user_preference
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@mcp.tool()
|
| 532 |
+
async def fashion_recommend_without_image(query: str) -> str:
|
| 533 |
+
"""Recommend fashion items sorely based on user's query.
|
| 534 |
+
Output texts of fashion recommendation from model.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
query (str): User's fashion related query including their recommendation request.
|
| 538 |
+
"""
|
| 539 |
+
response = await openai.chat.completions.create(
|
| 540 |
+
model="gpt-4o-mini",
|
| 541 |
+
messages=[
|
| 542 |
+
{"role": "system", "content": "You are a fashion stylist. You should answer user's fashion-related question, especially about fashion recommendation."},
|
| 543 |
+
{"role": "user", "content": query}
|
| 544 |
+
],
|
| 545 |
+
max_tokens=500,
|
| 546 |
+
)
|
| 547 |
+
return response.choices[0].message.content
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
@mcp.tool()
|
| 551 |
+
async def image_generate(text: str) -> str:
|
| 552 |
+
""""Generate image based on description. Output is path that saves generated image.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
text (str): Descriptive text from user. Used for fashion image generation. English ONLY!
|
| 556 |
+
"""
|
| 557 |
+
output_path = fashion_vlm.t2i_infer([text])[0]
|
| 558 |
+
output_path = os.path.abspath(output_path)
|
| 559 |
+
print(f"Generated image saved at {output_path}")
|
| 560 |
+
return output_path
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
# 获取内部 Server 对象
|
| 564 |
+
mcp_server = mcp._mcp_server
|
| 565 |
+
sse_transport = SseServerTransport("/messages/")
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# 处理 SSE 连接
|
| 569 |
+
async def handle_sse(request):
|
| 570 |
+
print("Handling SSE connection")
|
| 571 |
+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
| 572 |
+
read_stream, write_stream = streams
|
| 573 |
+
await mcp_server.run(
|
| 574 |
+
read_stream,
|
| 575 |
+
write_stream,
|
| 576 |
+
mcp_server.create_initialization_options(),
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# 定义路由
|
| 580 |
+
routes = [
|
| 581 |
+
Route("/sse", endpoint=handle_sse),
|
| 582 |
+
Mount("/messages/", app=sse_transport.handle_post_message),
|
| 583 |
+
]
|
| 584 |
+
|
| 585 |
+
# 创建 Starlette 应用
|
| 586 |
+
starlette_app = Starlette(routes=routes)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
if __name__ == "__main__":
|
| 590 |
+
print("Starting Fashion VLM server with HTTP and SSE...")
|
| 591 |
+
uvicorn.run(starlette_app, host="0.0.0.0", port=8000)
|
| 592 |
+
|
mcp_servers/fashion_vlm/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modeling_showo import Showo
|
| 2 |
+
from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
|
| 3 |
+
from .sampling import *
|
| 4 |
+
from .clip_encoder import CLIPVisionTower
|
mcp_servers/fashion_vlm/models/clip_encoder.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
| 5 |
+
|
| 6 |
+
class CLIPVisionTower(nn.Module):
|
| 7 |
+
def __init__(self, vision_tower):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.is_loaded = False
|
| 11 |
+
|
| 12 |
+
self.vision_tower_name = vision_tower
|
| 13 |
+
self.select_layer = -2
|
| 14 |
+
self.select_feature = "patch"
|
| 15 |
+
self.load_model()
|
| 16 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
| 17 |
+
|
| 18 |
+
def load_model(self, device_map=None):
|
| 19 |
+
if self.is_loaded:
|
| 20 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 24 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
| 25 |
+
self.vision_tower.requires_grad_(False)
|
| 26 |
+
|
| 27 |
+
self.is_loaded = True
|
| 28 |
+
|
| 29 |
+
def feature_select(self, image_forward_outs):
|
| 30 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 31 |
+
if self.select_feature == 'patch':
|
| 32 |
+
image_features = image_features[:, 1:]
|
| 33 |
+
elif self.select_feature == 'cls_patch':
|
| 34 |
+
image_features = image_features
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 37 |
+
return image_features
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def forward(self, images):
|
| 41 |
+
if type(images) is list:
|
| 42 |
+
image_features = []
|
| 43 |
+
for image in images:
|
| 44 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
| 45 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 46 |
+
image_features.append(image_feature)
|
| 47 |
+
else:
|
| 48 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 49 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 50 |
+
|
| 51 |
+
return image_features
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def dummy_feature(self):
|
| 55 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def dtype(self):
|
| 59 |
+
return self.vision_tower.dtype
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def device(self):
|
| 63 |
+
return self.vision_tower.device
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def config(self):
|
| 67 |
+
if self.is_loaded:
|
| 68 |
+
return self.vision_tower.config
|
| 69 |
+
else:
|
| 70 |
+
return self.cfg_only
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def hidden_size(self):
|
| 74 |
+
return self.config.hidden_size
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def num_patches_per_side(self):
|
| 78 |
+
return self.config.image_size // self.config.patch_size
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def num_patches(self):
|
| 82 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CLIPVisionTowerS2(CLIPVisionTower):
|
| 86 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 87 |
+
super().__init__(vision_tower, args, delay_load)
|
| 88 |
+
|
| 89 |
+
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
|
| 90 |
+
self.s2_scales = list(map(int, self.s2_scales.split(',')))
|
| 91 |
+
self.s2_scales.sort()
|
| 92 |
+
self.s2_split_size = self.s2_scales[0]
|
| 93 |
+
self.s2_image_size = self.s2_scales[-1]
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
from s2wrapper import forward as multiscale_forward
|
| 97 |
+
except ImportError:
|
| 98 |
+
raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
|
| 99 |
+
self.multiscale_forward = multiscale_forward
|
| 100 |
+
|
| 101 |
+
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
| 102 |
+
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
|
| 103 |
+
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
| 104 |
+
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
| 105 |
+
|
| 106 |
+
def load_model(self, device_map=None):
|
| 107 |
+
if self.is_loaded:
|
| 108 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 112 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
| 113 |
+
self.vision_tower.requires_grad_(False)
|
| 114 |
+
|
| 115 |
+
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
| 116 |
+
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
| 117 |
+
|
| 118 |
+
self.is_loaded = True
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def forward_feature(self, images):
|
| 122 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 123 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 124 |
+
return image_features
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def forward(self, images):
|
| 128 |
+
if type(images) is list:
|
| 129 |
+
image_features = []
|
| 130 |
+
for image in images:
|
| 131 |
+
image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
|
| 132 |
+
image_features.append(image_feature)
|
| 133 |
+
else:
|
| 134 |
+
image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
|
| 135 |
+
|
| 136 |
+
return image_features
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def hidden_size(self):
|
| 140 |
+
return self.config.hidden_size * len(self.s2_scales)
|
mcp_servers/fashion_vlm/models/common_modules.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
from einops.layers.torch import Rearrange
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def nonlinearity(x):
|
| 17 |
+
# swish
|
| 18 |
+
return x * torch.sigmoid(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def Normalize(in_channels):
|
| 22 |
+
return torch.nn.GroupNorm(
|
| 23 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Upsample(nn.Module):
|
| 28 |
+
def __init__(self, in_channels, with_conv):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.with_conv = with_conv
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
self.conv = torch.nn.Conv2d(
|
| 33 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 38 |
+
if self.with_conv:
|
| 39 |
+
x = self.conv(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DepthToSpaceUpsample(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
in_channels,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
conv = nn.Conv2d(in_channels, in_channels * 4, 1)
|
| 50 |
+
|
| 51 |
+
self.net = nn.Sequential(
|
| 52 |
+
conv,
|
| 53 |
+
nn.SiLU(),
|
| 54 |
+
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.init_conv_(conv)
|
| 58 |
+
|
| 59 |
+
def init_conv_(self, conv):
|
| 60 |
+
o, i, h, w = conv.weight.shape
|
| 61 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
| 62 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 63 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
| 64 |
+
|
| 65 |
+
conv.weight.data.copy_(conv_weight)
|
| 66 |
+
nn.init.zeros_(conv.bias.data)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
out = self.net(x)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Downsample(nn.Module):
|
| 74 |
+
def __init__(self, in_channels, with_conv):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.with_conv = with_conv
|
| 77 |
+
if self.with_conv:
|
| 78 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 79 |
+
self.conv = torch.nn.Conv2d(
|
| 80 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if self.with_conv:
|
| 85 |
+
pad = (0, 1, 0, 1)
|
| 86 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 87 |
+
x = self.conv(x)
|
| 88 |
+
else:
|
| 89 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def unpack_time(t, batch):
|
| 94 |
+
_, c, w, h = t.size()
|
| 95 |
+
out = torch.reshape(t, [batch, -1, c, w, h])
|
| 96 |
+
out = rearrange(out, "b t c h w -> b c t h w")
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def pack_time(t):
|
| 101 |
+
out = rearrange(t, "b c t h w -> b t c h w")
|
| 102 |
+
_, _, c, w, h = out.size()
|
| 103 |
+
return torch.reshape(out, [-1, c, w, h])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TimeDownsample2x(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
dim,
|
| 110 |
+
dim_out=None,
|
| 111 |
+
kernel_size=3,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
if dim_out is None:
|
| 115 |
+
dim_out = dim
|
| 116 |
+
self.time_causal_padding = (kernel_size - 1, 0)
|
| 117 |
+
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
| 121 |
+
b, h, w, c, t = x.size()
|
| 122 |
+
x = torch.reshape(x, [-1, c, t])
|
| 123 |
+
|
| 124 |
+
x = F.pad(x, self.time_causal_padding)
|
| 125 |
+
out = self.conv(x)
|
| 126 |
+
|
| 127 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
| 128 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 129 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 130 |
+
return out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TimeUpsample2x(nn.Module):
|
| 134 |
+
def __init__(self, dim, dim_out=None):
|
| 135 |
+
super().__init__()
|
| 136 |
+
if dim_out is None:
|
| 137 |
+
dim_out = dim
|
| 138 |
+
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
| 139 |
+
|
| 140 |
+
self.net = nn.Sequential(
|
| 141 |
+
nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.init_conv_(conv)
|
| 145 |
+
|
| 146 |
+
def init_conv_(self, conv):
|
| 147 |
+
o, i, t = conv.weight.shape
|
| 148 |
+
conv_weight = torch.empty(o // 2, i, t)
|
| 149 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 150 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
|
| 151 |
+
|
| 152 |
+
conv.weight.data.copy_(conv_weight)
|
| 153 |
+
nn.init.zeros_(conv.bias.data)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
| 157 |
+
b, h, w, c, t = x.size()
|
| 158 |
+
x = torch.reshape(x, [-1, c, t])
|
| 159 |
+
|
| 160 |
+
out = self.net(x)
|
| 161 |
+
out = out[:, :, 1:].contiguous()
|
| 162 |
+
|
| 163 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
| 164 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class AttnBlock(nn.Module):
|
| 169 |
+
def __init__(self, in_channels):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.in_channels = in_channels
|
| 172 |
+
|
| 173 |
+
self.norm = Normalize(in_channels)
|
| 174 |
+
self.q = torch.nn.Conv2d(
|
| 175 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 176 |
+
)
|
| 177 |
+
self.k = torch.nn.Conv2d(
|
| 178 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 179 |
+
)
|
| 180 |
+
self.v = torch.nn.Conv2d(
|
| 181 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 182 |
+
)
|
| 183 |
+
self.proj_out = torch.nn.Conv2d(
|
| 184 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
h_ = x
|
| 189 |
+
h_ = self.norm(h_)
|
| 190 |
+
q = self.q(h_)
|
| 191 |
+
k = self.k(h_)
|
| 192 |
+
v = self.v(h_)
|
| 193 |
+
|
| 194 |
+
# compute attention
|
| 195 |
+
b, c, h, w = q.shape
|
| 196 |
+
q = q.reshape(b, c, h * w)
|
| 197 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
| 198 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
| 199 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 200 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 201 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 202 |
+
|
| 203 |
+
# attend to values
|
| 204 |
+
v = v.reshape(b, c, h * w)
|
| 205 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 206 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 207 |
+
h_ = h_.reshape(b, c, h, w)
|
| 208 |
+
|
| 209 |
+
h_ = self.proj_out(h_)
|
| 210 |
+
|
| 211 |
+
return x + h_
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class TimeAttention(AttnBlock):
|
| 215 |
+
def forward(self, x, *args, **kwargs):
|
| 216 |
+
x = rearrange(x, "b c t h w -> b h w t c")
|
| 217 |
+
b, h, w, t, c = x.size()
|
| 218 |
+
x = torch.reshape(x, (-1, t, c))
|
| 219 |
+
|
| 220 |
+
x = super().forward(x, *args, **kwargs)
|
| 221 |
+
|
| 222 |
+
x = torch.reshape(x, [b, h, w, t, c])
|
| 223 |
+
return rearrange(x, "b h w t c -> b c t h w")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class Residual(nn.Module):
|
| 227 |
+
def __init__(self, fn: nn.Module):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.fn = fn
|
| 230 |
+
|
| 231 |
+
def forward(self, x, **kwargs):
|
| 232 |
+
return self.fn(x, **kwargs) + x
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def cast_tuple(t, length=1):
|
| 236 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CausalConv3d(nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
chan_in,
|
| 243 |
+
chan_out,
|
| 244 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 245 |
+
pad_mode="constant",
|
| 246 |
+
**kwargs
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
| 250 |
+
|
| 251 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 252 |
+
|
| 253 |
+
dilation = kwargs.pop("dilation", 1)
|
| 254 |
+
stride = kwargs.pop("stride", 1)
|
| 255 |
+
|
| 256 |
+
self.pad_mode = pad_mode
|
| 257 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
| 258 |
+
height_pad = height_kernel_size // 2
|
| 259 |
+
width_pad = width_kernel_size // 2
|
| 260 |
+
|
| 261 |
+
self.time_pad = time_pad
|
| 262 |
+
self.time_causal_padding = (
|
| 263 |
+
width_pad,
|
| 264 |
+
width_pad,
|
| 265 |
+
height_pad,
|
| 266 |
+
height_pad,
|
| 267 |
+
time_pad,
|
| 268 |
+
0,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
stride = (stride, 1, 1)
|
| 272 |
+
dilation = (dilation, 1, 1)
|
| 273 |
+
self.conv = nn.Conv3d(
|
| 274 |
+
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
| 279 |
+
|
| 280 |
+
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
|
| 281 |
+
return self.conv(x)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def ResnetBlockCausal3D(
|
| 285 |
+
dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
|
| 286 |
+
):
|
| 287 |
+
net = nn.Sequential(
|
| 288 |
+
Normalize(dim),
|
| 289 |
+
nn.SiLU(),
|
| 290 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
| 291 |
+
Normalize(dim),
|
| 292 |
+
nn.SiLU(),
|
| 293 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
| 294 |
+
)
|
| 295 |
+
return Residual(net)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class ResnetBlock(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
*,
|
| 302 |
+
in_channels,
|
| 303 |
+
out_channels=None,
|
| 304 |
+
conv_shortcut=False,
|
| 305 |
+
dropout,
|
| 306 |
+
temb_channels=512
|
| 307 |
+
):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.in_channels = in_channels
|
| 310 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 311 |
+
self.out_channels = out_channels
|
| 312 |
+
self.use_conv_shortcut = conv_shortcut
|
| 313 |
+
|
| 314 |
+
self.norm1 = Normalize(in_channels)
|
| 315 |
+
self.conv1 = torch.nn.Conv2d(
|
| 316 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 317 |
+
)
|
| 318 |
+
if temb_channels > 0:
|
| 319 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 320 |
+
else:
|
| 321 |
+
self.temb_proj = None
|
| 322 |
+
self.norm2 = Normalize(out_channels)
|
| 323 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 324 |
+
self.conv2 = torch.nn.Conv2d(
|
| 325 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 326 |
+
)
|
| 327 |
+
if self.in_channels != self.out_channels:
|
| 328 |
+
if self.use_conv_shortcut:
|
| 329 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
| 330 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
| 334 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def forward(self, x, temb):
|
| 338 |
+
h = x
|
| 339 |
+
h = self.norm1(h)
|
| 340 |
+
h = nonlinearity(h)
|
| 341 |
+
h = self.conv1(h)
|
| 342 |
+
|
| 343 |
+
if temb is not None:
|
| 344 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 345 |
+
|
| 346 |
+
h = self.norm2(h)
|
| 347 |
+
h = nonlinearity(h)
|
| 348 |
+
h = self.dropout(h)
|
| 349 |
+
h = self.conv2(h)
|
| 350 |
+
|
| 351 |
+
if self.in_channels != self.out_channels:
|
| 352 |
+
if self.use_conv_shortcut:
|
| 353 |
+
x = self.conv_shortcut(x)
|
| 354 |
+
else:
|
| 355 |
+
x = self.nin_shortcut(x)
|
| 356 |
+
|
| 357 |
+
return x + h
|
mcp_servers/fashion_vlm/models/misc.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
import torch
|
| 3 |
+
from typing import (
|
| 4 |
+
Any,
|
| 5 |
+
Callable,
|
| 6 |
+
Dict,
|
| 7 |
+
Iterable,
|
| 8 |
+
List,
|
| 9 |
+
NamedTuple,
|
| 10 |
+
NewType,
|
| 11 |
+
Optional,
|
| 12 |
+
Sized,
|
| 13 |
+
Tuple,
|
| 14 |
+
Type,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
try:
|
| 19 |
+
from typing import Literal
|
| 20 |
+
except ImportError:
|
| 21 |
+
from typing_extensions import Literal
|
| 22 |
+
|
| 23 |
+
# Tensor dtype
|
| 24 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
| 25 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
| 26 |
+
|
| 27 |
+
# Config type
|
| 28 |
+
from omegaconf import DictConfig
|
| 29 |
+
|
| 30 |
+
# PyTorch Tensor type
|
| 31 |
+
from torch import Tensor
|
| 32 |
+
|
| 33 |
+
# Runtime type checking decorator
|
| 34 |
+
from typeguard import typechecked as typechecker
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def broadcast(tensor, src=0):
|
| 38 |
+
if not _distributed_available():
|
| 39 |
+
return tensor
|
| 40 |
+
else:
|
| 41 |
+
torch.distributed.broadcast(tensor, src=src)
|
| 42 |
+
return tensor
|
| 43 |
+
|
| 44 |
+
def _distributed_available():
|
| 45 |
+
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
| 46 |
+
|
| 47 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
| 48 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
| 49 |
+
if '--local-rank' in cfg:
|
| 50 |
+
del cfg['--local-rank']
|
| 51 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
| 52 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
| 53 |
+
return scfg
|
mcp_servers/fashion_vlm/models/modeling_magvitv2.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .common_modules import *
|
| 6 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
| 7 |
+
from .misc import *
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
class Updateable:
|
| 11 |
+
def do_update_step(
|
| 12 |
+
self, epoch: int, global_step: int, on_load_weights: bool = False
|
| 13 |
+
):
|
| 14 |
+
for attr in self.__dir__():
|
| 15 |
+
if attr.startswith("_"):
|
| 16 |
+
continue
|
| 17 |
+
try:
|
| 18 |
+
module = getattr(self, attr)
|
| 19 |
+
except:
|
| 20 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
| 21 |
+
if isinstance(module, Updateable):
|
| 22 |
+
module.do_update_step(
|
| 23 |
+
epoch, global_step, on_load_weights=on_load_weights
|
| 24 |
+
)
|
| 25 |
+
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
|
| 26 |
+
|
| 27 |
+
def do_update_step_end(self, epoch: int, global_step: int):
|
| 28 |
+
for attr in self.__dir__():
|
| 29 |
+
if attr.startswith("_"):
|
| 30 |
+
continue
|
| 31 |
+
try:
|
| 32 |
+
module = getattr(self, attr)
|
| 33 |
+
except:
|
| 34 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
| 35 |
+
if isinstance(module, Updateable):
|
| 36 |
+
module.do_update_step_end(epoch, global_step)
|
| 37 |
+
self.update_step_end(epoch, global_step)
|
| 38 |
+
|
| 39 |
+
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
| 40 |
+
# override this method to implement custom update logic
|
| 41 |
+
# if on_load_weights is True, you should be careful doing things related to model evaluations,
|
| 42 |
+
# as the models and tensors are not guarenteed to be on the same device
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def update_step_end(self, epoch: int, global_step: int):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
class VQGANEncoder(ModelMixin, ConfigMixin):
|
| 49 |
+
@dataclass
|
| 50 |
+
class Config:
|
| 51 |
+
ch: int = 128
|
| 52 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
|
| 53 |
+
num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
|
| 54 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [5])
|
| 55 |
+
dropout: float = 0.0
|
| 56 |
+
in_ch: int = 3
|
| 57 |
+
out_ch: int = 3
|
| 58 |
+
resolution: int = 256
|
| 59 |
+
z_channels: int = 13
|
| 60 |
+
double_z: bool = False
|
| 61 |
+
|
| 62 |
+
def __init__(self,
|
| 63 |
+
ch: int = 128,
|
| 64 |
+
ch_mult: List[int] = [1, 2, 2, 4, 4],
|
| 65 |
+
num_res_blocks: List[int] = [4, 3, 4, 3, 4],
|
| 66 |
+
attn_resolutions: List[int] = [5],
|
| 67 |
+
dropout: float = 0.0,
|
| 68 |
+
in_ch: int = 3,
|
| 69 |
+
out_ch: int = 3,
|
| 70 |
+
resolution: int = 256,
|
| 71 |
+
z_channels: int = 13,
|
| 72 |
+
double_z: bool = False):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.ch = ch
|
| 75 |
+
self.temb_ch = 0
|
| 76 |
+
self.num_resolutions = len(ch_mult)
|
| 77 |
+
self.num_res_blocks = num_res_blocks
|
| 78 |
+
self.resolution = resolution
|
| 79 |
+
self.in_ch = in_ch
|
| 80 |
+
# downsampling
|
| 81 |
+
self.conv_in = torch.nn.Conv2d(
|
| 82 |
+
self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
curr_res = self.resolution
|
| 86 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 87 |
+
self.down = nn.ModuleList()
|
| 88 |
+
for i_level in range(self.num_resolutions):
|
| 89 |
+
block = nn.ModuleList()
|
| 90 |
+
attn = nn.ModuleList()
|
| 91 |
+
block_in = self.ch * in_ch_mult[i_level]
|
| 92 |
+
block_out = self.ch * ch_mult[i_level]
|
| 93 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 94 |
+
block.append(
|
| 95 |
+
ResnetBlock(
|
| 96 |
+
in_channels=block_in,
|
| 97 |
+
out_channels=block_out,
|
| 98 |
+
temb_channels=self.temb_ch,
|
| 99 |
+
dropout=dropout,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
block_in = block_out
|
| 103 |
+
if curr_res in attn_resolutions:
|
| 104 |
+
attn.append(AttnBlock(block_in))
|
| 105 |
+
down = nn.Module()
|
| 106 |
+
down.block = block
|
| 107 |
+
down.attn = attn
|
| 108 |
+
if i_level != self.num_resolutions - 1:
|
| 109 |
+
down.downsample = Downsample(block_in, True)
|
| 110 |
+
curr_res = curr_res // 2
|
| 111 |
+
self.down.append(down)
|
| 112 |
+
|
| 113 |
+
# middle
|
| 114 |
+
self.mid = nn.Module()
|
| 115 |
+
self.mid.block_1 = ResnetBlock(
|
| 116 |
+
in_channels=block_in,
|
| 117 |
+
out_channels=block_in,
|
| 118 |
+
temb_channels=self.temb_ch,
|
| 119 |
+
dropout=dropout,
|
| 120 |
+
)
|
| 121 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 122 |
+
self.mid.block_2 = ResnetBlock(
|
| 123 |
+
in_channels=block_in,
|
| 124 |
+
out_channels=block_in,
|
| 125 |
+
temb_channels=self.temb_ch,
|
| 126 |
+
dropout=dropout,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
self.norm_out = Normalize(block_in)
|
| 131 |
+
self.conv_out = torch.nn.Conv2d(
|
| 132 |
+
block_in,
|
| 133 |
+
2 * z_channels if double_z else z_channels,
|
| 134 |
+
kernel_size=3,
|
| 135 |
+
stride=1,
|
| 136 |
+
padding=1,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 140 |
+
# for param in self.parameters():
|
| 141 |
+
# broadcast(param, src=0)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
# timestep embedding
|
| 145 |
+
temb = None
|
| 146 |
+
|
| 147 |
+
# downsampling
|
| 148 |
+
hs = [self.conv_in(x)]
|
| 149 |
+
for i_level in range(self.num_resolutions):
|
| 150 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 151 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 152 |
+
if len(self.down[i_level].attn) > 0:
|
| 153 |
+
h = self.down[i_level].attn[i_block](h)
|
| 154 |
+
hs.append(h)
|
| 155 |
+
if i_level != self.num_resolutions - 1:
|
| 156 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 157 |
+
|
| 158 |
+
# middle
|
| 159 |
+
h = hs[-1]
|
| 160 |
+
h = self.mid.block_1(h, temb)
|
| 161 |
+
h = self.mid.attn_1(h)
|
| 162 |
+
h = self.mid.block_2(h, temb)
|
| 163 |
+
|
| 164 |
+
# end
|
| 165 |
+
h = self.norm_out(h)
|
| 166 |
+
h = nonlinearity(h)
|
| 167 |
+
h = self.conv_out(h)
|
| 168 |
+
h = self.quant_conv(h)
|
| 169 |
+
return h
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class LFQuantizer(nn.Module):
|
| 173 |
+
def __init__(self, num_codebook_entry: int = -1,
|
| 174 |
+
codebook_dim: int = 13,
|
| 175 |
+
beta: float = 0.25,
|
| 176 |
+
entropy_multiplier: float = 0.1,
|
| 177 |
+
commit_loss_multiplier: float = 0.1, ):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.codebook_size = 2 ** codebook_dim
|
| 180 |
+
print(
|
| 181 |
+
f"Look-up free quantizer with codebook size: {self.codebook_size}"
|
| 182 |
+
)
|
| 183 |
+
self.e_dim = codebook_dim
|
| 184 |
+
self.beta = beta
|
| 185 |
+
|
| 186 |
+
indices = torch.arange(self.codebook_size)
|
| 187 |
+
|
| 188 |
+
binary = (
|
| 189 |
+
indices.unsqueeze(1)
|
| 190 |
+
>> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
|
| 191 |
+
) & 1
|
| 192 |
+
|
| 193 |
+
embedding = binary.float() * 2 - 1
|
| 194 |
+
self.register_buffer("embedding", embedding)
|
| 195 |
+
self.register_buffer(
|
| 196 |
+
"power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
|
| 197 |
+
)
|
| 198 |
+
self.commit_loss_multiplier = commit_loss_multiplier
|
| 199 |
+
self.entropy_multiplier = entropy_multiplier
|
| 200 |
+
|
| 201 |
+
def get_indices(self, z_q):
|
| 202 |
+
return (
|
| 203 |
+
(self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
|
| 204 |
+
.sum(1, keepdim=True)
|
| 205 |
+
.long()
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def get_codebook_entry(self, indices, shape=None):
|
| 209 |
+
if shape is None:
|
| 210 |
+
h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
|
| 211 |
+
else:
|
| 212 |
+
h, w = shape
|
| 213 |
+
b, _ = indices.shape
|
| 214 |
+
indices = indices.reshape(-1)
|
| 215 |
+
z_q = self.embedding[indices]
|
| 216 |
+
z_q = z_q.view(b, h, w, -1)
|
| 217 |
+
|
| 218 |
+
# reshape back to match original input shape
|
| 219 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 220 |
+
|
| 221 |
+
return z_q
|
| 222 |
+
|
| 223 |
+
def forward(self, z, get_code=False):
|
| 224 |
+
"""
|
| 225 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
| 226 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
| 227 |
+
z (continuous) -> z_q (discrete)
|
| 228 |
+
z.shape = (batch, channel, height, width)
|
| 229 |
+
quantization pipeline:
|
| 230 |
+
1. get encoder input (B,C,H,W)
|
| 231 |
+
2. flatten input to (B*H*W,C)
|
| 232 |
+
"""
|
| 233 |
+
if get_code:
|
| 234 |
+
return self.get_codebook_entry(z)
|
| 235 |
+
|
| 236 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 237 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 238 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 239 |
+
ge_zero = (z_flattened > 0).float()
|
| 240 |
+
ones = torch.ones_like(z_flattened)
|
| 241 |
+
z_q = ones * ge_zero + -ones * (1 - ge_zero)
|
| 242 |
+
|
| 243 |
+
# preserve gradients
|
| 244 |
+
z_q = z_flattened + (z_q - z_flattened).detach()
|
| 245 |
+
|
| 246 |
+
# compute entropy loss
|
| 247 |
+
CatDist = torch.distributions.categorical.Categorical
|
| 248 |
+
logit = torch.stack(
|
| 249 |
+
[
|
| 250 |
+
-(z_flattened - torch.ones_like(z_q)).pow(2),
|
| 251 |
+
-(z_flattened - torch.ones_like(z_q) * -1).pow(2),
|
| 252 |
+
],
|
| 253 |
+
dim=-1,
|
| 254 |
+
)
|
| 255 |
+
cat_dist = CatDist(logits=logit)
|
| 256 |
+
entropy = cat_dist.entropy().mean()
|
| 257 |
+
mean_prob = cat_dist.probs.mean(0)
|
| 258 |
+
mean_entropy = CatDist(probs=mean_prob).entropy().mean()
|
| 259 |
+
|
| 260 |
+
# compute loss for embedding
|
| 261 |
+
commit_loss = torch.mean(
|
| 262 |
+
(z_q.detach() - z_flattened) ** 2
|
| 263 |
+
) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
|
| 264 |
+
|
| 265 |
+
# reshape back to match original input shape
|
| 266 |
+
z_q = z_q.view(z.shape)
|
| 267 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
"z": z_q,
|
| 271 |
+
"quantizer_loss": commit_loss * self.commit_loss_multiplier,
|
| 272 |
+
"entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
|
| 273 |
+
"indices": self.get_indices(z_q),
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class VQGANDecoder(ModelMixin, ConfigMixin):
|
| 278 |
+
def __init__(self, ch: int = 128,
|
| 279 |
+
ch_mult: List[int] = [1, 1, 2, 2, 4],
|
| 280 |
+
num_res_blocks: List[int] = [4, 4, 3, 4, 3],
|
| 281 |
+
attn_resolutions: List[int] = [5],
|
| 282 |
+
dropout: float = 0.0,
|
| 283 |
+
in_ch: int = 3,
|
| 284 |
+
out_ch: int = 3,
|
| 285 |
+
resolution: int = 256,
|
| 286 |
+
z_channels: int = 13,
|
| 287 |
+
double_z: bool = False):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.ch = ch
|
| 290 |
+
self.temb_ch = 0
|
| 291 |
+
self.num_resolutions = len(ch_mult)
|
| 292 |
+
self.num_res_blocks = num_res_blocks
|
| 293 |
+
self.resolution = resolution
|
| 294 |
+
self.in_ch = in_ch
|
| 295 |
+
self.give_pre_end = False
|
| 296 |
+
|
| 297 |
+
self.z_channels = z_channels
|
| 298 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 299 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 300 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 301 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
| 302 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 303 |
+
print(
|
| 304 |
+
"Working with z of shape {} = {} dimensions.".format(
|
| 305 |
+
self.z_shape, np.prod(self.z_shape)
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# z to block_in
|
| 310 |
+
self.conv_in = torch.nn.Conv2d(
|
| 311 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# middle
|
| 315 |
+
self.mid = nn.Module()
|
| 316 |
+
self.mid.block_1 = ResnetBlock(
|
| 317 |
+
in_channels=block_in,
|
| 318 |
+
out_channels=block_in,
|
| 319 |
+
temb_channels=self.temb_ch,
|
| 320 |
+
dropout=dropout,
|
| 321 |
+
)
|
| 322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 323 |
+
self.mid.block_2 = ResnetBlock(
|
| 324 |
+
in_channels=block_in,
|
| 325 |
+
out_channels=block_in,
|
| 326 |
+
temb_channels=self.temb_ch,
|
| 327 |
+
dropout=dropout,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# upsampling
|
| 331 |
+
self.up = nn.ModuleList()
|
| 332 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 333 |
+
block = nn.ModuleList()
|
| 334 |
+
attn = nn.ModuleList()
|
| 335 |
+
block_out = ch * ch_mult[i_level]
|
| 336 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 337 |
+
block.append(
|
| 338 |
+
ResnetBlock(
|
| 339 |
+
in_channels=block_in,
|
| 340 |
+
out_channels=block_out,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout,
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
block_in = block_out
|
| 346 |
+
if curr_res in attn_resolutions:
|
| 347 |
+
attn.append(AttnBlock(block_in))
|
| 348 |
+
up = nn.Module()
|
| 349 |
+
up.block = block
|
| 350 |
+
up.attn = attn
|
| 351 |
+
if i_level != 0:
|
| 352 |
+
up.upsample = Upsample(block_in, True)
|
| 353 |
+
curr_res = curr_res * 2
|
| 354 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 355 |
+
|
| 356 |
+
self.norm_out = Normalize(block_in)
|
| 357 |
+
self.conv_out = torch.nn.Conv2d(
|
| 358 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
| 359 |
+
)
|
| 360 |
+
self.post_quant_conv = torch.nn.Conv2d(
|
| 361 |
+
z_channels, z_channels, 1
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def forward(self, z):
|
| 366 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 367 |
+
self.last_z_shape = z.shape
|
| 368 |
+
# timestep embedding
|
| 369 |
+
temb = None
|
| 370 |
+
output = dict()
|
| 371 |
+
z = self.post_quant_conv(z)
|
| 372 |
+
|
| 373 |
+
# z to block_in
|
| 374 |
+
h = self.conv_in(z)
|
| 375 |
+
|
| 376 |
+
# middle
|
| 377 |
+
h = self.mid.block_1(h, temb)
|
| 378 |
+
h = self.mid.attn_1(h)
|
| 379 |
+
h = self.mid.block_2(h, temb)
|
| 380 |
+
|
| 381 |
+
# upsampling
|
| 382 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 383 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
| 384 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 385 |
+
if len(self.up[i_level].attn) > 0:
|
| 386 |
+
h = self.up[i_level].attn[i_block](h)
|
| 387 |
+
if i_level != 0:
|
| 388 |
+
h = self.up[i_level].upsample(h)
|
| 389 |
+
|
| 390 |
+
# end
|
| 391 |
+
output["output"] = h
|
| 392 |
+
if self.give_pre_end:
|
| 393 |
+
return output
|
| 394 |
+
|
| 395 |
+
h = self.norm_out(h)
|
| 396 |
+
h = nonlinearity(h)
|
| 397 |
+
h = self.conv_out(h)
|
| 398 |
+
output["output"] = h
|
| 399 |
+
return output
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class MAGVITv2(ModelMixin, ConfigMixin):
|
| 403 |
+
@register_to_config
|
| 404 |
+
def __init__(
|
| 405 |
+
self,
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
|
| 409 |
+
self.encoder = VQGANEncoder()
|
| 410 |
+
self.decoder = VQGANDecoder()
|
| 411 |
+
self.quantize = LFQuantizer()
|
| 412 |
+
|
| 413 |
+
def forward(self, pixel_values, return_loss=False):
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
def encode(self, pixel_values, return_loss=False):
|
| 417 |
+
hidden_states = self.encoder(pixel_values)
|
| 418 |
+
quantized_states = self.quantize(hidden_states)['z']
|
| 419 |
+
codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
|
| 420 |
+
output = (quantized_states, codebook_indices)
|
| 421 |
+
return output
|
| 422 |
+
|
| 423 |
+
def get_code(self, pixel_values):
|
| 424 |
+
hidden_states = self.encoder(pixel_values)
|
| 425 |
+
codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
|
| 426 |
+
|
| 427 |
+
return codebook_indices
|
| 428 |
+
|
| 429 |
+
def decode_code(self, codebook_indices, shape=None):
|
| 430 |
+
z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
|
| 431 |
+
|
| 432 |
+
reconstructed_pixel_values = self.decoder(z_q)["output"]
|
| 433 |
+
return reconstructed_pixel_values
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == '__main__':
|
| 437 |
+
encoder = VQGANEncoder()
|
| 438 |
+
import ipdb
|
| 439 |
+
ipdb.set_trace()
|
| 440 |
+
print()
|
mcp_servers/fashion_vlm/models/modeling_showo.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 NUS Show Lab, HuggingFace.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from transformers import AutoConfig
|
| 19 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
| 20 |
+
from .sampling import cosine_schedule, mask_by_random_topk
|
| 21 |
+
from .phi import PhiForCausalLM
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Showo(ModelMixin, ConfigMixin):
|
| 25 |
+
_supports_gradient_checkpointing = True
|
| 26 |
+
|
| 27 |
+
@register_to_config
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
w_clip_vit,
|
| 31 |
+
vocab_size,
|
| 32 |
+
llm_vocab_size,
|
| 33 |
+
llm_model_path='',
|
| 34 |
+
codebook_size=8192,
|
| 35 |
+
num_vq_tokens=256,
|
| 36 |
+
load_from_showo=True,
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.vocab_size = vocab_size
|
| 41 |
+
self.register_to_config(mask_token_id=vocab_size - 1)
|
| 42 |
+
if load_from_showo:
|
| 43 |
+
config = AutoConfig.from_pretrained(llm_model_path)
|
| 44 |
+
self.showo = PhiForCausalLM(config)
|
| 45 |
+
else:
|
| 46 |
+
self.showo = PhiForCausalLM.from_pretrained(llm_model_path, attn_implementation='sdpa')
|
| 47 |
+
self.showo.resize_token_embeddings(self.vocab_size)
|
| 48 |
+
self.output_size = self.vocab_size
|
| 49 |
+
|
| 50 |
+
if self.w_clip_vit:
|
| 51 |
+
self.mm_projector = torch.nn.Sequential(
|
| 52 |
+
torch.nn.Linear(1024, 2048),
|
| 53 |
+
torch.nn.GELU(),
|
| 54 |
+
torch.nn.Linear(2048, 2048)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
input_ids,
|
| 60 |
+
input_embeddings=None,
|
| 61 |
+
attention_mask=None,
|
| 62 |
+
labels=None,
|
| 63 |
+
label_smoothing=0.0,
|
| 64 |
+
batch_size_t2i=0,
|
| 65 |
+
batch_size_lm=0,
|
| 66 |
+
batch_size_mmu=0,
|
| 67 |
+
max_seq_length=128,
|
| 68 |
+
labels_mask_text=None,
|
| 69 |
+
labels_mask_image=None,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
|
| 73 |
+
if input_embeddings is None:
|
| 74 |
+
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
|
| 75 |
+
else:
|
| 76 |
+
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
|
| 77 |
+
|
| 78 |
+
if labels is not None:
|
| 79 |
+
# 1. Mask token prediction (discrete diffusion) for image generation
|
| 80 |
+
# Note that, max_seq_length indicates the maximum number of text tokens, maybe a bit confused.
|
| 81 |
+
loss_t2i = F.cross_entropy(
|
| 82 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
| 83 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# 2. Next token prediction for language modeling
|
| 87 |
+
loss_lm = F.cross_entropy(
|
| 88 |
+
logits[batch_size_t2i:batch_size_t2i + batch_size_lm, :-1].contiguous().view(-1, self.output_size),
|
| 89 |
+
labels[batch_size_t2i:batch_size_t2i + batch_size_lm, 1:].contiguous().view(-1), ignore_index=-100,
|
| 90 |
+
)
|
| 91 |
+
# loss_lm = torch.tensor(0.0, device=logits.device)
|
| 92 |
+
|
| 93 |
+
# 3. Next token prediction for captioning/multimodal understanding
|
| 94 |
+
loss_mmu = F.cross_entropy(
|
| 95 |
+
logits[-batch_size_mmu:, :-1].contiguous().view(-1, self.output_size),
|
| 96 |
+
labels[-batch_size_mmu:, 1:].contiguous().view(-1), ignore_index=-100,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return logits, loss_t2i, loss_lm, loss_mmu
|
| 100 |
+
|
| 101 |
+
return logits
|
| 102 |
+
|
| 103 |
+
def t2i_generate(
|
| 104 |
+
self,
|
| 105 |
+
input_ids: torch.LongTensor = None,
|
| 106 |
+
uncond_input_ids: torch.LongTensor = None,
|
| 107 |
+
attention_mask=None,
|
| 108 |
+
temperature=1.0,
|
| 109 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
| 110 |
+
guidance_scale=0,
|
| 111 |
+
noise_schedule=cosine_schedule,
|
| 112 |
+
generator: torch.Generator = None,
|
| 113 |
+
):
|
| 114 |
+
# begin with all image token ids masked
|
| 115 |
+
mask_token_id = self.config.mask_token_id
|
| 116 |
+
num_vq_tokens = self.config.num_vq_tokens
|
| 117 |
+
num_new_special_tokens = 10
|
| 118 |
+
llm_vocab_size = self.config.llm_vocab_size
|
| 119 |
+
max_seq_length = 381
|
| 120 |
+
|
| 121 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
|
| 122 |
+
input_ids_minus_lm_vocab_size = torch.where(
|
| 123 |
+
input_ids_minus_lm_vocab_size == mask_token_id,
|
| 124 |
+
mask_token_id,
|
| 125 |
+
input_ids_minus_lm_vocab_size - llm_vocab_size - num_new_special_tokens
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# for classifier-free guidance
|
| 129 |
+
if uncond_input_ids is not None:
|
| 130 |
+
uncond_prefix = uncond_input_ids[:, :max_seq_length + 1]
|
| 131 |
+
|
| 132 |
+
for step in range(timesteps):
|
| 133 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
| 134 |
+
uncond_input_ids = torch.cat(
|
| 135 |
+
[uncond_prefix, input_ids[:, max_seq_length + 1:]], dim=1)
|
| 136 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
| 137 |
+
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
|
| 138 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
| 139 |
+
# it seems that muse has a different cfg setting
|
| 140 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
| 141 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, llm_vocab_size + num_new_special_tokens:-1]
|
| 142 |
+
else:
|
| 143 |
+
logits = self(input_ids, attention_mask=attention_mask)
|
| 144 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, llm_vocab_size + num_new_special_tokens:-1]
|
| 145 |
+
|
| 146 |
+
probs = logits.softmax(dim=-1)
|
| 147 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
| 148 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
|
| 149 |
+
|
| 150 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
| 151 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
| 152 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 153 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 154 |
+
ratio = 1.0 * (step + 1) / timesteps
|
| 155 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
| 156 |
+
# Computes the probabilities of each selected tokens.
|
| 157 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
| 158 |
+
selected_probs = selected_probs.squeeze(-1)
|
| 159 |
+
|
| 160 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
| 161 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
| 162 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
| 163 |
+
mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
| 164 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 165 |
+
# one and for the next iteration
|
| 166 |
+
mask_len = torch.max(
|
| 167 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
| 168 |
+
)
|
| 169 |
+
# Adds noise for randomness
|
| 170 |
+
temperature = temperature * (1.0 - ratio)
|
| 171 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
| 172 |
+
# Masks tokens with lower confidence.
|
| 173 |
+
input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
|
| 174 |
+
sampled_ids + llm_vocab_size
|
| 175 |
+
+ num_new_special_tokens)
|
| 176 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
| 177 |
+
|
| 178 |
+
return sampled_ids
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
|
| 182 |
+
"""
|
| 183 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 184 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 185 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
device = idx.device
|
| 189 |
+
except:
|
| 190 |
+
device = input_embeddings.device
|
| 191 |
+
|
| 192 |
+
result = []
|
| 193 |
+
for _ in range(max_new_tokens):
|
| 194 |
+
# if the sequence context is growing too long we must crop it at block_size
|
| 195 |
+
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
| 196 |
+
# forward the model to get the logits for the index in the sequence
|
| 197 |
+
# logits, _ = self(idx_cond)
|
| 198 |
+
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
|
| 199 |
+
|
| 200 |
+
L = attention_mask.shape[-1]
|
| 201 |
+
attention_mask = attention_mask.squeeze()
|
| 202 |
+
attention_mask_a = torch.hstack(
|
| 203 |
+
[
|
| 204 |
+
attention_mask, # L, L
|
| 205 |
+
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
attention_mask_b = torch.vstack(
|
| 209 |
+
[
|
| 210 |
+
attention_mask_a, # L, L+1
|
| 211 |
+
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
|
| 212 |
+
]
|
| 213 |
+
)
|
| 214 |
+
attention_mask = attention_mask_b
|
| 215 |
+
|
| 216 |
+
# pluck the logits at the final step and scale by desired temperature
|
| 217 |
+
logits = logits[:, -1, :] / temperature
|
| 218 |
+
# optionally crop the logits to only the top k options
|
| 219 |
+
if top_k is not None:
|
| 220 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 221 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 222 |
+
# apply softmax to convert logits to (normalized) probabilities
|
| 223 |
+
probs = F.softmax(logits, dim=-1)
|
| 224 |
+
# sample from the distribution
|
| 225 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 226 |
+
result.append(idx_next[0][0])
|
| 227 |
+
# append sampled index to the running sequence and continue
|
| 228 |
+
if self.config.w_clip_vit:
|
| 229 |
+
idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
|
| 230 |
+
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
|
| 231 |
+
else:
|
| 232 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 233 |
+
|
| 234 |
+
if eot_token is not None and idx_next.cpu() == eot_token:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
return result
|
mcp_servers/fashion_vlm/models/modeling_utils.py
ADDED
|
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import itertools
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from functools import partial
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import safetensors
|
| 28 |
+
import torch
|
| 29 |
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
| 30 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
| 31 |
+
from torch import Tensor, nn
|
| 32 |
+
|
| 33 |
+
from diffusers import __version__
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
FLAX_WEIGHTS_NAME,
|
| 36 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
| 37 |
+
WEIGHTS_INDEX_NAME,
|
| 38 |
+
_add_variant,
|
| 39 |
+
_get_checkpoint_shard_files,
|
| 40 |
+
_get_model_file,
|
| 41 |
+
deprecate,
|
| 42 |
+
is_accelerate_available,
|
| 43 |
+
is_torch_version,
|
| 44 |
+
logging,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
CONFIG_NAME = "config.json"
|
| 48 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 49 |
+
SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
|
| 50 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
| 51 |
+
|
| 52 |
+
from diffusers.utils.hub_utils import (
|
| 53 |
+
PushToHubMixin,
|
| 54 |
+
load_or_create_model_card,
|
| 55 |
+
populate_model_card,
|
| 56 |
+
)
|
| 57 |
+
from diffusers.models.model_loading_utils import (
|
| 58 |
+
_determine_device_map,
|
| 59 |
+
_fetch_index_file,
|
| 60 |
+
_load_state_dict_into_model,
|
| 61 |
+
load_model_dict_into_meta,
|
| 62 |
+
load_state_dict,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 66 |
+
|
| 67 |
+
logger = logging.get_logger(__name__)
|
| 68 |
+
|
| 69 |
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if is_torch_version(">=", "1.9.0"):
|
| 73 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
| 74 |
+
else:
|
| 75 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if is_accelerate_available():
|
| 79 |
+
import accelerate
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
| 83 |
+
try:
|
| 84 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
| 85 |
+
return next(parameters_and_buffers).device
|
| 86 |
+
except StopIteration:
|
| 87 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 88 |
+
|
| 89 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 90 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 91 |
+
return tuples
|
| 92 |
+
|
| 93 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 94 |
+
first_tuple = next(gen)
|
| 95 |
+
return first_tuple[1].device
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
| 99 |
+
try:
|
| 100 |
+
params = tuple(parameter.parameters())
|
| 101 |
+
if len(params) > 0:
|
| 102 |
+
return params[0].dtype
|
| 103 |
+
|
| 104 |
+
buffers = tuple(parameter.buffers())
|
| 105 |
+
if len(buffers) > 0:
|
| 106 |
+
return buffers[0].dtype
|
| 107 |
+
|
| 108 |
+
except StopIteration:
|
| 109 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 110 |
+
|
| 111 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 112 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 113 |
+
return tuples
|
| 114 |
+
|
| 115 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 116 |
+
first_tuple = next(gen)
|
| 117 |
+
return first_tuple[1].dtype
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
| 121 |
+
r"""
|
| 122 |
+
Base class for all models.
|
| 123 |
+
|
| 124 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
| 125 |
+
saving models.
|
| 126 |
+
|
| 127 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
config_name = CONFIG_NAME
|
| 131 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
| 132 |
+
_supports_gradient_checkpointing = False
|
| 133 |
+
_keys_to_ignore_on_load_unexpected = None
|
| 134 |
+
_no_split_modules = None
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
super().__init__()
|
| 138 |
+
|
| 139 |
+
def __getattr__(self, name: str) -> Any:
|
| 140 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
| 141 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
| 142 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
| 143 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
| 147 |
+
is_attribute = name in self.__dict__
|
| 148 |
+
|
| 149 |
+
if is_in_config and not is_attribute:
|
| 150 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
| 151 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
| 152 |
+
return self._internal_dict[name]
|
| 153 |
+
|
| 154 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
| 155 |
+
return super().__getattr__(name)
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def is_gradient_checkpointing(self) -> bool:
|
| 159 |
+
"""
|
| 160 |
+
Whether gradient checkpointing is activated for this model or not.
|
| 161 |
+
"""
|
| 162 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
| 163 |
+
|
| 164 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
| 167 |
+
*checkpoint activations* in other frameworks).
|
| 168 |
+
"""
|
| 169 |
+
if not self._supports_gradient_checkpointing:
|
| 170 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 171 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
| 172 |
+
|
| 173 |
+
def disable_gradient_checkpointing(self) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
| 176 |
+
*checkpoint activations* in other frameworks).
|
| 177 |
+
"""
|
| 178 |
+
if self._supports_gradient_checkpointing:
|
| 179 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
| 180 |
+
|
| 181 |
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
| 182 |
+
r"""
|
| 183 |
+
Set the switch for the npu flash attention.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
| 187 |
+
if hasattr(module, "set_use_npu_flash_attention"):
|
| 188 |
+
module.set_use_npu_flash_attention(valid)
|
| 189 |
+
|
| 190 |
+
for child in module.children():
|
| 191 |
+
fn_recursive_set_npu_flash_attention(child)
|
| 192 |
+
|
| 193 |
+
for module in self.children():
|
| 194 |
+
if isinstance(module, torch.nn.Module):
|
| 195 |
+
fn_recursive_set_npu_flash_attention(module)
|
| 196 |
+
|
| 197 |
+
def enable_npu_flash_attention(self) -> None:
|
| 198 |
+
r"""
|
| 199 |
+
Enable npu flash attention from torch_npu
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
self.set_use_npu_flash_attention(True)
|
| 203 |
+
|
| 204 |
+
def disable_npu_flash_attention(self) -> None:
|
| 205 |
+
r"""
|
| 206 |
+
disable npu flash attention from torch_npu
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
self.set_use_npu_flash_attention(False)
|
| 210 |
+
|
| 211 |
+
def set_use_memory_efficient_attention_xformers(
|
| 212 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
| 213 |
+
) -> None:
|
| 214 |
+
# Recursively walk through all the children.
|
| 215 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
| 216 |
+
# gets the message
|
| 217 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 218 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 219 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
| 220 |
+
|
| 221 |
+
for child in module.children():
|
| 222 |
+
fn_recursive_set_mem_eff(child)
|
| 223 |
+
|
| 224 |
+
for module in self.children():
|
| 225 |
+
if isinstance(module, torch.nn.Module):
|
| 226 |
+
fn_recursive_set_mem_eff(module)
|
| 227 |
+
|
| 228 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
| 229 |
+
r"""
|
| 230 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
| 231 |
+
|
| 232 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
| 233 |
+
inference. Speed up during training is not guaranteed.
|
| 234 |
+
|
| 235 |
+
<Tip warning={true}>
|
| 236 |
+
|
| 237 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
| 238 |
+
precedent.
|
| 239 |
+
|
| 240 |
+
</Tip>
|
| 241 |
+
|
| 242 |
+
Parameters:
|
| 243 |
+
attention_op (`Callable`, *optional*):
|
| 244 |
+
Override the default `None` operator for use as `op` argument to the
|
| 245 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
| 246 |
+
function of xFormers.
|
| 247 |
+
|
| 248 |
+
Examples:
|
| 249 |
+
|
| 250 |
+
```py
|
| 251 |
+
>>> import torch
|
| 252 |
+
>>> from diffusers import UNet2DConditionModel
|
| 253 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 254 |
+
|
| 255 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
| 256 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
| 257 |
+
... )
|
| 258 |
+
>>> model = model.to("cuda")
|
| 259 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 260 |
+
```
|
| 261 |
+
"""
|
| 262 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
| 263 |
+
|
| 264 |
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
| 265 |
+
r"""
|
| 266 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
| 267 |
+
"""
|
| 268 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
| 269 |
+
|
| 270 |
+
def save_pretrained(
|
| 271 |
+
self,
|
| 272 |
+
save_directory: Union[str, os.PathLike],
|
| 273 |
+
is_main_process: bool = True,
|
| 274 |
+
save_function: Optional[Callable] = None,
|
| 275 |
+
safe_serialization: bool = True,
|
| 276 |
+
variant: Optional[str] = None,
|
| 277 |
+
max_shard_size: Union[int, str] = "10GB",
|
| 278 |
+
push_to_hub: bool = False,
|
| 279 |
+
**kwargs,
|
| 280 |
+
):
|
| 281 |
+
"""
|
| 282 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
| 283 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
| 284 |
+
|
| 285 |
+
Arguments:
|
| 286 |
+
save_directory (`str` or `os.PathLike`):
|
| 287 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
| 288 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
| 289 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
| 290 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
| 291 |
+
process to avoid race conditions.
|
| 292 |
+
save_function (`Callable`):
|
| 293 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
| 294 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
| 295 |
+
`DIFFUSERS_SAVE_MODE`.
|
| 296 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
| 297 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
| 298 |
+
variant (`str`, *optional*):
|
| 299 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
| 300 |
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
| 301 |
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
| 302 |
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
| 303 |
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
| 304 |
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
| 305 |
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
| 306 |
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
| 307 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 308 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 309 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 310 |
+
namespace).
|
| 311 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 312 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 313 |
+
"""
|
| 314 |
+
if os.path.isfile(save_directory):
|
| 315 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 316 |
+
return
|
| 317 |
+
|
| 318 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
| 319 |
+
weights_name = _add_variant(weights_name, variant)
|
| 320 |
+
weight_name_split = weights_name.split(".")
|
| 321 |
+
if len(weight_name_split) in [2, 3]:
|
| 322 |
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError(f"Invalid {weights_name} provided.")
|
| 325 |
+
|
| 326 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 327 |
+
|
| 328 |
+
if push_to_hub:
|
| 329 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 330 |
+
private = kwargs.pop("private", False)
|
| 331 |
+
create_pr = kwargs.pop("create_pr", False)
|
| 332 |
+
token = kwargs.pop("token", None)
|
| 333 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 334 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
| 335 |
+
|
| 336 |
+
# Only save the model itself if we are using distributed training
|
| 337 |
+
model_to_save = self
|
| 338 |
+
|
| 339 |
+
# Attach architecture to the config
|
| 340 |
+
# Save the config
|
| 341 |
+
if is_main_process:
|
| 342 |
+
model_to_save.save_config(save_directory)
|
| 343 |
+
|
| 344 |
+
# Save the model
|
| 345 |
+
state_dict = model_to_save.state_dict()
|
| 346 |
+
|
| 347 |
+
# Save the model
|
| 348 |
+
state_dict_split = split_torch_state_dict_into_shards(
|
| 349 |
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Clean the folder from a previous save
|
| 353 |
+
if is_main_process:
|
| 354 |
+
for filename in os.listdir(save_directory):
|
| 355 |
+
if filename in state_dict_split.filename_to_tensors.keys():
|
| 356 |
+
continue
|
| 357 |
+
full_filename = os.path.join(save_directory, filename)
|
| 358 |
+
if not os.path.isfile(full_filename):
|
| 359 |
+
continue
|
| 360 |
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
| 361 |
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
| 362 |
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
| 363 |
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
| 364 |
+
if (
|
| 365 |
+
filename.startswith(weights_without_ext)
|
| 366 |
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
| 367 |
+
):
|
| 368 |
+
os.remove(full_filename)
|
| 369 |
+
|
| 370 |
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
| 371 |
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
| 372 |
+
filepath = os.path.join(save_directory, filename)
|
| 373 |
+
if safe_serialization:
|
| 374 |
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
| 375 |
+
# joyfulness), but for now this enough.
|
| 376 |
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
| 377 |
+
else:
|
| 378 |
+
torch.save(shard, filepath)
|
| 379 |
+
|
| 380 |
+
if state_dict_split.is_sharded:
|
| 381 |
+
index = {
|
| 382 |
+
"metadata": state_dict_split.metadata,
|
| 383 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
| 384 |
+
}
|
| 385 |
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
| 386 |
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
| 387 |
+
# Save the index as well
|
| 388 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
| 389 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 390 |
+
f.write(content)
|
| 391 |
+
logger.info(
|
| 392 |
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
| 393 |
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
| 394 |
+
f"index located at {save_index_file}."
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
path_to_weights = os.path.join(save_directory, weights_name)
|
| 398 |
+
logger.info(f"Model weights saved in {path_to_weights}")
|
| 399 |
+
|
| 400 |
+
if push_to_hub:
|
| 401 |
+
# Create a new empty model card and eventually tag it
|
| 402 |
+
model_card = load_or_create_model_card(repo_id, token=token)
|
| 403 |
+
model_card = populate_model_card(model_card)
|
| 404 |
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
| 405 |
+
|
| 406 |
+
self._upload_folder(
|
| 407 |
+
save_directory,
|
| 408 |
+
repo_id,
|
| 409 |
+
token=token,
|
| 410 |
+
commit_message=commit_message,
|
| 411 |
+
create_pr=create_pr,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
@classmethod
|
| 415 |
+
@validate_hf_hub_args
|
| 416 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
| 417 |
+
r"""
|
| 418 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
| 419 |
+
|
| 420 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
| 421 |
+
train the model, set it back in training mode with `model.train()`.
|
| 422 |
+
|
| 423 |
+
Parameters:
|
| 424 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 425 |
+
Can be either:
|
| 426 |
+
|
| 427 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 428 |
+
the Hub.
|
| 429 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
| 430 |
+
with [`~ModelMixin.save_pretrained`].
|
| 431 |
+
|
| 432 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 433 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 434 |
+
is not used.
|
| 435 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
| 436 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
| 437 |
+
dtype is automatically derived from the model's weights.
|
| 438 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 439 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 440 |
+
cached versions if they exist.
|
| 441 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 442 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 443 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 444 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
| 445 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 446 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 447 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 448 |
+
won't be downloaded from the Hub.
|
| 449 |
+
token (`str` or *bool*, *optional*):
|
| 450 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 451 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 452 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 453 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 454 |
+
allowed by Git.
|
| 455 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
| 456 |
+
Load the model weights from a Flax checkpoint save file.
|
| 457 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 458 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 459 |
+
mirror (`str`, *optional*):
|
| 460 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
| 461 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
| 462 |
+
information.
|
| 463 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 464 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
| 465 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
| 466 |
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
| 467 |
+
|
| 468 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
| 469 |
+
more information about each option see [designing a device
|
| 470 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
| 471 |
+
max_memory (`Dict`, *optional*):
|
| 472 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
| 473 |
+
each GPU and the available CPU RAM if unset.
|
| 474 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
| 475 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
| 476 |
+
offload_state_dict (`bool`, *optional*):
|
| 477 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
| 478 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
| 479 |
+
when there is some disk offload.
|
| 480 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
| 481 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
| 482 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
| 483 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
| 484 |
+
argument to `True` will raise an error.
|
| 485 |
+
variant (`str`, *optional*):
|
| 486 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
| 487 |
+
loading `from_flax`.
|
| 488 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
| 489 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
| 490 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
| 491 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
| 492 |
+
|
| 493 |
+
<Tip>
|
| 494 |
+
|
| 495 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
| 496 |
+
`huggingface-cli login`. You can also activate the special
|
| 497 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
| 498 |
+
firewalled environment.
|
| 499 |
+
|
| 500 |
+
</Tip>
|
| 501 |
+
|
| 502 |
+
Example:
|
| 503 |
+
|
| 504 |
+
```py
|
| 505 |
+
from diffusers import UNet2DConditionModel
|
| 506 |
+
|
| 507 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
| 508 |
+
```
|
| 509 |
+
|
| 510 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
| 511 |
+
|
| 512 |
+
```bash
|
| 513 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
| 514 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
| 515 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
| 516 |
+
```
|
| 517 |
+
"""
|
| 518 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 519 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 520 |
+
force_download = kwargs.pop("force_download", False)
|
| 521 |
+
from_flax = kwargs.pop("from_flax", False)
|
| 522 |
+
proxies = kwargs.pop("proxies", None)
|
| 523 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
| 524 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 525 |
+
token = kwargs.pop("token", None)
|
| 526 |
+
revision = kwargs.pop("revision", None)
|
| 527 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
| 528 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 529 |
+
device_map = kwargs.pop("device_map", None)
|
| 530 |
+
max_memory = kwargs.pop("max_memory", None)
|
| 531 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
| 532 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
| 533 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
| 534 |
+
variant = kwargs.pop("variant", None)
|
| 535 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
| 536 |
+
|
| 537 |
+
allow_pickle = False
|
| 538 |
+
if use_safetensors is None:
|
| 539 |
+
use_safetensors = True
|
| 540 |
+
allow_pickle = True
|
| 541 |
+
|
| 542 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
| 543 |
+
low_cpu_mem_usage = False
|
| 544 |
+
logger.warning(
|
| 545 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
| 546 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
| 547 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
| 548 |
+
" install accelerate\n```\n."
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if device_map is not None and not is_accelerate_available():
|
| 552 |
+
raise NotImplementedError(
|
| 553 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
| 554 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Check if we can handle device_map and dispatching the weights
|
| 558 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
| 559 |
+
raise NotImplementedError(
|
| 560 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 561 |
+
" `device_map=None`."
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
| 565 |
+
raise NotImplementedError(
|
| 566 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 567 |
+
" `low_cpu_mem_usage=False`."
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
| 571 |
+
raise ValueError(
|
| 572 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
| 573 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# change device_map into a map if we passed an int, a str or a torch.device
|
| 577 |
+
if isinstance(device_map, torch.device):
|
| 578 |
+
device_map = {"": device_map}
|
| 579 |
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
| 580 |
+
try:
|
| 581 |
+
device_map = {"": torch.device(device_map)}
|
| 582 |
+
except RuntimeError:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
| 585 |
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
| 586 |
+
)
|
| 587 |
+
elif isinstance(device_map, int):
|
| 588 |
+
if device_map < 0:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
device_map = {"": device_map}
|
| 594 |
+
|
| 595 |
+
if device_map is not None:
|
| 596 |
+
if low_cpu_mem_usage is None:
|
| 597 |
+
low_cpu_mem_usage = True
|
| 598 |
+
elif not low_cpu_mem_usage:
|
| 599 |
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
| 600 |
+
|
| 601 |
+
if low_cpu_mem_usage:
|
| 602 |
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
| 603 |
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
| 604 |
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
| 605 |
+
|
| 606 |
+
# Load config if we don't provide a configuration
|
| 607 |
+
config_path = pretrained_model_name_or_path
|
| 608 |
+
|
| 609 |
+
user_agent = {
|
| 610 |
+
"diffusers": __version__,
|
| 611 |
+
"file_type": "model",
|
| 612 |
+
"framework": "pytorch",
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
# load config
|
| 616 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
| 617 |
+
config_path,
|
| 618 |
+
cache_dir=cache_dir,
|
| 619 |
+
return_unused_kwargs=True,
|
| 620 |
+
return_commit_hash=True,
|
| 621 |
+
force_download=force_download,
|
| 622 |
+
proxies=proxies,
|
| 623 |
+
local_files_only=local_files_only,
|
| 624 |
+
token=token,
|
| 625 |
+
revision=revision,
|
| 626 |
+
subfolder=subfolder,
|
| 627 |
+
user_agent=user_agent,
|
| 628 |
+
**kwargs,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Determine if we're loading from a directory of sharded checkpoints.
|
| 632 |
+
is_sharded = False
|
| 633 |
+
index_file = None
|
| 634 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 635 |
+
index_file = _fetch_index_file(
|
| 636 |
+
is_local=is_local,
|
| 637 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 638 |
+
subfolder=subfolder or "",
|
| 639 |
+
use_safetensors=use_safetensors,
|
| 640 |
+
cache_dir=cache_dir,
|
| 641 |
+
variant=variant,
|
| 642 |
+
force_download=force_download,
|
| 643 |
+
proxies=proxies,
|
| 644 |
+
local_files_only=local_files_only,
|
| 645 |
+
token=token,
|
| 646 |
+
revision=revision,
|
| 647 |
+
user_agent=user_agent,
|
| 648 |
+
commit_hash=commit_hash,
|
| 649 |
+
)
|
| 650 |
+
if index_file is not None and index_file.is_file():
|
| 651 |
+
is_sharded = True
|
| 652 |
+
|
| 653 |
+
if is_sharded and from_flax:
|
| 654 |
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
| 655 |
+
|
| 656 |
+
# load model
|
| 657 |
+
model_file = None
|
| 658 |
+
if from_flax:
|
| 659 |
+
model_file = _get_model_file(
|
| 660 |
+
pretrained_model_name_or_path,
|
| 661 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
| 662 |
+
cache_dir=cache_dir,
|
| 663 |
+
force_download=force_download,
|
| 664 |
+
proxies=proxies,
|
| 665 |
+
local_files_only=local_files_only,
|
| 666 |
+
token=token,
|
| 667 |
+
revision=revision,
|
| 668 |
+
subfolder=subfolder,
|
| 669 |
+
user_agent=user_agent,
|
| 670 |
+
commit_hash=commit_hash,
|
| 671 |
+
)
|
| 672 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 673 |
+
|
| 674 |
+
# Convert the weights
|
| 675 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
| 676 |
+
|
| 677 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
| 678 |
+
else:
|
| 679 |
+
if is_sharded:
|
| 680 |
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
| 681 |
+
pretrained_model_name_or_path,
|
| 682 |
+
index_file,
|
| 683 |
+
cache_dir=cache_dir,
|
| 684 |
+
proxies=proxies,
|
| 685 |
+
local_files_only=local_files_only,
|
| 686 |
+
token=token,
|
| 687 |
+
user_agent=user_agent,
|
| 688 |
+
revision=revision,
|
| 689 |
+
subfolder=subfolder or "",
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
elif use_safetensors and not is_sharded:
|
| 693 |
+
try:
|
| 694 |
+
model_file = _get_model_file(
|
| 695 |
+
pretrained_model_name_or_path,
|
| 696 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
| 697 |
+
cache_dir=cache_dir,
|
| 698 |
+
force_download=force_download,
|
| 699 |
+
proxies=proxies,
|
| 700 |
+
local_files_only=local_files_only,
|
| 701 |
+
token=token,
|
| 702 |
+
revision=revision,
|
| 703 |
+
subfolder=subfolder,
|
| 704 |
+
user_agent=user_agent,
|
| 705 |
+
commit_hash=commit_hash,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
except IOError as e:
|
| 709 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
| 710 |
+
if not allow_pickle:
|
| 711 |
+
raise
|
| 712 |
+
logger.warning(
|
| 713 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if model_file is None and not is_sharded:
|
| 717 |
+
model_file = _get_model_file(
|
| 718 |
+
pretrained_model_name_or_path,
|
| 719 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
| 720 |
+
cache_dir=cache_dir,
|
| 721 |
+
force_download=force_download,
|
| 722 |
+
proxies=proxies,
|
| 723 |
+
local_files_only=local_files_only,
|
| 724 |
+
token=token,
|
| 725 |
+
revision=revision,
|
| 726 |
+
subfolder=subfolder,
|
| 727 |
+
user_agent=user_agent,
|
| 728 |
+
commit_hash=commit_hash,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if low_cpu_mem_usage:
|
| 732 |
+
# Instantiate model with empty weights
|
| 733 |
+
with accelerate.init_empty_weights():
|
| 734 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 735 |
+
|
| 736 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
| 737 |
+
if device_map is None and not is_sharded:
|
| 738 |
+
param_device = "cpu"
|
| 739 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
| 740 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 741 |
+
# move the params from meta device to cpu
|
| 742 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 743 |
+
if len(missing_keys) > 0:
|
| 744 |
+
raise ValueError(
|
| 745 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
| 746 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 747 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 748 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 752 |
+
model,
|
| 753 |
+
state_dict,
|
| 754 |
+
device=param_device,
|
| 755 |
+
dtype=torch_dtype,
|
| 756 |
+
model_name_or_path=pretrained_model_name_or_path,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 760 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 761 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 762 |
+
|
| 763 |
+
if len(unexpected_keys) > 0:
|
| 764 |
+
logger.warning(
|
| 765 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
else: # else let accelerate handle loading and dispatching.
|
| 769 |
+
# Load weights and dispatch according to the device_map
|
| 770 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
| 771 |
+
force_hook = True
|
| 772 |
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
| 773 |
+
if device_map is None and is_sharded:
|
| 774 |
+
# we load the parameters on the cpu
|
| 775 |
+
device_map = {"": "cpu"}
|
| 776 |
+
force_hook = False
|
| 777 |
+
try:
|
| 778 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 779 |
+
model,
|
| 780 |
+
model_file if not is_sharded else index_file,
|
| 781 |
+
device_map,
|
| 782 |
+
max_memory=max_memory,
|
| 783 |
+
offload_folder=offload_folder,
|
| 784 |
+
offload_state_dict=offload_state_dict,
|
| 785 |
+
dtype=torch_dtype,
|
| 786 |
+
force_hooks=force_hook,
|
| 787 |
+
strict=True,
|
| 788 |
+
)
|
| 789 |
+
except AttributeError as e:
|
| 790 |
+
# When using accelerate loading, we do not have the ability to load the state
|
| 791 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
| 792 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
| 793 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
| 794 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
| 795 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
| 796 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
| 797 |
+
# the weights so we don't have to do this again.
|
| 798 |
+
|
| 799 |
+
if "'Attention' object has no attribute" in str(e):
|
| 800 |
+
logger.warning(
|
| 801 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
| 802 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
| 803 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
| 804 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
| 805 |
+
" please also re-upload it or open a PR on the original repository."
|
| 806 |
+
)
|
| 807 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
| 808 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 809 |
+
model,
|
| 810 |
+
model_file if not is_sharded else index_file,
|
| 811 |
+
device_map,
|
| 812 |
+
max_memory=max_memory,
|
| 813 |
+
offload_folder=offload_folder,
|
| 814 |
+
offload_state_dict=offload_state_dict,
|
| 815 |
+
dtype=torch_dtype,
|
| 816 |
+
force_hooks=force_hook,
|
| 817 |
+
strict=True,
|
| 818 |
+
)
|
| 819 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
| 820 |
+
else:
|
| 821 |
+
raise e
|
| 822 |
+
|
| 823 |
+
loading_info = {
|
| 824 |
+
"missing_keys": [],
|
| 825 |
+
"unexpected_keys": [],
|
| 826 |
+
"mismatched_keys": [],
|
| 827 |
+
"error_msgs": [],
|
| 828 |
+
}
|
| 829 |
+
else:
|
| 830 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 831 |
+
|
| 832 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
| 833 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 834 |
+
|
| 835 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
| 836 |
+
model,
|
| 837 |
+
state_dict,
|
| 838 |
+
model_file,
|
| 839 |
+
pretrained_model_name_or_path,
|
| 840 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
loading_info = {
|
| 844 |
+
"missing_keys": missing_keys,
|
| 845 |
+
"unexpected_keys": unexpected_keys,
|
| 846 |
+
"mismatched_keys": mismatched_keys,
|
| 847 |
+
"error_msgs": error_msgs,
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
| 851 |
+
raise ValueError(
|
| 852 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
| 853 |
+
)
|
| 854 |
+
elif torch_dtype is not None:
|
| 855 |
+
model = model.to(torch_dtype)
|
| 856 |
+
|
| 857 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
| 858 |
+
|
| 859 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
| 860 |
+
model.eval()
|
| 861 |
+
if output_loading_info:
|
| 862 |
+
return model, loading_info
|
| 863 |
+
|
| 864 |
+
return model
|
| 865 |
+
|
| 866 |
+
@classmethod
|
| 867 |
+
def _load_pretrained_model(
|
| 868 |
+
cls,
|
| 869 |
+
model,
|
| 870 |
+
state_dict: OrderedDict,
|
| 871 |
+
resolved_archive_file,
|
| 872 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 873 |
+
ignore_mismatched_sizes: bool = False,
|
| 874 |
+
):
|
| 875 |
+
# Retrieve missing & unexpected_keys
|
| 876 |
+
model_state_dict = model.state_dict()
|
| 877 |
+
loaded_keys = list(state_dict.keys())
|
| 878 |
+
|
| 879 |
+
expected_keys = list(model_state_dict.keys())
|
| 880 |
+
|
| 881 |
+
original_loaded_keys = loaded_keys
|
| 882 |
+
|
| 883 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
| 884 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
| 885 |
+
|
| 886 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
| 887 |
+
model_to_load = model
|
| 888 |
+
|
| 889 |
+
def _find_mismatched_keys(
|
| 890 |
+
state_dict,
|
| 891 |
+
model_state_dict,
|
| 892 |
+
loaded_keys,
|
| 893 |
+
ignore_mismatched_sizes,
|
| 894 |
+
):
|
| 895 |
+
mismatched_keys = []
|
| 896 |
+
if ignore_mismatched_sizes:
|
| 897 |
+
for checkpoint_key in loaded_keys:
|
| 898 |
+
model_key = checkpoint_key
|
| 899 |
+
|
| 900 |
+
if (
|
| 901 |
+
model_key in model_state_dict
|
| 902 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
| 903 |
+
):
|
| 904 |
+
mismatched_keys.append(
|
| 905 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
| 906 |
+
)
|
| 907 |
+
del state_dict[checkpoint_key]
|
| 908 |
+
return mismatched_keys
|
| 909 |
+
|
| 910 |
+
if state_dict is not None:
|
| 911 |
+
# Whole checkpoint
|
| 912 |
+
mismatched_keys = _find_mismatched_keys(
|
| 913 |
+
state_dict,
|
| 914 |
+
model_state_dict,
|
| 915 |
+
original_loaded_keys,
|
| 916 |
+
ignore_mismatched_sizes,
|
| 917 |
+
)
|
| 918 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
| 919 |
+
|
| 920 |
+
if len(error_msgs) > 0:
|
| 921 |
+
error_msg = "\n\t".join(error_msgs)
|
| 922 |
+
if "size mismatch" in error_msg:
|
| 923 |
+
error_msg += (
|
| 924 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
| 925 |
+
)
|
| 926 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
| 927 |
+
|
| 928 |
+
if len(unexpected_keys) > 0:
|
| 929 |
+
logger.warning(
|
| 930 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
| 931 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
| 932 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
| 933 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
| 934 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
| 935 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
| 936 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
| 937 |
+
" BertForSequenceClassification model)."
|
| 938 |
+
)
|
| 939 |
+
else:
|
| 940 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
| 941 |
+
if len(missing_keys) > 0:
|
| 942 |
+
logger.warning(
|
| 943 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 944 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
| 945 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 946 |
+
)
|
| 947 |
+
elif len(mismatched_keys) == 0:
|
| 948 |
+
logger.info(
|
| 949 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
| 950 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
| 951 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
| 952 |
+
" without further training."
|
| 953 |
+
)
|
| 954 |
+
if len(mismatched_keys) > 0:
|
| 955 |
+
mismatched_warning = "\n".join(
|
| 956 |
+
[
|
| 957 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
| 958 |
+
for key, shape1, shape2 in mismatched_keys
|
| 959 |
+
]
|
| 960 |
+
)
|
| 961 |
+
logger.warning(
|
| 962 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 963 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
| 964 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
| 965 |
+
" able to use it for predictions and inference."
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
| 969 |
+
|
| 970 |
+
@classmethod
|
| 971 |
+
def _get_signature_keys(cls, obj):
|
| 972 |
+
parameters = inspect.signature(obj.__init__).parameters
|
| 973 |
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
| 974 |
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
| 975 |
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
| 976 |
+
|
| 977 |
+
return expected_modules, optional_parameters
|
| 978 |
+
|
| 979 |
+
# Adapted from `transformers` modeling_utils.py
|
| 980 |
+
def _get_no_split_modules(self, device_map: str):
|
| 981 |
+
"""
|
| 982 |
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
| 983 |
+
get the underlying `_no_split_modules`.
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
device_map (`str`):
|
| 987 |
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
`List[str]`: List of modules that should not be split
|
| 991 |
+
"""
|
| 992 |
+
_no_split_modules = set()
|
| 993 |
+
modules_to_check = [self]
|
| 994 |
+
while len(modules_to_check) > 0:
|
| 995 |
+
module = modules_to_check.pop(-1)
|
| 996 |
+
# if the module does not appear in _no_split_modules, we also check the children
|
| 997 |
+
if module.__class__.__name__ not in _no_split_modules:
|
| 998 |
+
if isinstance(module, ModelMixin):
|
| 999 |
+
if module._no_split_modules is None:
|
| 1000 |
+
raise ValueError(
|
| 1001 |
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
| 1002 |
+
"class needs to implement the `_no_split_modules` attribute."
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
| 1006 |
+
modules_to_check += list(module.children())
|
| 1007 |
+
return list(_no_split_modules)
|
| 1008 |
+
|
| 1009 |
+
@property
|
| 1010 |
+
def device(self) -> torch.device:
|
| 1011 |
+
"""
|
| 1012 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
| 1013 |
+
device).
|
| 1014 |
+
"""
|
| 1015 |
+
return get_parameter_device(self)
|
| 1016 |
+
|
| 1017 |
+
@property
|
| 1018 |
+
def dtype(self) -> torch.dtype:
|
| 1019 |
+
"""
|
| 1020 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
| 1021 |
+
"""
|
| 1022 |
+
return get_parameter_dtype(self)
|
| 1023 |
+
|
| 1024 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
| 1025 |
+
"""
|
| 1026 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
| 1027 |
+
|
| 1028 |
+
Args:
|
| 1029 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
| 1030 |
+
Whether or not to return only the number of trainable parameters.
|
| 1031 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
| 1032 |
+
Whether or not to return only the number of non-embedding parameters.
|
| 1033 |
+
|
| 1034 |
+
Returns:
|
| 1035 |
+
`int`: The number of parameters.
|
| 1036 |
+
|
| 1037 |
+
Example:
|
| 1038 |
+
|
| 1039 |
+
```py
|
| 1040 |
+
from diffusers import UNet2DConditionModel
|
| 1041 |
+
|
| 1042 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
| 1043 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
| 1044 |
+
unet.num_parameters(only_trainable=True)
|
| 1045 |
+
859520964
|
| 1046 |
+
```
|
| 1047 |
+
"""
|
| 1048 |
+
|
| 1049 |
+
if exclude_embeddings:
|
| 1050 |
+
embedding_param_names = [
|
| 1051 |
+
f"{name}.weight"
|
| 1052 |
+
for name, module_type in self.named_modules()
|
| 1053 |
+
if isinstance(module_type, torch.nn.Embedding)
|
| 1054 |
+
]
|
| 1055 |
+
non_embedding_parameters = [
|
| 1056 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
| 1057 |
+
]
|
| 1058 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
| 1059 |
+
else:
|
| 1060 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 1061 |
+
|
| 1062 |
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
| 1063 |
+
deprecated_attention_block_paths = []
|
| 1064 |
+
|
| 1065 |
+
def recursive_find_attn_block(name, module):
|
| 1066 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1067 |
+
deprecated_attention_block_paths.append(name)
|
| 1068 |
+
|
| 1069 |
+
for sub_name, sub_module in module.named_children():
|
| 1070 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
| 1071 |
+
recursive_find_attn_block(sub_name, sub_module)
|
| 1072 |
+
|
| 1073 |
+
recursive_find_attn_block("", self)
|
| 1074 |
+
|
| 1075 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
| 1076 |
+
# because it is possible we are loading from a state dict that was already
|
| 1077 |
+
# converted
|
| 1078 |
+
|
| 1079 |
+
for path in deprecated_attention_block_paths:
|
| 1080 |
+
# group_norm path stays the same
|
| 1081 |
+
|
| 1082 |
+
# query -> to_q
|
| 1083 |
+
if f"{path}.query.weight" in state_dict:
|
| 1084 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
| 1085 |
+
if f"{path}.query.bias" in state_dict:
|
| 1086 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
| 1087 |
+
|
| 1088 |
+
# key -> to_k
|
| 1089 |
+
if f"{path}.key.weight" in state_dict:
|
| 1090 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
| 1091 |
+
if f"{path}.key.bias" in state_dict:
|
| 1092 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
| 1093 |
+
|
| 1094 |
+
# value -> to_v
|
| 1095 |
+
if f"{path}.value.weight" in state_dict:
|
| 1096 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
| 1097 |
+
if f"{path}.value.bias" in state_dict:
|
| 1098 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
| 1099 |
+
|
| 1100 |
+
# proj_attn -> to_out.0
|
| 1101 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
| 1102 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
| 1103 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
| 1104 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
| 1105 |
+
|
| 1106 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
| 1107 |
+
deprecated_attention_block_modules = []
|
| 1108 |
+
|
| 1109 |
+
def recursive_find_attn_block(module):
|
| 1110 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1111 |
+
deprecated_attention_block_modules.append(module)
|
| 1112 |
+
|
| 1113 |
+
for sub_module in module.children():
|
| 1114 |
+
recursive_find_attn_block(sub_module)
|
| 1115 |
+
|
| 1116 |
+
recursive_find_attn_block(self)
|
| 1117 |
+
|
| 1118 |
+
for module in deprecated_attention_block_modules:
|
| 1119 |
+
module.query = module.to_q
|
| 1120 |
+
module.key = module.to_k
|
| 1121 |
+
module.value = module.to_v
|
| 1122 |
+
module.proj_attn = module.to_out[0]
|
| 1123 |
+
|
| 1124 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
| 1125 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
| 1126 |
+
# making an incorrect assumption that this model should be converted when
|
| 1127 |
+
# it really shouldn't be.
|
| 1128 |
+
del module.to_q
|
| 1129 |
+
del module.to_k
|
| 1130 |
+
del module.to_v
|
| 1131 |
+
del module.to_out
|
| 1132 |
+
|
| 1133 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
| 1134 |
+
deprecated_attention_block_modules = []
|
| 1135 |
+
|
| 1136 |
+
def recursive_find_attn_block(module) -> None:
|
| 1137 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
| 1138 |
+
deprecated_attention_block_modules.append(module)
|
| 1139 |
+
|
| 1140 |
+
for sub_module in module.children():
|
| 1141 |
+
recursive_find_attn_block(sub_module)
|
| 1142 |
+
|
| 1143 |
+
recursive_find_attn_block(self)
|
| 1144 |
+
|
| 1145 |
+
for module in deprecated_attention_block_modules:
|
| 1146 |
+
module.to_q = module.query
|
| 1147 |
+
module.to_k = module.key
|
| 1148 |
+
module.to_v = module.value
|
| 1149 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
| 1150 |
+
|
| 1151 |
+
del module.query
|
| 1152 |
+
del module.key
|
| 1153 |
+
del module.value
|
| 1154 |
+
del module.proj_attn
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
class LegacyModelMixin(ModelMixin):
|
| 1158 |
+
r"""
|
| 1159 |
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
| 1160 |
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
| 1161 |
+
"""
|
| 1162 |
+
|
| 1163 |
+
@classmethod
|
| 1164 |
+
@validate_hf_hub_args
|
| 1165 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
| 1166 |
+
# To prevent dependency import problem.
|
| 1167 |
+
from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
|
| 1168 |
+
|
| 1169 |
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
| 1170 |
+
kwargs_copy = kwargs.copy()
|
| 1171 |
+
|
| 1172 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 1173 |
+
force_download = kwargs.pop("force_download", False)
|
| 1174 |
+
proxies = kwargs.pop("proxies", None)
|
| 1175 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 1176 |
+
token = kwargs.pop("token", None)
|
| 1177 |
+
revision = kwargs.pop("revision", None)
|
| 1178 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 1179 |
+
|
| 1180 |
+
# Load config if we don't provide a configuration
|
| 1181 |
+
config_path = pretrained_model_name_or_path
|
| 1182 |
+
|
| 1183 |
+
user_agent = {
|
| 1184 |
+
"diffusers": __version__,
|
| 1185 |
+
"file_type": "model",
|
| 1186 |
+
"framework": "pytorch",
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
# load config
|
| 1190 |
+
config, _, _ = cls.load_config(
|
| 1191 |
+
config_path,
|
| 1192 |
+
cache_dir=cache_dir,
|
| 1193 |
+
return_unused_kwargs=True,
|
| 1194 |
+
return_commit_hash=True,
|
| 1195 |
+
force_download=force_download,
|
| 1196 |
+
proxies=proxies,
|
| 1197 |
+
local_files_only=local_files_only,
|
| 1198 |
+
token=token,
|
| 1199 |
+
revision=revision,
|
| 1200 |
+
subfolder=subfolder,
|
| 1201 |
+
user_agent=user_agent,
|
| 1202 |
+
**kwargs,
|
| 1203 |
+
)
|
| 1204 |
+
# resolve remapping
|
| 1205 |
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
| 1206 |
+
|
| 1207 |
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
mcp_servers/fashion_vlm/models/phi.py
ADDED
|
@@ -0,0 +1,1489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""PyTorch Phi model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from packaging import version
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 30 |
+
from transformers.modeling_attn_mask_utils import (
|
| 31 |
+
_prepare_4d_causal_attention_mask,
|
| 32 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_outputs import (
|
| 35 |
+
BaseModelOutputWithPast,
|
| 36 |
+
CausalLMOutputWithPast,
|
| 37 |
+
SequenceClassifierOutputWithPast,
|
| 38 |
+
TokenClassifierOutput,
|
| 39 |
+
)
|
| 40 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
+
from transformers.utils import (
|
| 42 |
+
add_code_sample_docstrings,
|
| 43 |
+
add_start_docstrings,
|
| 44 |
+
add_start_docstrings_to_model_forward,
|
| 45 |
+
get_torch_version,
|
| 46 |
+
is_flash_attn_2_available,
|
| 47 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 48 |
+
logging,
|
| 49 |
+
replace_return_docstrings,
|
| 50 |
+
)
|
| 51 |
+
from transformers.models.phi.configuration_phi import PhiConfig
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_flash_attn_2_available():
|
| 55 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 56 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
logger = logging.get_logger(__name__)
|
| 60 |
+
|
| 61 |
+
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
| 62 |
+
_CONFIG_FOR_DOC = "PhiConfig"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 66 |
+
def _get_unpad_data(attention_mask):
|
| 67 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 68 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 69 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 70 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 71 |
+
return (
|
| 72 |
+
indices,
|
| 73 |
+
cu_seqlens,
|
| 74 |
+
max_seqlen_in_batch,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
|
| 79 |
+
class PhiRotaryEmbedding(nn.Module):
|
| 80 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
self.dim = dim
|
| 84 |
+
self.max_position_embeddings = max_position_embeddings
|
| 85 |
+
self.base = base
|
| 86 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 87 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 88 |
+
|
| 89 |
+
# Build here to make `torch.jit.trace` work.
|
| 90 |
+
self._set_cos_sin_cache(
|
| 91 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 95 |
+
self.max_seq_len_cached = seq_len
|
| 96 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 97 |
+
|
| 98 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 99 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 100 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 101 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 102 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 103 |
+
|
| 104 |
+
def forward(self, x, seq_len=None):
|
| 105 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 106 |
+
if seq_len > self.max_seq_len_cached:
|
| 107 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 111 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
|
| 116 |
+
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
| 117 |
+
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 120 |
+
self.scaling_factor = scaling_factor
|
| 121 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 122 |
+
|
| 123 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 124 |
+
self.max_seq_len_cached = seq_len
|
| 125 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 126 |
+
t = t / self.scaling_factor
|
| 127 |
+
|
| 128 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 129 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 130 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 131 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 132 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
|
| 136 |
+
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
| 137 |
+
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 140 |
+
self.scaling_factor = scaling_factor
|
| 141 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 142 |
+
|
| 143 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 144 |
+
self.max_seq_len_cached = seq_len
|
| 145 |
+
|
| 146 |
+
if seq_len > self.max_position_embeddings:
|
| 147 |
+
base = self.base * (
|
| 148 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
| 149 |
+
) ** (self.dim / (self.dim - 2))
|
| 150 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 151 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 152 |
+
|
| 153 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 154 |
+
|
| 155 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 156 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 157 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 158 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 159 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 163 |
+
def rotate_half(x):
|
| 164 |
+
"""Rotates half the hidden dims of the input."""
|
| 165 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 166 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 167 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
| 171 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 172 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
q (`torch.Tensor`): The query tensor.
|
| 176 |
+
k (`torch.Tensor`): The key tensor.
|
| 177 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 178 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 179 |
+
position_ids (`torch.Tensor`):
|
| 180 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 181 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 182 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 183 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 184 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 185 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 186 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 187 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 188 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 189 |
+
Returns:
|
| 190 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 191 |
+
"""
|
| 192 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 193 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 194 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 195 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 196 |
+
return q_embed, k_embed
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
|
| 200 |
+
class PhiMLP(nn.Module):
|
| 201 |
+
def __init__(self, config):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.config = config
|
| 204 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 205 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 206 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 207 |
+
|
| 208 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
hidden_states = self.fc1(hidden_states)
|
| 210 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 211 |
+
hidden_states = self.fc2(hidden_states)
|
| 212 |
+
return hidden_states
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
| 216 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 217 |
+
"""
|
| 218 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 219 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 220 |
+
"""
|
| 221 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 222 |
+
if n_rep == 1:
|
| 223 |
+
return hidden_states
|
| 224 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 225 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class PhiAttention(nn.Module):
|
| 229 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.config = config
|
| 234 |
+
self.layer_idx = layer_idx
|
| 235 |
+
if layer_idx is None:
|
| 236 |
+
logger.warning_once(
|
| 237 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 238 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 239 |
+
"when creating this class."
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self.attention_dropout = config.attention_dropout
|
| 243 |
+
self.hidden_size = config.hidden_size
|
| 244 |
+
self.num_heads = config.num_attention_heads
|
| 245 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 246 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 247 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 248 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 249 |
+
self.rope_theta = config.rope_theta
|
| 250 |
+
self.partial_rotary_factor = config.partial_rotary_factor
|
| 251 |
+
self.is_causal = True
|
| 252 |
+
|
| 253 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 256 |
+
f" and `num_heads`: {self.num_heads})."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 260 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 261 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 262 |
+
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
| 263 |
+
|
| 264 |
+
self.qk_layernorm = config.qk_layernorm
|
| 265 |
+
if self.qk_layernorm:
|
| 266 |
+
self.q_layernorm = nn.LayerNorm(
|
| 267 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
| 268 |
+
)
|
| 269 |
+
self.k_layernorm = nn.LayerNorm(
|
| 270 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
self._init_rope()
|
| 274 |
+
|
| 275 |
+
def _init_rope(self):
|
| 276 |
+
if self.config.rope_scaling is None:
|
| 277 |
+
self.rotary_emb = PhiRotaryEmbedding(
|
| 278 |
+
int(self.partial_rotary_factor * self.head_dim),
|
| 279 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 280 |
+
base=self.rope_theta,
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 284 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 285 |
+
if scaling_type == "linear":
|
| 286 |
+
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
|
| 287 |
+
int(self.partial_rotary_factor * self.head_dim),
|
| 288 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 289 |
+
scaling_factor=scaling_factor,
|
| 290 |
+
base=self.rope_theta,
|
| 291 |
+
)
|
| 292 |
+
elif scaling_type == "dynamic":
|
| 293 |
+
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
|
| 294 |
+
int(self.partial_rotary_factor * self.head_dim),
|
| 295 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 296 |
+
scaling_factor=scaling_factor,
|
| 297 |
+
base=self.rope_theta,
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 301 |
+
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
hidden_states: torch.Tensor,
|
| 305 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 306 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 307 |
+
past_key_value: Optional[Cache] = None,
|
| 308 |
+
output_attentions: bool = False,
|
| 309 |
+
use_cache: bool = False,
|
| 310 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 311 |
+
bsz, q_len, _ = hidden_states.size()
|
| 312 |
+
|
| 313 |
+
query_states = self.q_proj(hidden_states)
|
| 314 |
+
key_states = self.k_proj(hidden_states)
|
| 315 |
+
value_states = self.v_proj(hidden_states)
|
| 316 |
+
|
| 317 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 318 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 319 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 320 |
+
|
| 321 |
+
if self.qk_layernorm:
|
| 322 |
+
query_states = self.q_layernorm(query_states)
|
| 323 |
+
key_states = self.k_layernorm(key_states)
|
| 324 |
+
|
| 325 |
+
kv_seq_len = key_states.shape[-2]
|
| 326 |
+
if past_key_value is not None:
|
| 327 |
+
if self.layer_idx is None:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 330 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 331 |
+
"with a layer index."
|
| 332 |
+
)
|
| 333 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 334 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 335 |
+
|
| 336 |
+
# Partial rotary embedding
|
| 337 |
+
query_rot, query_pass = (
|
| 338 |
+
query_states[..., : self.rotary_emb.dim],
|
| 339 |
+
query_states[..., self.rotary_emb.dim :],
|
| 340 |
+
)
|
| 341 |
+
key_rot, key_pass = (
|
| 342 |
+
key_states[..., : self.rotary_emb.dim],
|
| 343 |
+
key_states[..., self.rotary_emb.dim :],
|
| 344 |
+
)
|
| 345 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
| 346 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
| 347 |
+
|
| 348 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
| 349 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
| 350 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
| 351 |
+
|
| 352 |
+
if past_key_value is not None:
|
| 353 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
| 354 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 355 |
+
|
| 356 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 357 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 358 |
+
|
| 359 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
| 360 |
+
attn_weights = torch.matmul(
|
| 361 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
| 362 |
+
) / math.sqrt(self.head_dim)
|
| 363 |
+
|
| 364 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 367 |
+
f" {attn_weights.size()}"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if attention_mask is not None:
|
| 371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 374 |
+
)
|
| 375 |
+
attn_weights = attn_weights + attention_mask
|
| 376 |
+
|
| 377 |
+
# upcast attention to fp32
|
| 378 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
| 379 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 380 |
+
|
| 381 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 382 |
+
|
| 383 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 386 |
+
f" {attn_output.size()}"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 390 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 391 |
+
|
| 392 |
+
attn_output = self.dense(attn_output)
|
| 393 |
+
|
| 394 |
+
if not output_attentions:
|
| 395 |
+
attn_weights = None
|
| 396 |
+
|
| 397 |
+
return attn_output, attn_weights, past_key_value
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class PhiFlashAttention2(PhiAttention):
|
| 401 |
+
"""
|
| 402 |
+
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
|
| 403 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 404 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
| 408 |
+
def __init__(self, *args, **kwargs):
|
| 409 |
+
super().__init__(*args, **kwargs)
|
| 410 |
+
|
| 411 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 412 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 413 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 414 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 415 |
+
|
| 416 |
+
def forward(
|
| 417 |
+
self,
|
| 418 |
+
hidden_states: torch.Tensor,
|
| 419 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 420 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 421 |
+
past_key_value: Optional[Cache] = None,
|
| 422 |
+
output_attentions: bool = False,
|
| 423 |
+
use_cache: bool = False,
|
| 424 |
+
**kwargs,
|
| 425 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 426 |
+
# PhiFlashAttention2 attention does not support output_attentions
|
| 427 |
+
|
| 428 |
+
output_attentions = False
|
| 429 |
+
|
| 430 |
+
bsz, q_len, _ = hidden_states.size()
|
| 431 |
+
|
| 432 |
+
query_states = self.q_proj(hidden_states)
|
| 433 |
+
key_states = self.k_proj(hidden_states)
|
| 434 |
+
value_states = self.v_proj(hidden_states)
|
| 435 |
+
|
| 436 |
+
# Flash attention requires the input to have the shape
|
| 437 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 438 |
+
# therefore we just need to keep the original shape
|
| 439 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 440 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 441 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 442 |
+
|
| 443 |
+
if self.qk_layernorm:
|
| 444 |
+
query_states = self.q_layernorm(query_states)
|
| 445 |
+
key_states = self.k_layernorm(key_states)
|
| 446 |
+
|
| 447 |
+
kv_seq_len = key_states.shape[-2]
|
| 448 |
+
if past_key_value is not None:
|
| 449 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 450 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 451 |
+
|
| 452 |
+
# Partial rotary embedding
|
| 453 |
+
query_rot, query_pass = (
|
| 454 |
+
query_states[..., : self.rotary_emb.dim],
|
| 455 |
+
query_states[..., self.rotary_emb.dim :],
|
| 456 |
+
)
|
| 457 |
+
key_rot, key_pass = (
|
| 458 |
+
key_states[..., : self.rotary_emb.dim],
|
| 459 |
+
key_states[..., self.rotary_emb.dim :],
|
| 460 |
+
)
|
| 461 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
| 462 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
| 463 |
+
|
| 464 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
| 465 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
| 466 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
| 467 |
+
|
| 468 |
+
if past_key_value is not None:
|
| 469 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
| 470 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 471 |
+
|
| 472 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 473 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 474 |
+
query_states = query_states.transpose(1, 2)
|
| 475 |
+
key_states = key_states.transpose(1, 2)
|
| 476 |
+
value_states = value_states.transpose(1, 2)
|
| 477 |
+
|
| 478 |
+
attn_dropout = self.attention_dropout if self.training else 0.0
|
| 479 |
+
|
| 480 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 481 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 482 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 483 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 484 |
+
# in fp32.
|
| 485 |
+
|
| 486 |
+
if query_states.dtype == torch.float32:
|
| 487 |
+
if torch.is_autocast_enabled():
|
| 488 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 489 |
+
# Handle the case where the model is quantized
|
| 490 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 491 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 492 |
+
else:
|
| 493 |
+
target_dtype = self.q_proj.weight.dtype
|
| 494 |
+
|
| 495 |
+
logger.warning_once(
|
| 496 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 497 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 498 |
+
f" {target_dtype}."
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
query_states = query_states.to(target_dtype)
|
| 502 |
+
key_states = key_states.to(target_dtype)
|
| 503 |
+
value_states = value_states.to(target_dtype)
|
| 504 |
+
|
| 505 |
+
attn_output = self._flash_attention_forward(
|
| 506 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 510 |
+
attn_output = self.dense(attn_output)
|
| 511 |
+
|
| 512 |
+
if not output_attentions:
|
| 513 |
+
attn_weights = None
|
| 514 |
+
|
| 515 |
+
return attn_output, attn_weights, past_key_value
|
| 516 |
+
|
| 517 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
| 518 |
+
def _flash_attention_forward(
|
| 519 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 520 |
+
):
|
| 521 |
+
"""
|
| 522 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 523 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
query_states (`torch.Tensor`):
|
| 527 |
+
Input query states to be passed to Flash Attention API
|
| 528 |
+
key_states (`torch.Tensor`):
|
| 529 |
+
Input key states to be passed to Flash Attention API
|
| 530 |
+
value_states (`torch.Tensor`):
|
| 531 |
+
Input value states to be passed to Flash Attention API
|
| 532 |
+
attention_mask (`torch.Tensor`):
|
| 533 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 534 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 535 |
+
dropout (`float`):
|
| 536 |
+
Attention dropout
|
| 537 |
+
softmax_scale (`float`, *optional*):
|
| 538 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 539 |
+
"""
|
| 540 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 541 |
+
causal = self.is_causal
|
| 542 |
+
else:
|
| 543 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 544 |
+
causal = self.is_causal and query_length != 1
|
| 545 |
+
|
| 546 |
+
# Contains at least one padding token in the sequence
|
| 547 |
+
if attention_mask is not None:
|
| 548 |
+
batch_size = query_states.shape[0]
|
| 549 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 550 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 554 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 555 |
+
|
| 556 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 557 |
+
query_states,
|
| 558 |
+
key_states,
|
| 559 |
+
value_states,
|
| 560 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 561 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 562 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 563 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 564 |
+
dropout_p=dropout,
|
| 565 |
+
softmax_scale=softmax_scale,
|
| 566 |
+
causal=causal,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 570 |
+
else:
|
| 571 |
+
attn_output = flash_attn_func(
|
| 572 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
return attn_output
|
| 576 |
+
|
| 577 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
| 578 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 579 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 580 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 581 |
+
|
| 582 |
+
key_layer = index_first_axis(
|
| 583 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
| 584 |
+
)
|
| 585 |
+
value_layer = index_first_axis(
|
| 586 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
| 587 |
+
)
|
| 588 |
+
if query_length == kv_seq_len:
|
| 589 |
+
query_layer = index_first_axis(
|
| 590 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
| 591 |
+
)
|
| 592 |
+
cu_seqlens_q = cu_seqlens_k
|
| 593 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 594 |
+
indices_q = indices_k
|
| 595 |
+
elif query_length == 1:
|
| 596 |
+
max_seqlen_in_batch_q = 1
|
| 597 |
+
cu_seqlens_q = torch.arange(
|
| 598 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 599 |
+
) # There is a memcpy here, that is very bad.
|
| 600 |
+
indices_q = cu_seqlens_q[:-1]
|
| 601 |
+
query_layer = query_layer.squeeze(1)
|
| 602 |
+
else:
|
| 603 |
+
# The -q_len: slice assumes left padding.
|
| 604 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 605 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 606 |
+
|
| 607 |
+
return (
|
| 608 |
+
query_layer,
|
| 609 |
+
key_layer,
|
| 610 |
+
value_layer,
|
| 611 |
+
indices_q,
|
| 612 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 613 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class PhiSdpaAttention(PhiAttention):
|
| 618 |
+
def __init__(self, *args, **kwargs):
|
| 619 |
+
super().__init__(*args, **kwargs)
|
| 620 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
| 621 |
+
|
| 622 |
+
"""
|
| 623 |
+
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 624 |
+
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 625 |
+
SDPA API.
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
# Adapted from PhiAttention.forward
|
| 629 |
+
def forward(
|
| 630 |
+
self,
|
| 631 |
+
hidden_states: torch.Tensor,
|
| 632 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 633 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 634 |
+
past_key_value: Optional[Cache] = None,
|
| 635 |
+
output_attentions: bool = False,
|
| 636 |
+
use_cache: bool = False,
|
| 637 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 638 |
+
if output_attentions:
|
| 639 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 640 |
+
logger.warning_once(
|
| 641 |
+
"PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
| 642 |
+
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
| 643 |
+
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
| 644 |
+
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 645 |
+
)
|
| 646 |
+
return super().forward(
|
| 647 |
+
hidden_states=hidden_states,
|
| 648 |
+
attention_mask=attention_mask,
|
| 649 |
+
position_ids=position_ids,
|
| 650 |
+
past_key_value=past_key_value,
|
| 651 |
+
output_attentions=output_attentions,
|
| 652 |
+
use_cache=use_cache,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
bsz, q_len, _ = hidden_states.size()
|
| 656 |
+
|
| 657 |
+
query_states = self.q_proj(hidden_states)
|
| 658 |
+
key_states = self.k_proj(hidden_states)
|
| 659 |
+
value_states = self.v_proj(hidden_states)
|
| 660 |
+
|
| 661 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 662 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 663 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 664 |
+
|
| 665 |
+
if self.qk_layernorm:
|
| 666 |
+
query_states = self.q_layernorm(query_states)
|
| 667 |
+
key_states = self.k_layernorm(key_states)
|
| 668 |
+
|
| 669 |
+
kv_seq_len = key_states.shape[-2]
|
| 670 |
+
if past_key_value is not None:
|
| 671 |
+
if self.layer_idx is None:
|
| 672 |
+
raise ValueError(
|
| 673 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 674 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 675 |
+
"with a layer index."
|
| 676 |
+
)
|
| 677 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 678 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 679 |
+
|
| 680 |
+
# Partial rotary embedding
|
| 681 |
+
query_rot, query_pass = (
|
| 682 |
+
query_states[..., : self.rotary_emb.dim],
|
| 683 |
+
query_states[..., self.rotary_emb.dim :],
|
| 684 |
+
)
|
| 685 |
+
key_rot, key_pass = (
|
| 686 |
+
key_states[..., : self.rotary_emb.dim],
|
| 687 |
+
key_states[..., self.rotary_emb.dim :],
|
| 688 |
+
)
|
| 689 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
| 690 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
| 691 |
+
|
| 692 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
| 693 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
| 694 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
| 695 |
+
|
| 696 |
+
if past_key_value is not None:
|
| 697 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
| 698 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 699 |
+
|
| 700 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 701 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 702 |
+
|
| 703 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
| 704 |
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
| 705 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 706 |
+
if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
|
| 707 |
+
query_states = query_states.contiguous()
|
| 708 |
+
key_states = key_states.contiguous()
|
| 709 |
+
value_states = value_states.contiguous()
|
| 710 |
+
|
| 711 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 712 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 713 |
+
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
| 714 |
+
|
| 715 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 716 |
+
query_states,
|
| 717 |
+
key_states,
|
| 718 |
+
value_states,
|
| 719 |
+
attn_mask=attention_mask,
|
| 720 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 721 |
+
is_causal=is_causal,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 725 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 726 |
+
|
| 727 |
+
attn_output = self.dense(attn_output)
|
| 728 |
+
|
| 729 |
+
return attn_output, None, past_key_value
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
PHI_ATTENTION_CLASSES = {
|
| 733 |
+
"eager": PhiAttention,
|
| 734 |
+
"flash_attention_2": PhiFlashAttention2,
|
| 735 |
+
"sdpa": PhiSdpaAttention,
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class PhiDecoderLayer(nn.Module):
|
| 740 |
+
def __init__(self, config: PhiConfig, layer_idx: int):
|
| 741 |
+
super().__init__()
|
| 742 |
+
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
| 743 |
+
self.mlp = PhiMLP(config)
|
| 744 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 745 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 746 |
+
|
| 747 |
+
def forward(
|
| 748 |
+
self,
|
| 749 |
+
hidden_states: torch.Tensor,
|
| 750 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 751 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 752 |
+
output_attentions: Optional[bool] = False,
|
| 753 |
+
use_cache: Optional[bool] = False,
|
| 754 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 755 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 756 |
+
"""
|
| 757 |
+
Args:
|
| 758 |
+
hidden_states (`torch.FloatTensor`):
|
| 759 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 760 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 761 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 762 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 763 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
| 764 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
| 765 |
+
output_attentions (`bool`, *optional*):
|
| 766 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 767 |
+
returned tensors for more detail.
|
| 768 |
+
use_cache (`bool`, *optional*):
|
| 769 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 770 |
+
(see `past_key_values`).
|
| 771 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 772 |
+
"""
|
| 773 |
+
|
| 774 |
+
residual = hidden_states
|
| 775 |
+
|
| 776 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 777 |
+
|
| 778 |
+
# Self Attention
|
| 779 |
+
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
|
| 780 |
+
hidden_states=hidden_states,
|
| 781 |
+
attention_mask=attention_mask,
|
| 782 |
+
position_ids=position_ids,
|
| 783 |
+
past_key_value=past_key_value,
|
| 784 |
+
output_attentions=output_attentions,
|
| 785 |
+
use_cache=use_cache,
|
| 786 |
+
)
|
| 787 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
| 788 |
+
|
| 789 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
| 790 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
| 791 |
+
outputs = (hidden_states,)
|
| 792 |
+
|
| 793 |
+
if output_attentions:
|
| 794 |
+
outputs += (self_attn_weights,)
|
| 795 |
+
|
| 796 |
+
if use_cache:
|
| 797 |
+
outputs += (present_key_value,)
|
| 798 |
+
|
| 799 |
+
return outputs
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
PHI_START_DOCSTRING = r"""
|
| 803 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 804 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 805 |
+
etc.)
|
| 806 |
+
|
| 807 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 808 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 809 |
+
and behavior.
|
| 810 |
+
|
| 811 |
+
Parameters:
|
| 812 |
+
config ([`PhiConfig`]):
|
| 813 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 814 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 815 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
@add_start_docstrings(
|
| 820 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
| 821 |
+
PHI_START_DOCSTRING,
|
| 822 |
+
)
|
| 823 |
+
class PhiPreTrainedModel(PreTrainedModel):
|
| 824 |
+
config_class = PhiConfig
|
| 825 |
+
base_model_prefix = "model"
|
| 826 |
+
supports_gradient_checkpointing = True
|
| 827 |
+
_no_split_modules = ["PhiDecoderLayer"]
|
| 828 |
+
_skip_keys_device_placement = "past_key_values"
|
| 829 |
+
_supports_flash_attn_2 = True
|
| 830 |
+
_supports_sdpa = True
|
| 831 |
+
_supports_cache_class = True
|
| 832 |
+
|
| 833 |
+
def _init_weights(self, module):
|
| 834 |
+
std = self.config.initializer_range
|
| 835 |
+
if isinstance(module, nn.Linear):
|
| 836 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 837 |
+
if module.bias is not None:
|
| 838 |
+
module.bias.data.zero_()
|
| 839 |
+
elif isinstance(module, nn.Embedding):
|
| 840 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 841 |
+
if module.padding_idx is not None:
|
| 842 |
+
module.weight.data[module.padding_idx].zero_()
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
PHI_INPUTS_DOCSTRING = r"""
|
| 846 |
+
Args:
|
| 847 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 848 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 849 |
+
it.
|
| 850 |
+
|
| 851 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 852 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 853 |
+
|
| 854 |
+
[What are input IDs?](../glossary#input-ids)
|
| 855 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 856 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 857 |
+
|
| 858 |
+
- 1 for tokens that are **not masked**,
|
| 859 |
+
- 0 for tokens that are **masked**.
|
| 860 |
+
|
| 861 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 862 |
+
|
| 863 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 864 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 865 |
+
|
| 866 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 867 |
+
`past_key_values`).
|
| 868 |
+
|
| 869 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 870 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 871 |
+
information on the default strategy.
|
| 872 |
+
|
| 873 |
+
- 1 indicates the head is **not masked**,
|
| 874 |
+
- 0 indicates the head is **masked**.
|
| 875 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 876 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 877 |
+
config.n_positions - 1]`.
|
| 878 |
+
|
| 879 |
+
[What are position IDs?](../glossary#position-ids)
|
| 880 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 881 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 882 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 883 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 884 |
+
|
| 885 |
+
Two formats are allowed:
|
| 886 |
+
- a [`~cache_utils.Cache`] instance;
|
| 887 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 888 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 889 |
+
cache format.
|
| 890 |
+
|
| 891 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 892 |
+
legacy cache format will be returned.
|
| 893 |
+
|
| 894 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 895 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 896 |
+
of shape `(batch_size, sequence_length)`.
|
| 897 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 898 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 899 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 900 |
+
model's internal embedding lookup matrix.
|
| 901 |
+
use_cache (`bool`, *optional*):
|
| 902 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 903 |
+
`past_key_values`).
|
| 904 |
+
output_attentions (`bool`, *optional*):
|
| 905 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 906 |
+
tensors for more detail.
|
| 907 |
+
output_hidden_states (`bool`, *optional*):
|
| 908 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 909 |
+
more detail.
|
| 910 |
+
return_dict (`bool`, *optional*):
|
| 911 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 912 |
+
"""
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
@add_start_docstrings(
|
| 916 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
| 917 |
+
PHI_START_DOCSTRING,
|
| 918 |
+
)
|
| 919 |
+
class PhiModel(PhiPreTrainedModel):
|
| 920 |
+
"""
|
| 921 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
|
| 922 |
+
|
| 923 |
+
Args:
|
| 924 |
+
config: PhiConfig
|
| 925 |
+
"""
|
| 926 |
+
|
| 927 |
+
def __init__(self, config: PhiConfig):
|
| 928 |
+
super().__init__(config)
|
| 929 |
+
self.padding_idx = config.pad_token_id
|
| 930 |
+
self.vocab_size = config.vocab_size
|
| 931 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 932 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
| 933 |
+
print("attention implementation: ", config._attn_implementation)
|
| 934 |
+
self.layers = nn.ModuleList(
|
| 935 |
+
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 936 |
+
)
|
| 937 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 938 |
+
|
| 939 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 940 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 941 |
+
|
| 942 |
+
self.gradient_checkpointing = False
|
| 943 |
+
# Initialize weights and apply final processing
|
| 944 |
+
self.post_init()
|
| 945 |
+
|
| 946 |
+
def get_input_embeddings(self):
|
| 947 |
+
return self.embed_tokens
|
| 948 |
+
|
| 949 |
+
def set_input_embeddings(self, value):
|
| 950 |
+
self.embed_tokens = value
|
| 951 |
+
|
| 952 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
| 953 |
+
def forward(
|
| 954 |
+
self,
|
| 955 |
+
input_ids: torch.LongTensor = None,
|
| 956 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 957 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 958 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 959 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 960 |
+
use_cache: Optional[bool] = None,
|
| 961 |
+
output_attentions: Optional[bool] = None,
|
| 962 |
+
output_hidden_states: Optional[bool] = None,
|
| 963 |
+
return_dict: Optional[bool] = None,
|
| 964 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 965 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 966 |
+
output_hidden_states = (
|
| 967 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 968 |
+
)
|
| 969 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 970 |
+
|
| 971 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 972 |
+
|
| 973 |
+
# retrieve input_ids and inputs_embeds
|
| 974 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 975 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 976 |
+
elif input_ids is not None:
|
| 977 |
+
batch_size, seq_length = input_ids.shape[:2]
|
| 978 |
+
elif inputs_embeds is not None:
|
| 979 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
| 980 |
+
else:
|
| 981 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 982 |
+
|
| 983 |
+
past_key_values_length = 0
|
| 984 |
+
|
| 985 |
+
if self.gradient_checkpointing and self.training:
|
| 986 |
+
if use_cache:
|
| 987 |
+
logger.warning_once(
|
| 988 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 989 |
+
)
|
| 990 |
+
use_cache = False
|
| 991 |
+
|
| 992 |
+
if use_cache:
|
| 993 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 994 |
+
if use_legacy_cache:
|
| 995 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 996 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 997 |
+
|
| 998 |
+
if position_ids is None:
|
| 999 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1000 |
+
position_ids = torch.arange(
|
| 1001 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 1002 |
+
)
|
| 1003 |
+
position_ids = position_ids.unsqueeze(0)
|
| 1004 |
+
|
| 1005 |
+
if inputs_embeds is None:
|
| 1006 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1007 |
+
|
| 1008 |
+
inputs_embeds = self.embed_dropout(inputs_embeds)
|
| 1009 |
+
# commented by Xavier
|
| 1010 |
+
# Attention mask.
|
| 1011 |
+
# if self._use_flash_attention_2:
|
| 1012 |
+
# # 2d mask is passed through the layers
|
| 1013 |
+
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1014 |
+
# elif self._use_sdpa and not output_attentions:
|
| 1015 |
+
# attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 1016 |
+
# attention_mask,
|
| 1017 |
+
# (batch_size, seq_length),
|
| 1018 |
+
# inputs_embeds,
|
| 1019 |
+
# past_key_values_length,
|
| 1020 |
+
# )
|
| 1021 |
+
# else:
|
| 1022 |
+
# # 4d mask is passed through the layers
|
| 1023 |
+
# attention_mask = _prepare_4d_causal_attention_mask(
|
| 1024 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 1025 |
+
# )
|
| 1026 |
+
# commented by Xavier
|
| 1027 |
+
|
| 1028 |
+
hidden_states = inputs_embeds
|
| 1029 |
+
|
| 1030 |
+
# decoder layers
|
| 1031 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1032 |
+
all_self_attns = () if output_attentions else None
|
| 1033 |
+
next_decoder_cache = None
|
| 1034 |
+
for decoder_layer in self.layers:
|
| 1035 |
+
if output_hidden_states:
|
| 1036 |
+
all_hidden_states += (hidden_states,)
|
| 1037 |
+
|
| 1038 |
+
if self.gradient_checkpointing and self.training:
|
| 1039 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1040 |
+
decoder_layer.__call__,
|
| 1041 |
+
hidden_states,
|
| 1042 |
+
attention_mask,
|
| 1043 |
+
position_ids,
|
| 1044 |
+
past_key_values,
|
| 1045 |
+
output_attentions,
|
| 1046 |
+
)
|
| 1047 |
+
else:
|
| 1048 |
+
layer_outputs = decoder_layer(
|
| 1049 |
+
hidden_states,
|
| 1050 |
+
attention_mask=attention_mask,
|
| 1051 |
+
position_ids=position_ids,
|
| 1052 |
+
past_key_value=past_key_values,
|
| 1053 |
+
output_attentions=output_attentions,
|
| 1054 |
+
use_cache=use_cache,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
hidden_states = layer_outputs[0]
|
| 1058 |
+
|
| 1059 |
+
if use_cache:
|
| 1060 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1061 |
+
|
| 1062 |
+
if output_attentions:
|
| 1063 |
+
all_self_attns += (layer_outputs[1],)
|
| 1064 |
+
|
| 1065 |
+
hidden_states = self.final_layernorm(hidden_states)
|
| 1066 |
+
|
| 1067 |
+
# add hidden states from the last decoder layer
|
| 1068 |
+
if output_hidden_states:
|
| 1069 |
+
all_hidden_states += (hidden_states,)
|
| 1070 |
+
|
| 1071 |
+
next_cache = None
|
| 1072 |
+
if use_cache:
|
| 1073 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
| 1074 |
+
if not return_dict:
|
| 1075 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1076 |
+
return BaseModelOutputWithPast(
|
| 1077 |
+
last_hidden_state=hidden_states,
|
| 1078 |
+
past_key_values=next_cache,
|
| 1079 |
+
hidden_states=all_hidden_states,
|
| 1080 |
+
attentions=all_self_attns,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
class PhiForCausalLM(PhiPreTrainedModel):
|
| 1085 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1086 |
+
def __init__(self, config):
|
| 1087 |
+
super().__init__(config)
|
| 1088 |
+
config.qk_layernorm = True
|
| 1089 |
+
config.use_cache = False
|
| 1090 |
+
self.model = PhiModel(config)
|
| 1091 |
+
self.vocab_size = config.vocab_size
|
| 1092 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
| 1093 |
+
|
| 1094 |
+
# Initialize weights and apply final processing
|
| 1095 |
+
self.post_init()
|
| 1096 |
+
|
| 1097 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
| 1098 |
+
def get_input_embeddings(self):
|
| 1099 |
+
return self.model.embed_tokens
|
| 1100 |
+
|
| 1101 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
| 1102 |
+
def set_input_embeddings(self, value):
|
| 1103 |
+
self.model.embed_tokens = value
|
| 1104 |
+
|
| 1105 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
| 1106 |
+
def get_output_embeddings(self):
|
| 1107 |
+
return self.lm_head
|
| 1108 |
+
|
| 1109 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
| 1110 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1111 |
+
self.lm_head = new_embeddings
|
| 1112 |
+
|
| 1113 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
| 1114 |
+
def set_decoder(self, decoder):
|
| 1115 |
+
self.model = decoder
|
| 1116 |
+
|
| 1117 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
| 1118 |
+
def get_decoder(self):
|
| 1119 |
+
return self.model
|
| 1120 |
+
|
| 1121 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
| 1122 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1123 |
+
def forward(
|
| 1124 |
+
self,
|
| 1125 |
+
input_ids: torch.LongTensor = None,
|
| 1126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1127 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1128 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1129 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1130 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1131 |
+
use_cache: Optional[bool] = None,
|
| 1132 |
+
output_attentions: Optional[bool] = None,
|
| 1133 |
+
output_hidden_states: Optional[bool] = None,
|
| 1134 |
+
return_dict: Optional[bool] = None,
|
| 1135 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1136 |
+
r"""
|
| 1137 |
+
Args:
|
| 1138 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1139 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1140 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1141 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1142 |
+
|
| 1143 |
+
Returns:
|
| 1144 |
+
|
| 1145 |
+
Example:
|
| 1146 |
+
|
| 1147 |
+
```python
|
| 1148 |
+
>>> from transformers import AutoTokenizer, PhiForCausalLM
|
| 1149 |
+
|
| 1150 |
+
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
|
| 1151 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
| 1152 |
+
|
| 1153 |
+
>>> prompt = "This is an example script ."
|
| 1154 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1155 |
+
|
| 1156 |
+
>>> # Generate
|
| 1157 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1158 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1159 |
+
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
|
| 1160 |
+
```"""
|
| 1161 |
+
|
| 1162 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1163 |
+
output_hidden_states = (
|
| 1164 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1165 |
+
)
|
| 1166 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1167 |
+
|
| 1168 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1169 |
+
outputs = self.model(
|
| 1170 |
+
input_ids=input_ids,
|
| 1171 |
+
attention_mask=attention_mask,
|
| 1172 |
+
position_ids=position_ids,
|
| 1173 |
+
past_key_values=past_key_values,
|
| 1174 |
+
inputs_embeds=inputs_embeds,
|
| 1175 |
+
use_cache=use_cache,
|
| 1176 |
+
output_attentions=output_attentions,
|
| 1177 |
+
output_hidden_states=output_hidden_states,
|
| 1178 |
+
return_dict=return_dict,
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
hidden_states = outputs[0]
|
| 1182 |
+
logits = self.lm_head(hidden_states)
|
| 1183 |
+
logits = logits.float()
|
| 1184 |
+
|
| 1185 |
+
loss = None
|
| 1186 |
+
if labels is not None:
|
| 1187 |
+
# Shift so that tokens < n predict n
|
| 1188 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1189 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1190 |
+
# Flatten the tokens
|
| 1191 |
+
loss_fct = CrossEntropyLoss()
|
| 1192 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1193 |
+
shift_labels = shift_labels.view(-1)
|
| 1194 |
+
# Enable model parallelism
|
| 1195 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1196 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1197 |
+
|
| 1198 |
+
if not return_dict:
|
| 1199 |
+
output = (logits,) + outputs[1:]
|
| 1200 |
+
return (loss,) + output if loss is not None else output
|
| 1201 |
+
|
| 1202 |
+
return CausalLMOutputWithPast(
|
| 1203 |
+
loss=loss,
|
| 1204 |
+
logits=logits,
|
| 1205 |
+
past_key_values=outputs.past_key_values,
|
| 1206 |
+
hidden_states=outputs.hidden_states,
|
| 1207 |
+
attentions=outputs.attentions,
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
|
| 1211 |
+
def prepare_inputs_for_generation(
|
| 1212 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1213 |
+
):
|
| 1214 |
+
if past_key_values is not None:
|
| 1215 |
+
if isinstance(past_key_values, Cache):
|
| 1216 |
+
cache_length = past_key_values.get_seq_length()
|
| 1217 |
+
past_length = past_key_values.seen_tokens
|
| 1218 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1219 |
+
else:
|
| 1220 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1221 |
+
max_cache_length = None
|
| 1222 |
+
|
| 1223 |
+
# Keep only the unprocessed tokens:
|
| 1224 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1225 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 1226 |
+
# input)
|
| 1227 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1228 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1229 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1230 |
+
# input_ids based on the past_length.
|
| 1231 |
+
elif past_length < input_ids.shape[1]:
|
| 1232 |
+
input_ids = input_ids[:, past_length:]
|
| 1233 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1234 |
+
|
| 1235 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1236 |
+
if (
|
| 1237 |
+
max_cache_length is not None
|
| 1238 |
+
and attention_mask is not None
|
| 1239 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1240 |
+
):
|
| 1241 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1242 |
+
|
| 1243 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1244 |
+
if attention_mask is not None and position_ids is None:
|
| 1245 |
+
# create position_ids on the fly for batch generation
|
| 1246 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1247 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1248 |
+
if past_key_values:
|
| 1249 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1250 |
+
|
| 1251 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1252 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1253 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1254 |
+
else:
|
| 1255 |
+
model_inputs = {"input_ids": input_ids}
|
| 1256 |
+
|
| 1257 |
+
model_inputs.update(
|
| 1258 |
+
{
|
| 1259 |
+
"position_ids": position_ids,
|
| 1260 |
+
"past_key_values": past_key_values,
|
| 1261 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1262 |
+
"attention_mask": attention_mask,
|
| 1263 |
+
}
|
| 1264 |
+
)
|
| 1265 |
+
return model_inputs
|
| 1266 |
+
|
| 1267 |
+
@staticmethod
|
| 1268 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
| 1269 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1270 |
+
reordered_past = ()
|
| 1271 |
+
for layer_past in past_key_values:
|
| 1272 |
+
reordered_past += (
|
| 1273 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1274 |
+
)
|
| 1275 |
+
return reordered_past
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
@add_start_docstrings(
|
| 1279 |
+
"""
|
| 1280 |
+
The PhiModel with a sequence classification head on top (linear layer).
|
| 1281 |
+
|
| 1282 |
+
[`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1283 |
+
(e.g. GPT-2) do.
|
| 1284 |
+
|
| 1285 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1286 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1287 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1288 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1289 |
+
each row of the batch).
|
| 1290 |
+
""",
|
| 1291 |
+
PHI_START_DOCSTRING,
|
| 1292 |
+
)
|
| 1293 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
|
| 1294 |
+
class PhiForSequenceClassification(PhiPreTrainedModel):
|
| 1295 |
+
def __init__(self, config):
|
| 1296 |
+
super().__init__(config)
|
| 1297 |
+
self.num_labels = config.num_labels
|
| 1298 |
+
self.model = PhiModel(config)
|
| 1299 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1300 |
+
|
| 1301 |
+
# Initialize weights and apply final processing
|
| 1302 |
+
self.post_init()
|
| 1303 |
+
|
| 1304 |
+
def get_input_embeddings(self):
|
| 1305 |
+
return self.model.embed_tokens
|
| 1306 |
+
|
| 1307 |
+
def set_input_embeddings(self, value):
|
| 1308 |
+
self.model.embed_tokens = value
|
| 1309 |
+
|
| 1310 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
| 1311 |
+
def forward(
|
| 1312 |
+
self,
|
| 1313 |
+
input_ids: torch.LongTensor = None,
|
| 1314 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1315 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1316 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 1317 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1318 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1319 |
+
use_cache: Optional[bool] = None,
|
| 1320 |
+
output_attentions: Optional[bool] = None,
|
| 1321 |
+
output_hidden_states: Optional[bool] = None,
|
| 1322 |
+
return_dict: Optional[bool] = None,
|
| 1323 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 1324 |
+
r"""
|
| 1325 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1326 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1327 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1328 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1329 |
+
"""
|
| 1330 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1331 |
+
|
| 1332 |
+
model_outputs = self.model(
|
| 1333 |
+
input_ids,
|
| 1334 |
+
attention_mask=attention_mask,
|
| 1335 |
+
position_ids=position_ids,
|
| 1336 |
+
past_key_values=past_key_values,
|
| 1337 |
+
inputs_embeds=inputs_embeds,
|
| 1338 |
+
use_cache=use_cache,
|
| 1339 |
+
output_attentions=output_attentions,
|
| 1340 |
+
output_hidden_states=output_hidden_states,
|
| 1341 |
+
return_dict=return_dict,
|
| 1342 |
+
)
|
| 1343 |
+
hidden_states = model_outputs[0]
|
| 1344 |
+
logits = self.score(hidden_states)
|
| 1345 |
+
|
| 1346 |
+
if input_ids is not None:
|
| 1347 |
+
batch_size = input_ids.shape[0]
|
| 1348 |
+
else:
|
| 1349 |
+
batch_size = inputs_embeds.shape[0]
|
| 1350 |
+
|
| 1351 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 1352 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1353 |
+
if self.config.pad_token_id is None:
|
| 1354 |
+
sequence_lengths = -1
|
| 1355 |
+
else:
|
| 1356 |
+
if input_ids is not None:
|
| 1357 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 1358 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 1359 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 1360 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 1361 |
+
else:
|
| 1362 |
+
sequence_lengths = -1
|
| 1363 |
+
|
| 1364 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 1365 |
+
|
| 1366 |
+
loss = None
|
| 1367 |
+
if labels is not None:
|
| 1368 |
+
labels = labels.to(logits.device)
|
| 1369 |
+
if self.config.problem_type is None:
|
| 1370 |
+
if self.num_labels == 1:
|
| 1371 |
+
self.config.problem_type = "regression"
|
| 1372 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1373 |
+
self.config.problem_type = "single_label_classification"
|
| 1374 |
+
else:
|
| 1375 |
+
self.config.problem_type = "multi_label_classification"
|
| 1376 |
+
|
| 1377 |
+
if self.config.problem_type == "regression":
|
| 1378 |
+
loss_fct = MSELoss()
|
| 1379 |
+
if self.num_labels == 1:
|
| 1380 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 1381 |
+
else:
|
| 1382 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1383 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1384 |
+
loss_fct = CrossEntropyLoss()
|
| 1385 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 1386 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1387 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1388 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1389 |
+
if not return_dict:
|
| 1390 |
+
output = (pooled_logits,) + model_outputs[1:]
|
| 1391 |
+
return ((loss,) + output) if loss is not None else output
|
| 1392 |
+
|
| 1393 |
+
return SequenceClassifierOutputWithPast(
|
| 1394 |
+
loss=loss,
|
| 1395 |
+
logits=pooled_logits,
|
| 1396 |
+
past_key_values=model_outputs.past_key_values,
|
| 1397 |
+
hidden_states=model_outputs.hidden_states,
|
| 1398 |
+
attentions=model_outputs.attentions,
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
|
| 1402 |
+
@add_start_docstrings(
|
| 1403 |
+
"""
|
| 1404 |
+
PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1405 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1406 |
+
""",
|
| 1407 |
+
PHI_START_DOCSTRING,
|
| 1408 |
+
)
|
| 1409 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
|
| 1410 |
+
class PhiForTokenClassification(PhiPreTrainedModel):
|
| 1411 |
+
def __init__(self, config: PhiConfig):
|
| 1412 |
+
super().__init__(config)
|
| 1413 |
+
self.num_labels = config.num_labels
|
| 1414 |
+
|
| 1415 |
+
self.model = PhiModel(config)
|
| 1416 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 1417 |
+
classifier_dropout = config.classifier_dropout
|
| 1418 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
| 1419 |
+
classifier_dropout = config.hidden_dropout
|
| 1420 |
+
else:
|
| 1421 |
+
classifier_dropout = 0.1
|
| 1422 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1423 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1424 |
+
|
| 1425 |
+
# Initialize weights and apply final processing
|
| 1426 |
+
self.post_init()
|
| 1427 |
+
|
| 1428 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
| 1429 |
+
@add_code_sample_docstrings(
|
| 1430 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1431 |
+
output_type=TokenClassifierOutput,
|
| 1432 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1433 |
+
)
|
| 1434 |
+
def forward(
|
| 1435 |
+
self,
|
| 1436 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1437 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 1438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1439 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1440 |
+
labels: Optional[torch.Tensor] = None,
|
| 1441 |
+
use_cache: Optional[bool] = None,
|
| 1442 |
+
output_attentions: Optional[bool] = None,
|
| 1443 |
+
output_hidden_states: Optional[bool] = None,
|
| 1444 |
+
return_dict: Optional[bool] = None,
|
| 1445 |
+
**deprecated_arguments,
|
| 1446 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 1447 |
+
r"""
|
| 1448 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1449 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1450 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1451 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1452 |
+
"""
|
| 1453 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1454 |
+
|
| 1455 |
+
model_outputs = self.model(
|
| 1456 |
+
input_ids,
|
| 1457 |
+
past_key_values=past_key_values,
|
| 1458 |
+
attention_mask=attention_mask,
|
| 1459 |
+
inputs_embeds=inputs_embeds,
|
| 1460 |
+
use_cache=use_cache,
|
| 1461 |
+
output_attentions=output_attentions,
|
| 1462 |
+
output_hidden_states=output_hidden_states,
|
| 1463 |
+
return_dict=return_dict,
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
hidden_states = model_outputs[0]
|
| 1467 |
+
hidden_states = self.dropout(hidden_states)
|
| 1468 |
+
logits = self.classifier(hidden_states)
|
| 1469 |
+
|
| 1470 |
+
loss = None
|
| 1471 |
+
if labels is not None:
|
| 1472 |
+
# move labels to correct device to enable model parallelism
|
| 1473 |
+
labels = labels.to(logits.device)
|
| 1474 |
+
batch_size, seq_length = labels.shape
|
| 1475 |
+
loss_fct = CrossEntropyLoss()
|
| 1476 |
+
loss = loss_fct(
|
| 1477 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
| 1478 |
+
)
|
| 1479 |
+
|
| 1480 |
+
if not return_dict:
|
| 1481 |
+
output = (logits,) + model_outputs[2:]
|
| 1482 |
+
return ((loss,) + output) if loss is not None else output
|
| 1483 |
+
|
| 1484 |
+
return TokenClassifierOutput(
|
| 1485 |
+
loss=loss,
|
| 1486 |
+
logits=logits,
|
| 1487 |
+
hidden_states=model_outputs.hidden_states,
|
| 1488 |
+
attentions=model_outputs.attentions,
|
| 1489 |
+
)
|
mcp_servers/fashion_vlm/models/sampling.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def log(t, eps=1e-20):
|
| 11 |
+
return torch.log(t.clamp(min=eps))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def gumbel_noise(t, generator=None):
|
| 15 |
+
noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
|
| 16 |
+
return -log(-log(noise))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
|
| 20 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def top_k(logits, thres=0.9):
|
| 24 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
| 25 |
+
val, ind = logits.topk(k, dim=-1)
|
| 26 |
+
probs = torch.full_like(logits, float("-inf"))
|
| 27 |
+
probs.scatter_(2, ind, val)
|
| 28 |
+
return probs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
| 32 |
+
confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
|
| 33 |
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
| 34 |
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
| 35 |
+
masking = confidence < cut_off
|
| 36 |
+
return masking
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def cosine_schedule(t):
|
| 40 |
+
return torch.cos(t * math.pi * 0.5)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def linear_schedule(t):
|
| 44 |
+
mask_ratio = 1 - t
|
| 45 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
| 46 |
+
return mask_ratio
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def pow(t, method):
|
| 50 |
+
exponent = float(method.replace("pow", ""))
|
| 51 |
+
mask_ratio = 1.0 - t**exponent
|
| 52 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
| 53 |
+
return mask_ratio
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
|
| 57 |
+
for item in [t, start, end, tau]:
|
| 58 |
+
item = torch.tensor(item) if not torch.is_tensor(item) else item
|
| 59 |
+
|
| 60 |
+
# A gamma function based on sigmoid function.
|
| 61 |
+
v_start = torch.sigmoid(torch.tensor(start / tau))
|
| 62 |
+
v_end = torch.sigmoid(torch.tensor(end / tau))
|
| 63 |
+
output = torch.sigmoid((t * (end - start) + start) / tau)
|
| 64 |
+
output = (v_end - output) / (v_end - v_start)
|
| 65 |
+
return torch.clip(output, clip_min, 1.0)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_mask_chedule(method, **schedule_kwargs):
|
| 69 |
+
if method == "cosine":
|
| 70 |
+
return cosine_schedule
|
| 71 |
+
elif method == "linear":
|
| 72 |
+
return linear_schedule
|
| 73 |
+
elif "pow" in method:
|
| 74 |
+
return partial(pow, method=method)
|
| 75 |
+
elif method == "sigmoid":
|
| 76 |
+
return partial(sigmoid_schedule, **schedule_kwargs)
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError("Unknown schedule method: {}".format(method))
|
| 79 |
+
|
| 80 |
+
def top_k_top_p_filtering(
|
| 81 |
+
logits: torch.Tensor,
|
| 82 |
+
top_k: int = 0,
|
| 83 |
+
top_p: float = 1.0,
|
| 84 |
+
filter_value: float = -float("Inf"),
|
| 85 |
+
min_tokens_to_keep: int = 1,
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
| 88 |
+
Args:
|
| 89 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
| 90 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
| 91 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
| 92 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
| 93 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
| 94 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
| 95 |
+
"""
|
| 96 |
+
if top_k > 0:
|
| 97 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
| 98 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 99 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 100 |
+
logits[indices_to_remove] = filter_value
|
| 101 |
+
|
| 102 |
+
if top_p < 1.0:
|
| 103 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 104 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 105 |
+
|
| 106 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 107 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 108 |
+
if min_tokens_to_keep > 1:
|
| 109 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 110 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 111 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 112 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 113 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 114 |
+
|
| 115 |
+
# scatter sorted tensors to original indexing
|
| 116 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 117 |
+
logits[indices_to_remove] = filter_value
|
| 118 |
+
return logits
|
mcp_servers/fashion_vlm/prompting_utils.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 NUS Show Lab.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
# TODO - SHOULD BE FURTHER IMPROVED.
|
| 18 |
+
class UniversalPrompting():
|
| 19 |
+
def __init__(self, text_tokenizer,
|
| 20 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
| 21 |
+
max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1):
|
| 22 |
+
"""
|
| 23 |
+
:param text_tokenizer: original text tokenizer
|
| 24 |
+
"""
|
| 25 |
+
self.text_tokenizer = text_tokenizer
|
| 26 |
+
self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 27 |
+
self.text_tokenizer.add_tokens(list(special_tokens))
|
| 28 |
+
self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
|
| 29 |
+
special_tokens}
|
| 30 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
| 31 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
| 32 |
+
self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
|
| 33 |
+
# plus 1 because at this time we add a task token before
|
| 34 |
+
self.max_text_len = max_text_len + 1
|
| 35 |
+
self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]')
|
| 36 |
+
self.ignore_id = ignore_id
|
| 37 |
+
self.cond_dropout_prob = cond_dropout_prob
|
| 38 |
+
|
| 39 |
+
def t2i_prompt(self, text_ids, image_ids, labels):
|
| 40 |
+
|
| 41 |
+
device = image_ids.device
|
| 42 |
+
sequence_ids = []
|
| 43 |
+
attention_masks = []
|
| 44 |
+
label_ids = []
|
| 45 |
+
probs = torch.rand(len(text_ids))
|
| 46 |
+
for i in range(len(text_ids)):
|
| 47 |
+
|
| 48 |
+
if len(text_ids[i]) == 0:
|
| 49 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 50 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 51 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 52 |
+
|
| 53 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 54 |
+
|
| 55 |
+
# randomly dropout text condition
|
| 56 |
+
if probs[i] < self.cond_dropout_prob:
|
| 57 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
|
| 58 |
+
|
| 59 |
+
if self.max_text_len >= len(temp_ids):
|
| 60 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 61 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
|
| 62 |
+
else:
|
| 63 |
+
# should add the eos token
|
| 64 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 65 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
| 66 |
+
|
| 67 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 68 |
+
temp_label_ids = torch.cat([
|
| 69 |
+
# should we predict text tokens when doing image reconstruction?
|
| 70 |
+
torch.tensor(temp_ids).to(device),
|
| 71 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 72 |
+
labels[i],
|
| 73 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 74 |
+
], dim=0)
|
| 75 |
+
|
| 76 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 77 |
+
|
| 78 |
+
temp_ids = torch.cat([
|
| 79 |
+
torch.tensor(temp_ids).to(device),
|
| 80 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 81 |
+
image_ids[i],
|
| 82 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 83 |
+
], dim=0)
|
| 84 |
+
|
| 85 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 86 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 87 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 88 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 89 |
+
|
| 90 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 91 |
+
|
| 92 |
+
def t2i_gen_prompt(self, text_ids, image_ids):
|
| 93 |
+
|
| 94 |
+
device = image_ids.device
|
| 95 |
+
sequence_ids = []
|
| 96 |
+
attention_masks = []
|
| 97 |
+
for i in range(len(text_ids)):
|
| 98 |
+
if len(text_ids[i]) == 0:
|
| 99 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 100 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 101 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 102 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 103 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 104 |
+
if self.max_text_len >= len(temp_ids):
|
| 105 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 106 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
|
| 107 |
+
else:
|
| 108 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 109 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
| 110 |
+
|
| 111 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 112 |
+
temp_ids = torch.cat([
|
| 113 |
+
torch.tensor(temp_ids).to(device),
|
| 114 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 115 |
+
image_ids[i],
|
| 116 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 117 |
+
], dim=0)
|
| 118 |
+
|
| 119 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 120 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 121 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 122 |
+
|
| 123 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
| 124 |
+
|
| 125 |
+
# language modeling
|
| 126 |
+
def lm_prompt(self, text_ids, max_seq_len):
|
| 127 |
+
|
| 128 |
+
sequence_ids = []
|
| 129 |
+
attention_masks = []
|
| 130 |
+
label_ids = []
|
| 131 |
+
for i in range(len(text_ids)):
|
| 132 |
+
if len(text_ids[i]) == 0:
|
| 133 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 134 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 135 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 136 |
+
|
| 137 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 138 |
+
|
| 139 |
+
if max_seq_len >= len(temp_ids):
|
| 140 |
+
temp_labels_ids = temp_ids + [self.ignore_id] * (max_seq_len - len(temp_ids))
|
| 141 |
+
temp_ids = temp_ids + [self.pad_id] * (max_seq_len - len(temp_ids))
|
| 142 |
+
temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
|
| 143 |
+
else:
|
| 144 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
| 145 |
+
# exceeds the max sequence length
|
| 146 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
| 147 |
+
temp_ids = temp_ids[:max_seq_len]
|
| 148 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
| 149 |
+
|
| 150 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 151 |
+
temp_ids = torch.tensor(temp_ids)
|
| 152 |
+
temp_masks = torch.tensor(temp_masks)
|
| 153 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
| 154 |
+
|
| 155 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 156 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 157 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
| 158 |
+
|
| 159 |
+
# input_ids, masks, labels
|
| 160 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 161 |
+
|
| 162 |
+
def mmu_prompt(self, image_ids, text_ids):
|
| 163 |
+
device = image_ids.device
|
| 164 |
+
sequence_ids = []
|
| 165 |
+
attention_masks = []
|
| 166 |
+
label_ids = []
|
| 167 |
+
max_text_len = self.max_text_len - 1
|
| 168 |
+
for i in range(len(text_ids)):
|
| 169 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 170 |
+
# for empty list []
|
| 171 |
+
|
| 172 |
+
if len(text_ids[i]) == 0:
|
| 173 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 174 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 175 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 176 |
+
|
| 177 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 178 |
+
|
| 179 |
+
if max_text_len >= len(temp_ids):
|
| 180 |
+
# minus 1 because task token was prepended to the former image tokens
|
| 181 |
+
temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids))
|
| 182 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
|
| 183 |
+
else:
|
| 184 |
+
# should add the eos token
|
| 185 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 186 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
| 187 |
+
|
| 188 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 189 |
+
temp_label_ids = torch.cat([
|
| 190 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 191 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 192 |
+
torch.ones_like(image_ids[i]) * self.ignore_id,
|
| 193 |
+
torch.tensor([self.ignore_id]).to(device),
|
| 194 |
+
torch.tensor(temp_ids).to(device),
|
| 195 |
+
], dim=0)
|
| 196 |
+
|
| 197 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 198 |
+
|
| 199 |
+
temp_ids = torch.cat([
|
| 200 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
| 201 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 202 |
+
image_ids[i],
|
| 203 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
| 204 |
+
torch.tensor(temp_ids).to(device),
|
| 205 |
+
], dim=0)
|
| 206 |
+
|
| 207 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 208 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 209 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 210 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 211 |
+
|
| 212 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 213 |
+
|
| 214 |
+
def t2v_prompt(self, text_ids, image_ids, labels):
|
| 215 |
+
|
| 216 |
+
device = image_ids.device
|
| 217 |
+
sequence_ids = []
|
| 218 |
+
attention_masks = []
|
| 219 |
+
label_ids = []
|
| 220 |
+
probs = torch.rand(len(text_ids))
|
| 221 |
+
for i in range(len(text_ids)):
|
| 222 |
+
|
| 223 |
+
if len(text_ids[i]) == 0:
|
| 224 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 225 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 226 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 227 |
+
|
| 228 |
+
temp_ids = [int(self.sptids_dict['<|t2v|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 229 |
+
|
| 230 |
+
# randomly dropout text condition
|
| 231 |
+
if probs[i] < self.cond_dropout_prob:
|
| 232 |
+
temp_ids = [int(self.sptids_dict['<|t2v|>']), self.text_tokenizer.bos_token_id,
|
| 233 |
+
self.text_tokenizer.eos_token_id]
|
| 234 |
+
|
| 235 |
+
if self.max_text_len >= len(temp_ids):
|
| 236 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 237 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
|
| 238 |
+
else:
|
| 239 |
+
# should add the eos token
|
| 240 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 241 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
| 242 |
+
|
| 243 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 244 |
+
temp_label_ids = torch.cat([
|
| 245 |
+
# should we predict text tokens when doing image reconstruction?
|
| 246 |
+
torch.tensor(temp_ids).to(device),
|
| 247 |
+
self.sptids_dict['<|sov|>'].to(device),
|
| 248 |
+
labels[i],
|
| 249 |
+
self.sptids_dict['<|eov|>'].to(device)
|
| 250 |
+
], dim=0)
|
| 251 |
+
|
| 252 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 253 |
+
|
| 254 |
+
temp_ids = torch.cat([
|
| 255 |
+
torch.tensor(temp_ids).to(device),
|
| 256 |
+
self.sptids_dict['<|sov|>'].to(device),
|
| 257 |
+
image_ids[i],
|
| 258 |
+
self.sptids_dict['<|eov|>'].to(device)
|
| 259 |
+
], dim=0)
|
| 260 |
+
|
| 261 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 262 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 263 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 264 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 265 |
+
|
| 266 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 267 |
+
|
| 268 |
+
def t2v_gen_prompt(self, text_ids, image_ids):
|
| 269 |
+
|
| 270 |
+
device = image_ids.device
|
| 271 |
+
sequence_ids = []
|
| 272 |
+
attention_masks = []
|
| 273 |
+
for i in range(len(text_ids)):
|
| 274 |
+
if len(text_ids[i]) == 0:
|
| 275 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 276 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 277 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 278 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 279 |
+
temp_ids = [int(self.sptids_dict['<|t2v|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 280 |
+
if self.max_text_len >= len(temp_ids):
|
| 281 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 282 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
|
| 283 |
+
else:
|
| 284 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 285 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
| 286 |
+
|
| 287 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 288 |
+
temp_ids = torch.cat([
|
| 289 |
+
torch.tensor(temp_ids).to(device),
|
| 290 |
+
self.sptids_dict['<|sov|>'].to(device),
|
| 291 |
+
image_ids[i],
|
| 292 |
+
self.sptids_dict['<|eov|>'].to(device)
|
| 293 |
+
], dim=0)
|
| 294 |
+
|
| 295 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 296 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 297 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 298 |
+
|
| 299 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
| 300 |
+
|
| 301 |
+
def i2v_prompt(self, image_ids, video_ids):
|
| 302 |
+
"""
|
| 303 |
+
:param image_ids:
|
| 304 |
+
:param video_ids:
|
| 305 |
+
:return:
|
| 306 |
+
"""
|
| 307 |
+
pass
|
| 308 |
+
|
| 309 |
+
def lvg_prompt(self, text_ids, image_ids, labels):
|
| 310 |
+
|
| 311 |
+
device = image_ids.device
|
| 312 |
+
sequence_ids = []
|
| 313 |
+
attention_masks = []
|
| 314 |
+
label_ids = []
|
| 315 |
+
probs = torch.rand(len(text_ids))
|
| 316 |
+
probs2 = torch.rand(len(text_ids))
|
| 317 |
+
for i in range(len(text_ids)):
|
| 318 |
+
|
| 319 |
+
if len(text_ids[i]) == 0:
|
| 320 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 321 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 322 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 323 |
+
|
| 324 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 325 |
+
|
| 326 |
+
# randomly dropout text condition
|
| 327 |
+
if probs[i] < self.cond_dropout_prob:
|
| 328 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id,
|
| 329 |
+
self.text_tokenizer.eos_token_id]
|
| 330 |
+
|
| 331 |
+
if self.max_text_len >= len(temp_ids):
|
| 332 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 333 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
|
| 334 |
+
else:
|
| 335 |
+
# should add the eos token
|
| 336 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 337 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
| 338 |
+
|
| 339 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 340 |
+
temp_label_ids = torch.cat([
|
| 341 |
+
# should we predict text tokens when doing image reconstruction?
|
| 342 |
+
torch.tensor(temp_ids).to(device),
|
| 343 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 344 |
+
labels[i],
|
| 345 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 346 |
+
], dim=0)
|
| 347 |
+
|
| 348 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
| 349 |
+
|
| 350 |
+
temp_ids = torch.cat([
|
| 351 |
+
torch.tensor(temp_ids).to(device),
|
| 352 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 353 |
+
image_ids[i],
|
| 354 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 355 |
+
], dim=0)
|
| 356 |
+
|
| 357 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 358 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 359 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 360 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
| 361 |
+
|
| 362 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
| 363 |
+
|
| 364 |
+
def lvg_gen_prompt(self, text_ids, image_ids):
|
| 365 |
+
|
| 366 |
+
device = image_ids.device
|
| 367 |
+
sequence_ids = []
|
| 368 |
+
attention_masks = []
|
| 369 |
+
for i in range(len(text_ids)):
|
| 370 |
+
if len(text_ids[i]) == 0:
|
| 371 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
| 372 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
| 373 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
| 374 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
| 375 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
| 376 |
+
if self.max_text_len >= len(temp_ids):
|
| 377 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
| 378 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
|
| 379 |
+
else:
|
| 380 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
| 381 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
| 382 |
+
|
| 383 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
| 384 |
+
temp_ids = torch.cat([
|
| 385 |
+
torch.tensor(temp_ids).to(device),
|
| 386 |
+
self.sptids_dict['<|soi|>'].to(device),
|
| 387 |
+
image_ids[i],
|
| 388 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
| 389 |
+
], dim=0)
|
| 390 |
+
|
| 391 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
| 392 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
| 393 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
| 394 |
+
|
| 395 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
| 396 |
+
|
| 397 |
+
def mask_prompt(self):
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
def __call__(self, input, task, padding=True, config=None):
|
| 401 |
+
"""
|
| 402 |
+
input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
|
| 403 |
+
task (str) : a flag indicates the current task.
|
| 404 |
+
"""
|
| 405 |
+
if task == "t2i":
|
| 406 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 407 |
+
image_ids = input[1] # (B, #tokens)
|
| 408 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
|
| 409 |
+
|
| 410 |
+
elif task == "t2v":
|
| 411 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 412 |
+
image_ids = input[1] # (B, #tokens)
|
| 413 |
+
sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2])
|
| 414 |
+
|
| 415 |
+
elif task == "t2i_plus_lm":
|
| 416 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 417 |
+
image_ids = input[1] # (B, #tokens)
|
| 418 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids,
|
| 419 |
+
input[2])
|
| 420 |
+
sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
|
| 421 |
+
return sequence_ids_with_masks, sequence_ids_with_masks_lm
|
| 422 |
+
|
| 423 |
+
elif task == "t2i_gen":
|
| 424 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 425 |
+
image_ids = input[1] # (B, #tokens)
|
| 426 |
+
sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
|
| 427 |
+
|
| 428 |
+
elif task == "t2v_gen":
|
| 429 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 430 |
+
image_ids = input[1] # (B, #tokens)
|
| 431 |
+
sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids)
|
| 432 |
+
|
| 433 |
+
elif task == "lm":
|
| 434 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
| 435 |
+
sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
|
| 436 |
+
|
| 437 |
+
elif task == "mmu":
|
| 438 |
+
image_ids = input[0]
|
| 439 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
| 440 |
+
sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
|
| 441 |
+
|
| 442 |
+
elif task == "t2v":
|
| 443 |
+
text_ids = self.text_tokenizer(input[0]['input_ids'])
|
| 444 |
+
video_ids = self.vision_tokenizer(input[1])
|
| 445 |
+
sequence_ids_with_masks = self.t2v_prompt(text_ids, video_ids)
|
| 446 |
+
|
| 447 |
+
elif task == "i2v":
|
| 448 |
+
image_ids = self.text_tokenizer(input[0])
|
| 449 |
+
video_ids = self.vision_tokenizer(input[1])
|
| 450 |
+
sequence_ids_with_masks = self.i2v_prompt(image_ids, video_ids)
|
| 451 |
+
|
| 452 |
+
elif task == "lvg":
|
| 453 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 454 |
+
image_ids = input[1] # (B, #tokens)
|
| 455 |
+
sequence_ids_with_masks = self.lvg_prompt(text_ids, image_ids, input[2])
|
| 456 |
+
|
| 457 |
+
elif task == "lvg_gen":
|
| 458 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
| 459 |
+
image_ids = input[1] # (B, #tokens)
|
| 460 |
+
sequence_ids_with_masks = self.lvg_gen_prompt(text_ids, image_ids)
|
| 461 |
+
else:
|
| 462 |
+
raise NotImplementedError
|
| 463 |
+
|
| 464 |
+
return sequence_ids_with_masks
|
| 465 |
+
|
| 466 |
+
def create_attention_mask_predict_next(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, rm_pad_in_image=False,
|
| 467 |
+
return_inverse_mask=True):
|
| 468 |
+
# sequence is expected to be of shape [N, L]
|
| 469 |
+
N, L = sequence.shape
|
| 470 |
+
|
| 471 |
+
# Masks to identify different types of tokens
|
| 472 |
+
is_padding = sequence == pad_id
|
| 473 |
+
|
| 474 |
+
is_start_image = sequence == soi_id
|
| 475 |
+
|
| 476 |
+
is_end_image = sequence == eoi_id
|
| 477 |
+
|
| 478 |
+
# Create cumulative sum masks to identify regions of image tokens
|
| 479 |
+
cumulative_start = torch.cumsum(is_start_image, dim=1)
|
| 480 |
+
cumulative_end = torch.cumsum(is_end_image, dim=1)
|
| 481 |
+
in_image_segment = (cumulative_start > cumulative_end) | is_start_image | is_end_image
|
| 482 |
+
|
| 483 |
+
is_text = ~(in_image_segment)
|
| 484 |
+
|
| 485 |
+
causal_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)).to(sequence.device)
|
| 486 |
+
|
| 487 |
+
mask_text = is_text[:, :, None] * causal_mask[None, :, :]
|
| 488 |
+
|
| 489 |
+
is_text_image = is_text | in_image_segment
|
| 490 |
+
|
| 491 |
+
mask_text_image_bi = is_text_image[:, :, None] * is_text_image[:, None, :]
|
| 492 |
+
if rm_pad_in_image: # remove padding token in image
|
| 493 |
+
sid_img = torch.where(sequence == soi_id)[1]
|
| 494 |
+
for i in range(mask_text_image_bi.shape[0]):
|
| 495 |
+
pad_end_idx = torch.where(sequence[i] == pad_id)
|
| 496 |
+
if len(pad_end_idx[0]) != 0:
|
| 497 |
+
pad_end_idx = pad_end_idx[0][-1]
|
| 498 |
+
mask_text[i][pad_end_idx + 1:, :pad_end_idx + 1] = 0
|
| 499 |
+
id_padding = torch.where(is_padding[i] == True)
|
| 500 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
| 501 |
+
|
| 502 |
+
mask_text[in_image_segment] = mask_text_image_bi[in_image_segment]
|
| 503 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
| 504 |
+
if return_inverse_mask:
|
| 505 |
+
inverted_mask = 1.0 - mask_text.type(sequence.dtype)
|
| 506 |
+
inverted_mask = inverted_mask.masked_fill(
|
| 507 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
| 508 |
+
)
|
| 509 |
+
return inverted_mask.unsqueeze(1)
|
| 510 |
+
else:
|
| 511 |
+
return mask_text.unsqueeze(1)
|
| 512 |
+
|
| 513 |
+
def create_attention_mask_lvg(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, return_inverse_mask=True):
|
| 514 |
+
# sequence is expected to be of shape [N, L]
|
| 515 |
+
N, L = sequence.shape
|
| 516 |
+
# Masks to identify different types of tokens
|
| 517 |
+
is_padding = sequence == pad_id
|
| 518 |
+
mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device)
|
| 519 |
+
|
| 520 |
+
sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
|
| 521 |
+
sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 522 |
+
eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 523 |
+
for i in range(N):
|
| 524 |
+
id_padding = torch.where(is_padding[i] == True)
|
| 525 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
| 526 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
| 527 |
+
mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1,
|
| 528 |
+
sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1] = 1
|
| 529 |
+
|
| 530 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
| 531 |
+
if return_inverse_mask:
|
| 532 |
+
inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
|
| 533 |
+
inverted_mask = inverted_mask.masked_fill(
|
| 534 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
| 535 |
+
)
|
| 536 |
+
return inverted_mask.unsqueeze(1)
|
| 537 |
+
else:
|
| 538 |
+
return mask_text_image_bi.unsqueeze(1)
|
| 539 |
+
|
| 540 |
+
# texts without attending image regions
|
| 541 |
+
def create_attention_mask_lvg_v2(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, sot_id=1000, eot_id=1001, return_inverse_mask=True):
|
| 542 |
+
# sequence is expected to be of shape [N, L]
|
| 543 |
+
N, L = sequence.shape
|
| 544 |
+
# Masks to identify different types of tokens
|
| 545 |
+
is_padding = sequence == pad_id
|
| 546 |
+
# is_text = torch.where(sequence < 2000, True, False)
|
| 547 |
+
is_text = torch.where(sequence < pad_id, True, False)
|
| 548 |
+
mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device).int()
|
| 549 |
+
sid_text_for_bi = torch.where(sequence == sot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 550 |
+
eid_text_for_bi = torch.where(sequence == eot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 551 |
+
# import ipdb
|
| 552 |
+
# ipdb.set_trace()
|
| 553 |
+
if sot_id == eot_id:
|
| 554 |
+
if sid_text_for_bi.shape[-1] % 2 != 0:
|
| 555 |
+
sid_text_for_bi = sid_text_for_bi[:, :-1]
|
| 556 |
+
eid_text_for_bi = eid_text_for_bi[:, :-1]
|
| 557 |
+
select_idx = [i for i in range(0, sid_text_for_bi.shape[1], 2)]
|
| 558 |
+
sid_text_for_bi = sid_text_for_bi[:, select_idx]
|
| 559 |
+
select_idx = [i+1 for i in range(0, eid_text_for_bi.shape[1], 2)]
|
| 560 |
+
eid_text_for_bi = eid_text_for_bi[:, select_idx]
|
| 561 |
+
sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 562 |
+
eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
| 563 |
+
all_zeros = torch.zeros_like(mask_text_image_bi).int()
|
| 564 |
+
for i in range(N):
|
| 565 |
+
all_zeros[i, :, is_text[i]] = 1
|
| 566 |
+
for j in range(sid_text_for_bi.shape[-1]):
|
| 567 |
+
all_zeros[i][is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
|
| 568 |
+
all_zeros[i][~is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
|
| 569 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
| 570 |
+
all_zeros[i][~is_text[i], sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
|
| 571 |
+
mask_text_image_bi = mask_text_image_bi * all_zeros
|
| 572 |
+
sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
|
| 573 |
+
|
| 574 |
+
for i in range(N):
|
| 575 |
+
id_padding = torch.where(is_padding[i] == True)
|
| 576 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
| 577 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
| 578 |
+
mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1, sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
|
| 579 |
+
|
| 580 |
+
mask_text_image_bi[:, :, 0] = 1
|
| 581 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
| 582 |
+
if return_inverse_mask:
|
| 583 |
+
inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
|
| 584 |
+
inverted_mask = inverted_mask.masked_fill(
|
| 585 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
| 586 |
+
)
|
| 587 |
+
return inverted_mask.unsqueeze(1)
|
| 588 |
+
else:
|
| 589 |
+
return mask_text_image_bi.unsqueeze(1)
|
| 590 |
+
|
| 591 |
+
def create_attention_mask_for_mmu(sequence, eoi_id=128258, return_inverse_mask=True):
|
| 592 |
+
N, L = sequence.shape
|
| 593 |
+
causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
|
| 594 |
+
eoi_image = torch.where(sequence == eoi_id)[1]
|
| 595 |
+
causal_mask[:, :, :, :eoi_image[0] + 1] = 1
|
| 596 |
+
|
| 597 |
+
if return_inverse_mask:
|
| 598 |
+
inverted_mask = 1.0 - causal_mask.type(sequence.dtype)
|
| 599 |
+
inverted_mask = inverted_mask.masked_fill(
|
| 600 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
| 601 |
+
)
|
| 602 |
+
return inverted_mask
|
| 603 |
+
else:
|
| 604 |
+
return causal_mask
|
| 605 |
+
|
| 606 |
+
def create_attention_mask_for_mmu_vit(
|
| 607 |
+
sequence,
|
| 608 |
+
return_inverse_mask=True,
|
| 609 |
+
system_prompt_len=0
|
| 610 |
+
):
|
| 611 |
+
N, L, H = sequence.shape
|
| 612 |
+
causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
|
| 613 |
+
index = 1 + system_prompt_len + 1 + 576
|
| 614 |
+
# PART OF SYSTEM PROMPT SHOULD BE CAUSAL ALSO
|
| 615 |
+
# causal_mask[:, :, :, :index] = 1
|
| 616 |
+
causal_mask[:, :, :, 1+system_prompt_len+1:index] = 1 # 把image token对应的列attention全部置为1
|
| 617 |
+
if return_inverse_mask:
|
| 618 |
+
inverted_mask = 1.0 - causal_mask.type(torch.int64)
|
| 619 |
+
inverted_mask = inverted_mask.masked_fill(
|
| 620 |
+
inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min
|
| 621 |
+
)
|
| 622 |
+
return inverted_mask
|
| 623 |
+
else:
|
| 624 |
+
return causal_mask
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == '__main__':
|
| 628 |
+
pass
|
mcp_servers/product_user_database.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mcp_server/product_user_database.py
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from itertools import combinations
|
| 9 |
+
|
| 10 |
+
from scipy import sparse
|
| 11 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 12 |
+
from mcp.server.fastmcp import FastMCP
|
| 13 |
+
from mcp.server.sse import SseServerTransport
|
| 14 |
+
from starlette.applications import Starlette
|
| 15 |
+
from starlette.routing import Route, Mount
|
| 16 |
+
import uvicorn
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 20 |
+
from openai import AsyncOpenAI
|
| 21 |
+
|
| 22 |
+
# Load environment variables
|
| 23 |
+
load_dotenv()
|
| 24 |
+
FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT", "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec")
|
| 25 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 26 |
+
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
|
| 27 |
+
openai = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
|
| 28 |
+
|
| 29 |
+
###################################
|
| 30 |
+
#########Loading Model#############
|
| 31 |
+
###################################
|
| 32 |
+
# Load CLIP model and processor
|
| 33 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
|
| 34 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
|
| 35 |
+
clip_model.eval()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Load item metadata
|
| 39 |
+
items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet").set_index("item_id")
|
| 40 |
+
outfits_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/outfits_lite.parquet").set_index("outfit_id")
|
| 41 |
+
users_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/users_lite.parquet").set_index("user_id")
|
| 42 |
+
image_paths = items_df["path"].to_dict()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class InteractionDataManager:
|
| 46 |
+
def __init__(self, users_df, outfits_df, items_df):
|
| 47 |
+
"""
|
| 48 |
+
初始化类,加载数据并设置基本参数
|
| 49 |
+
|
| 50 |
+
参数:
|
| 51 |
+
- users_file: 用户数据文件路径 (parquet)
|
| 52 |
+
- outfits_file: Outfit 数据文件路径 (parquet)
|
| 53 |
+
- items_file: 单品数据文件路径 (parquet)
|
| 54 |
+
"""
|
| 55 |
+
self.users_df = users_df
|
| 56 |
+
self.outfits_df = outfits_df
|
| 57 |
+
self.items_df = items_df
|
| 58 |
+
|
| 59 |
+
# 创建映射
|
| 60 |
+
self.item_id_to_index = {item_id: index for index, item_id in enumerate(self.items_df.index)}
|
| 61 |
+
self.index_to_item_id = {index: item_id for index, item_id in enumerate(self.items_df.index)}
|
| 62 |
+
self.user_id_to_index = {user_id: index for index, user_id in enumerate(self.users_df.index)}
|
| 63 |
+
self.index_to_user_id = {index: user_id for index, user_id in enumerate(self.users_df.index)}
|
| 64 |
+
self.outfit_ids_dict = self.outfits_df['item_ids'].to_dict() # get outfit's item ids from outfit id
|
| 65 |
+
self.item_category_dict = self.items_df['category'].to_dict() # get item's category from item id
|
| 66 |
+
self.item_subcategory_dict = self.items_df['subcategory'].to_dict() # get item's subcategory from item id
|
| 67 |
+
self.n_items = len(self.items_df)
|
| 68 |
+
self.n_users = len(self.users_df)
|
| 69 |
+
|
| 70 |
+
self.user_outfit_pairs = []
|
| 71 |
+
outfit_set = set(self.outfits_df.index)
|
| 72 |
+
for uid, user in self.users_df.iterrows():
|
| 73 |
+
oids = user.outfit_ids.split(",")
|
| 74 |
+
self.user_outfit_pairs.extend([(uid, oid) for oid in oids if oid in outfit_set])
|
| 75 |
+
|
| 76 |
+
# 预处理类别到物品ID的映射(使用groupby)
|
| 77 |
+
self.subcategory_to_items = self.items_df.groupby('subcategory').apply(lambda x: set(x.index)).to_dict()
|
| 78 |
+
|
| 79 |
+
# 预处理类别到物品索引的映射(优化查找效率)
|
| 80 |
+
self.subcategory_to_indices = {}
|
| 81 |
+
for subcategory, item_ids in self.subcategory_to_items.items():
|
| 82 |
+
self.subcategory_to_indices[subcategory] = set([self.item_id_to_index[item_id]
|
| 83 |
+
for item_id in item_ids
|
| 84 |
+
if item_id in self.item_id_to_index])
|
| 85 |
+
|
| 86 |
+
item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/item_matrix.npz'
|
| 87 |
+
try:
|
| 88 |
+
self.load_matrix('item', item_interaction_matrix_path)
|
| 89 |
+
except FileNotFoundError:
|
| 90 |
+
self.build_item_interaction_matrix()
|
| 91 |
+
self.save_matrix('item', item_interaction_matrix_path)
|
| 92 |
+
|
| 93 |
+
user_item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/user_item_matrix.npz'
|
| 94 |
+
try:
|
| 95 |
+
self.load_matrix('user_item', user_item_interaction_matrix_path)
|
| 96 |
+
except FileNotFoundError:
|
| 97 |
+
self.build_user_item_interaction_matrix()
|
| 98 |
+
self.save_matrix('user_item', user_item_interaction_matrix_path)
|
| 99 |
+
|
| 100 |
+
# 加载item clip features
|
| 101 |
+
with open(f"{FASHION_DATA_ROOT}/meta/clip_features.pkl", "rb") as f:
|
| 102 |
+
print("Loading Fashion Features...")
|
| 103 |
+
self.clip_features = pickle.load(f)
|
| 104 |
+
print("Loading Fashion Features Successfully")
|
| 105 |
+
|
| 106 |
+
# Prepare embeddings and item IDs
|
| 107 |
+
self.item_ids = list(self.clip_features.keys())
|
| 108 |
+
self.image_embeddings = np.array([self.clip_features[item_id]["image_embeds"] for item_id in item_ids])
|
| 109 |
+
|
| 110 |
+
def save_matrix(self, matrix_type, filepath):
|
| 111 |
+
"""
|
| 112 |
+
保存矩阵到文件
|
| 113 |
+
|
| 114 |
+
参数:
|
| 115 |
+
- matrix_type: 'item' 或 'user_item',指定保存的矩阵类型
|
| 116 |
+
- filepath: 保存路径 (例如 'temp/item_matrix.npz')
|
| 117 |
+
"""
|
| 118 |
+
if matrix_type == 'item':
|
| 119 |
+
matrix = self.item_interaction_matrix
|
| 120 |
+
elif matrix_type == 'user_item':
|
| 121 |
+
matrix = self.user_item_interaction_matrix
|
| 122 |
+
else:
|
| 123 |
+
raise ValueError("matrix_type must be 'item' or 'user_item'")
|
| 124 |
+
|
| 125 |
+
if matrix is None:
|
| 126 |
+
raise ValueError(f"{matrix_type} matrix has not been built yet.")
|
| 127 |
+
|
| 128 |
+
sparse.save_npz(filepath, matrix)
|
| 129 |
+
print(f"Saved {matrix_type} matrix to {filepath}")
|
| 130 |
+
|
| 131 |
+
def load_matrix(self, matrix_type, filepath):
|
| 132 |
+
"""
|
| 133 |
+
从文件加载矩阵
|
| 134 |
+
|
| 135 |
+
参数:
|
| 136 |
+
- matrix_type: 'item' 或 'user_item',指定加载的矩阵类型
|
| 137 |
+
- filepath: 加载路径 (例如 'temp/item_matrix.npz')
|
| 138 |
+
"""
|
| 139 |
+
if not os.path.exists(filepath):
|
| 140 |
+
raise FileNotFoundError(f"File {filepath} does not exist.")
|
| 141 |
+
|
| 142 |
+
matrix = sparse.load_npz(filepath)
|
| 143 |
+
if matrix_type == 'item':
|
| 144 |
+
self.item_interaction_matrix = matrix
|
| 145 |
+
elif matrix_type == 'user_item':
|
| 146 |
+
self.user_item_interaction_matrix = matrix
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError("matrix_type must be 'item' or 'user_item'")
|
| 149 |
+
|
| 150 |
+
print(f"Loaded {matrix_type} matrix from {filepath}")
|
| 151 |
+
return matrix
|
| 152 |
+
|
| 153 |
+
def build_item_interaction_matrix(self):
|
| 154 |
+
"""构建 Item-Item 交互矩阵"""
|
| 155 |
+
# 初始化单品交互矩阵
|
| 156 |
+
self.item_interaction_matrix = sparse.lil_matrix((self.n_items, self.n_items), dtype=int)
|
| 157 |
+
|
| 158 |
+
for index, outfit in tqdm(self.outfits_df.iterrows(), total=len(self.outfits_df)):
|
| 159 |
+
item_ids = outfit['item_ids'].split(',')
|
| 160 |
+
# 记录 item 对的共现
|
| 161 |
+
for item_id1, item_id2 in combinations(item_ids, r=2):
|
| 162 |
+
if item_id1 in self.item_id_to_index and item_id2 in self.item_id_to_index:
|
| 163 |
+
idx1 = self.item_id_to_index[item_id1]
|
| 164 |
+
idx2 = self.item_id_to_index[item_id2]
|
| 165 |
+
self.item_interaction_matrix[idx1, idx2] += 1
|
| 166 |
+
self.item_interaction_matrix[idx2, idx1] += 1 # 无序对称
|
| 167 |
+
|
| 168 |
+
# 转换为 CSR 格式
|
| 169 |
+
self.item_interaction_matrix = self.item_interaction_matrix.tocsr()
|
| 170 |
+
return self.item_interaction_matrix
|
| 171 |
+
|
| 172 |
+
def build_user_item_interaction_matrix(self):
|
| 173 |
+
"""构建 User-Item 交互矩阵"""
|
| 174 |
+
# 初始化用户-单品交互矩阵
|
| 175 |
+
self.user_item_interaction_matrix = sparse.lil_matrix((self.n_users, self.n_items), dtype=int)
|
| 176 |
+
|
| 177 |
+
for uid, user in tqdm(self.users_df.iterrows(), total=len(self.users_df)):
|
| 178 |
+
oids = user["outfit_ids"].split(",")
|
| 179 |
+
outfits = self.outfits_df.loc[self.outfits_df.index.isin(oids)]
|
| 180 |
+
for oid, outfit in outfits.iterrows():
|
| 181 |
+
item_ids = outfit['item_ids'].split(',')
|
| 182 |
+
# 记录 user-item 对的出现
|
| 183 |
+
for iid in item_ids:
|
| 184 |
+
if iid in self.item_id_to_index:
|
| 185 |
+
uidx = self.user_id_to_index[uid]
|
| 186 |
+
iidx = self.item_id_to_index[iid]
|
| 187 |
+
self.user_item_interaction_matrix[uidx, iidx] += 1
|
| 188 |
+
|
| 189 |
+
# 转换为 CSR 格式
|
| 190 |
+
self.user_item_interaction_matrix = self.user_item_interaction_matrix.tocsr()
|
| 191 |
+
return self.user_item_interaction_matrix
|
| 192 |
+
|
| 193 |
+
def _process_interactions_for_category(
|
| 194 |
+
self,
|
| 195 |
+
matrix,
|
| 196 |
+
given_id,
|
| 197 |
+
category_indices,
|
| 198 |
+
id_to_index
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
处理单个实体与目标类别的交互
|
| 202 |
+
|
| 203 |
+
参数:
|
| 204 |
+
- matrix: 交互矩阵
|
| 205 |
+
- given_id: 给定的实体ID(用户或物品)
|
| 206 |
+
- category_indices: 目标类别的物品索引集合
|
| 207 |
+
|
| 208 |
+
返回:
|
| 209 |
+
- 交互列表,每个元素为一个包含item_id、interaction_count和score的字典
|
| 210 |
+
"""
|
| 211 |
+
interactions = []
|
| 212 |
+
|
| 213 |
+
given_index = id_to_index[given_id]
|
| 214 |
+
row = matrix[given_index]
|
| 215 |
+
|
| 216 |
+
# 提取该行的非零元素
|
| 217 |
+
row_start = row.indptr[0]
|
| 218 |
+
row_end = row.indptr[1]
|
| 219 |
+
col_indices = row.indices[row_start:row_end]
|
| 220 |
+
data_values = row.data[row_start:row_end]
|
| 221 |
+
|
| 222 |
+
# 筛选出属于目标类别的物品
|
| 223 |
+
for col_idx, value in zip(col_indices, data_values):
|
| 224 |
+
# 检查是否为目标类别的物品
|
| 225 |
+
if col_idx in category_indices:
|
| 226 |
+
# 获取物品ID
|
| 227 |
+
output_id = self.index_to_item_id[col_idx]
|
| 228 |
+
interactions.append({
|
| 229 |
+
'item_id': output_id,
|
| 230 |
+
'interaction_count': int(value),
|
| 231 |
+
'score': 0.0
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
return interactions
|
| 235 |
+
|
| 236 |
+
def get_item_category_interactions(
|
| 237 |
+
self,
|
| 238 |
+
target_category: str,
|
| 239 |
+
given_ids: List[str],
|
| 240 |
+
query_type='item', # item or user
|
| 241 |
+
top_k=None,
|
| 242 |
+
):
|
| 243 |
+
"""
|
| 244 |
+
获取指定实体(用户或单品)与目标类别的所有交互情况
|
| 245 |
+
|
| 246 |
+
参数:
|
| 247 |
+
- target_category: 待查询的subcategory
|
| 248 |
+
- given_ids: List of 目标类别
|
| 249 |
+
- query_type: 查询的类别, item或user
|
| 250 |
+
- top_k: 返回交互次数最多的前k个物品, 如果是None直接全部返回
|
| 251 |
+
|
| 252 |
+
返回:
|
| 253 |
+
- 列表,包含与目标类别的交互统计信息,按交互次数排序
|
| 254 |
+
"""
|
| 255 |
+
if query_type == 'item':
|
| 256 |
+
matrix = self.item_interaction_matrix
|
| 257 |
+
id_to_index = self.item_id_to_index
|
| 258 |
+
elif query_type == 'user':
|
| 259 |
+
matrix = self.user_item_interaction_matrix
|
| 260 |
+
id_to_index = self.user_id_to_index
|
| 261 |
+
else:
|
| 262 |
+
print(f'query_type must be either item or user but got {query_type}')
|
| 263 |
+
return []
|
| 264 |
+
|
| 265 |
+
# 收集所有交互记录
|
| 266 |
+
all_interactions = []
|
| 267 |
+
category = target_category
|
| 268 |
+
category_indices = self.subcategory_to_indices.get(category, set()) # 获取该类别的所有物品索引
|
| 269 |
+
|
| 270 |
+
# 获取该实体的所有交互
|
| 271 |
+
for given_id in given_ids:
|
| 272 |
+
interactions = self._process_interactions_for_category(
|
| 273 |
+
matrix, given_id, category_indices, id_to_index
|
| 274 |
+
)
|
| 275 |
+
# 将交互添加到结果列表
|
| 276 |
+
all_interactions.extend(interactions)
|
| 277 |
+
|
| 278 |
+
# 合并相同物品的交互次数
|
| 279 |
+
item_interactions = {}
|
| 280 |
+
for interaction in all_interactions:
|
| 281 |
+
item_id = interaction['item_id']
|
| 282 |
+
count = interaction['interaction_count']
|
| 283 |
+
|
| 284 |
+
if item_id in item_interactions:
|
| 285 |
+
item_interactions[item_id] += count
|
| 286 |
+
else:
|
| 287 |
+
item_interactions[item_id] = count
|
| 288 |
+
|
| 289 |
+
# 转换为结果格式
|
| 290 |
+
merged_interactions = [
|
| 291 |
+
{'item_id': item_id, 'interaction_count': count, 'score': 0.0}
|
| 292 |
+
for item_id, count in item_interactions.items()
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
# 排序
|
| 296 |
+
if merged_interactions:
|
| 297 |
+
merged_interactions.sort(key=lambda x: x['interaction_count'], reverse=True)
|
| 298 |
+
|
| 299 |
+
# 截取top-k
|
| 300 |
+
if top_k and merged_interactions:
|
| 301 |
+
merged_interactions = merged_interactions[:top_k]
|
| 302 |
+
|
| 303 |
+
# 存储结果
|
| 304 |
+
return merged_interactions
|
| 305 |
+
|
| 306 |
+
def rank_by_similarity(self, item_interactions, user_interactions, beta=2.0):
|
| 307 |
+
"""
|
| 308 |
+
计算用户交互项与商品交互项的相似度并排序
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def get_combined_features(feature_dict):
|
| 312 |
+
return (feature_dict['image_embeds'] + feature_dict['text_embeds']) / 2
|
| 313 |
+
|
| 314 |
+
item_feature_list = []
|
| 315 |
+
for item in item_interactions:
|
| 316 |
+
item_id = item['item_id']
|
| 317 |
+
if item_id not in self.clip_features:
|
| 318 |
+
raise ValueError(f"Didn't find clip feature of item with id: {item_id}")
|
| 319 |
+
|
| 320 |
+
item_features = get_combined_features(self.clip_features[item_id])
|
| 321 |
+
item_feature_list.append(item_features)
|
| 322 |
+
|
| 323 |
+
weights = np.array([x['interaction_count'] for x in item_interactions], dtype=np.float32)
|
| 324 |
+
weights = weights / np.sum(weights)
|
| 325 |
+
item_feature = np.sum(np.stack(item_feature_list, axis=0) * weights[:, np.newaxis], axis=0).reshape(1, -1)
|
| 326 |
+
|
| 327 |
+
max_count = max((user_item.get('interaction_count', 1) for user_item in user_interactions), default=1)
|
| 328 |
+
for user_item in user_interactions:
|
| 329 |
+
user_item_id = user_item['item_id']
|
| 330 |
+
if user_item_id not in self.clip_features:
|
| 331 |
+
raise ValueError(f"Didn't find clip feature of item with id: {user_item_id}")
|
| 332 |
+
|
| 333 |
+
user_item_features = get_combined_features(self.clip_features[user_item_id]).reshape(1, -1)
|
| 334 |
+
similarity = cosine_similarity(user_item_features, item_feature).item()
|
| 335 |
+
interaction_count = user_item['interaction_count']
|
| 336 |
+
count_factor = (interaction_count / max_count) * beta + 1
|
| 337 |
+
user_item['score'] = float(similarity) * count_factor
|
| 338 |
+
|
| 339 |
+
user_interactions.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 340 |
+
return user_interactions
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
data_manager = InteractionDataManager(users_df, outfits_df, items_df)
|
| 344 |
+
mcp = FastMCP('image-retrieval-server')
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@mcp.tool()
|
| 348 |
+
async def summary_user_history(user_id: str, target_category: str, list_of_items: List[str]) -> str:
|
| 349 |
+
"""Summary user's buying history of specific fashion category given user_id, target_category, list_of_items
|
| 350 |
+
After we collect all buying history of this user, we will summarize descriptions of these historical items through LLM.
|
| 351 |
+
So we will return user's preference about target_category in sentences.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
user_id (str): User id. Will be provided through prompt
|
| 355 |
+
target_category (str): We care about user's buying history of this specific category.
|
| 356 |
+
list_of_items: List of item ids for history filtering. Will be provided through prompt
|
| 357 |
+
"""
|
| 358 |
+
# We need to find the most appropriate item to become the target item
|
| 359 |
+
# It should have enough relationship with user and other items
|
| 360 |
+
# Specifically, item_interaction larger than 3, history larger than 10
|
| 361 |
+
item_interaction_result = data_manager.get_item_category_interactions(
|
| 362 |
+
target_category, list_of_items, query_type='item'
|
| 363 |
+
)
|
| 364 |
+
user_interaction_result = data_manager.get_item_category_interactions(
|
| 365 |
+
target_category, [user_id], query_type='user'
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def get_description(item_id: str) -> str:
|
| 369 |
+
return data_manager.items_df.loc[item_id].gen_description
|
| 370 |
+
|
| 371 |
+
descriptions_for_summary = []
|
| 372 |
+
if len(item_interaction_result) == 0:
|
| 373 |
+
descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result]
|
| 374 |
+
else:
|
| 375 |
+
if len(user_interaction_result) >= 0:
|
| 376 |
+
user_interaction_result = data_manager.rank_by_similarity(
|
| 377 |
+
item_interaction_result,
|
| 378 |
+
user_interaction_result
|
| 379 |
+
)
|
| 380 |
+
descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result[:5]]
|
| 381 |
+
|
| 382 |
+
if descriptions_for_summary:
|
| 383 |
+
user_message = f"Summary user's preference of {target_category} based on following descriptions of fashion items that user brought previously:"
|
| 384 |
+
for x in descriptions_for_summary:
|
| 385 |
+
user_message += f"\n{x}"
|
| 386 |
+
# Get summary using OpenAI API call
|
| 387 |
+
response = await openai.chat.completions.create(
|
| 388 |
+
model="gpt-4o-mini",
|
| 389 |
+
messages=[
|
| 390 |
+
{"role": "system", "content": f"You are a user preference summary assistant. Your response is limited in one sentence, staring at 'I prefer ...'"},
|
| 391 |
+
{"role": "user", "content": user_message}
|
| 392 |
+
],
|
| 393 |
+
max_tokens=1000,
|
| 394 |
+
)
|
| 395 |
+
return response.choices[0].message.content
|
| 396 |
+
else:
|
| 397 |
+
return ""
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
user_id = "115"
|
| 401 |
+
# 根据类别和given outfit找到这个用户的历史交互
|
| 402 |
+
partial_outfit = ["25479e5dacebbfaed18a7dc4830bd5cd19114486", "becc7b46236e9abb6f6760e7a1569b06bbc236c1",
|
| 403 |
+
"180c32b5c8c164f3c632f3e73d6002ccfa6fea57"]
|
| 404 |
+
target_category = "Skirts"
|
| 405 |
+
summary_user_history(user_id, target_category, partial_outfit)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
async def compute_text_embedding(text: str) -> np.ndarray:
|
| 409 |
+
inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True)
|
| 410 |
+
with torch.no_grad():
|
| 411 |
+
text_embedding = clip_model.get_text_features(**inputs).numpy()
|
| 412 |
+
return text_embedding / np.linalg.norm(text_embedding, axis=1, keepdims=True)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
async def find_most_similar_image(text_embedding: np.ndarray) -> Dict[str, Any]:
|
| 416 |
+
similarities = np.dot(data_manager.image_embeddings, text_embedding.T).flatten()
|
| 417 |
+
most_similar_idx = np.argmax(similarities)
|
| 418 |
+
most_similar_item_id = data_manager.item_ids[most_similar_idx]
|
| 419 |
+
return {
|
| 420 |
+
"image_path": image_paths[most_similar_item_id],
|
| 421 |
+
"similarity": float(similarities[most_similar_idx])
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@mcp.tool()
|
| 426 |
+
async def retrieve_image(text: str) -> Dict[str, Any]:
|
| 427 |
+
"""Search for the most similar fashion image based on a text description.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
text (str): Text description of the fashion item to search.
|
| 431 |
+
"""
|
| 432 |
+
print(f"Searching for {text}")
|
| 433 |
+
text_embedding = await compute_text_embedding(text)
|
| 434 |
+
return await find_most_similar_image(text_embedding)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
mcp_server = mcp._mcp_server # 获取内部 Server 对象
|
| 438 |
+
sse_transport = SseServerTransport("/messages/")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
async def handle_sse(request):
|
| 442 |
+
print("Handling SSE connection")
|
| 443 |
+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
| 444 |
+
read_stream, write_stream = streams
|
| 445 |
+
await mcp_server.run(
|
| 446 |
+
read_stream,
|
| 447 |
+
write_stream,
|
| 448 |
+
mcp_server.create_initialization_options(),
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# 定义路由
|
| 452 |
+
routes = [
|
| 453 |
+
Route("/sse", endpoint=handle_sse),
|
| 454 |
+
Mount("/messages/", app=sse_transport.handle_post_message),
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
# 创建 Starlette 应用
|
| 458 |
+
starlette_app = Starlette(routes=routes)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
if __name__ == "__main__":
|
| 462 |
+
print("Starting Image Retrieval server with HTTP and SSE...")
|
| 463 |
+
uvicorn.run(starlette_app, host="0.0.0.0", port=8001) # 使用 8001 端口,避免与 FashionVLM 冲突
|
mcp_servers/virtual_try_on.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import uuid
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import PIL
|
| 7 |
+
from google import genai
|
| 8 |
+
from google.genai import types
|
| 9 |
+
from mcp.server.fastmcp import FastMCP
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 15 |
+
GEN_IMG_DIR = os.getenv("GEN_IMG_DIR")
|
| 16 |
+
# os.environ["HTTP_PROXY"] = "http://127.0.0.1:10809"
|
| 17 |
+
# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:10809"
|
| 18 |
+
|
| 19 |
+
client = genai.Client(api_key=GEMINI_API_KEY)
|
| 20 |
+
mcp = FastMCP("virtual_try_on")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def save_image(response, path):
|
| 24 |
+
for part in response.candidates[0].content.parts:
|
| 25 |
+
if part.text is not None:
|
| 26 |
+
continue
|
| 27 |
+
elif part.inline_data is not None:
|
| 28 |
+
mime = part.inline_data.mime_type
|
| 29 |
+
data = part.inline_data.data
|
| 30 |
+
pathlib.Path(path).write_bytes(data)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@mcp.tool()
|
| 34 |
+
async def try_on(image_path: str) -> str:
|
| 35 |
+
"""Generate a virtual try-on image based on image path and return the saved file path.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
image_path str: Path to the input image file for try-on image generation
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
str: File path of the generated image
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
print(image_path)
|
| 45 |
+
response = client.models.generate_content(
|
| 46 |
+
model="models/gemini-2.0-flash-exp",
|
| 47 |
+
contents=[
|
| 48 |
+
"You are a virtual try on tool. Put all the clothes uploaded on a real person and create a picture. Only clothings should be put on, excluding shoes or bags or accessories.",
|
| 49 |
+
PIL.Image.open(os.path.abspath(image_path))
|
| 50 |
+
],
|
| 51 |
+
config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
|
| 52 |
+
)
|
| 53 |
+
gen_img_filename = f'{GEN_IMG_DIR}/{uuid.uuid4().hex}.png'
|
| 54 |
+
await save_image(response, gen_img_filename)
|
| 55 |
+
return os.path.abspath(gen_img_filename)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(e)
|
| 58 |
+
return image_path
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
print("Started MCP server 'virtual_try_on'...")
|
| 63 |
+
mcp.run(transport='stdio')
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
# image_path = "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec/data/basic_recommendation/train/temp/0000000_target.jpg"
|
| 68 |
+
# print(image_path)
|
| 69 |
+
#
|
| 70 |
+
# def save_image(response, path):
|
| 71 |
+
# for part in response.candidates[0].content.parts:
|
| 72 |
+
# if part.inline_data is not None:
|
| 73 |
+
# data = part.inline_data.data
|
| 74 |
+
# pathlib.Path(path).write_bytes(data)
|
| 75 |
+
#
|
| 76 |
+
#
|
| 77 |
+
# client = genai.Client(api_key="AIzaSyCd3sP-FksEgLB2GCFom8UDvasWJ-glSL4")
|
| 78 |
+
#
|
| 79 |
+
# response = client.models.generate_content(
|
| 80 |
+
# model="models/gemini-2.0-flash-exp",
|
| 81 |
+
# contents=[
|
| 82 |
+
# "你是虚拟穿衣工具,把这一套衣服都穿到模特身上,输出一张图片,全身图",
|
| 83 |
+
# PIL.Image.open(image_path)
|
| 84 |
+
# ],
|
| 85 |
+
# config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
|
| 86 |
+
# )
|
| 87 |
+
#
|
| 88 |
+
# for part in response.candidates[0].content.parts:
|
| 89 |
+
# if part.text is not None:
|
| 90 |
+
# print(part.text)
|
| 91 |
+
#
|
| 92 |
+
# save_image(response, 'edited_image3.png')
|
| 93 |
+
|
| 94 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.21.0
|
| 2 |
+
aiohttp==3.9.5
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
albumentations==0.3.2
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
antlr4-python3-runtime==4.9.3
|
| 7 |
+
anykeystore==0.2
|
| 8 |
+
asn1crypto==1.5.1
|
| 9 |
+
asttokens==2.4.1
|
| 10 |
+
async-timeout==4.0.3
|
| 11 |
+
attrs==21.2.0
|
| 12 |
+
bidict==0.23.1
|
| 13 |
+
blessed==1.20.0
|
| 14 |
+
boto3==1.34.113
|
| 15 |
+
botocore==1.34.113
|
| 16 |
+
braceexpand==0.1.7
|
| 17 |
+
cachetools==5.3.3
|
| 18 |
+
certifi==2024.2.2
|
| 19 |
+
cffi==1.16.0
|
| 20 |
+
chardet==5.2.0
|
| 21 |
+
charset-normalizer==3.3.2
|
| 22 |
+
click==8.1.7
|
| 23 |
+
clip==0.2.0
|
| 24 |
+
clip-openai==1.0.post20230121
|
| 25 |
+
cmake==3.29.3
|
| 26 |
+
cramjam==2.8.3
|
| 27 |
+
crcmod==1.7
|
| 28 |
+
cryptacular==1.6.2
|
| 29 |
+
cryptography==39.0.2
|
| 30 |
+
cycler==0.12.1
|
| 31 |
+
datasets==2.2.1
|
| 32 |
+
diffusers==0.30.1
|
| 33 |
+
decorator==5.1.1
|
| 34 |
+
decord==0.6.0
|
| 35 |
+
deepspeed==0.14.2
|
| 36 |
+
defusedxml==0.7.1
|
| 37 |
+
Deprecated==1.2.14
|
| 38 |
+
descartes==1.1.0
|
| 39 |
+
dill==0.3.8
|
| 40 |
+
distlib==0.3.8
|
| 41 |
+
distro-info==1.0
|
| 42 |
+
dnspython==2.6.1
|
| 43 |
+
docker-pycreds==0.4.0
|
| 44 |
+
docstring_parser==0.16
|
| 45 |
+
ecdsa==0.19.0
|
| 46 |
+
einops==0.6.0
|
| 47 |
+
exceptiongroup==1.2.1
|
| 48 |
+
executing==2.0.1
|
| 49 |
+
fairscale==0.4.13
|
| 50 |
+
fastparquet==2024.5.0
|
| 51 |
+
ffmpegcv==0.3.13
|
| 52 |
+
filelock==3.14.0
|
| 53 |
+
fire==0.6.0
|
| 54 |
+
fonttools==4.51.0
|
| 55 |
+
frozenlist==1.4.1
|
| 56 |
+
fsspec==2023.6.0
|
| 57 |
+
ftfy==6.2.0
|
| 58 |
+
gitdb==4.0.11
|
| 59 |
+
GitPython==3.1.43
|
| 60 |
+
gpustat==1.1.1
|
| 61 |
+
greenlet==3.0.3
|
| 62 |
+
grpcio==1.64.0
|
| 63 |
+
h11==0.14.0
|
| 64 |
+
hjson==3.1.0
|
| 65 |
+
huggingface-hub==0.23.2
|
| 66 |
+
hupper==1.12.1
|
| 67 |
+
idna==3.7
|
| 68 |
+
imageio==2.34.1
|
| 69 |
+
imgaug==0.2.6
|
| 70 |
+
iniconfig==2.0.0
|
| 71 |
+
ipaddress==1.0.23
|
| 72 |
+
ipdb==0.13.13
|
| 73 |
+
ipython==8.18.1
|
| 74 |
+
jaxtyping==0.2.28
|
| 75 |
+
jedi==0.19.1
|
| 76 |
+
Jinja2==3.1.4
|
| 77 |
+
jmespath==1.0.1
|
| 78 |
+
joblib==1.4.2
|
| 79 |
+
jsonargparse==4.14.1
|
| 80 |
+
jsonlines==4.0.0
|
| 81 |
+
kiwisolver==1.4.5
|
| 82 |
+
kornia==0.7.2
|
| 83 |
+
kornia_rs==0.1.3
|
| 84 |
+
lazy_loader==0.4
|
| 85 |
+
lightning==2.2.3
|
| 86 |
+
lightning-utilities==0.11.2
|
| 87 |
+
lit==18.1.6
|
| 88 |
+
MarkupSafe==2.1.5
|
| 89 |
+
matplotlib==3.5.3
|
| 90 |
+
matplotlib-inline==0.1.7
|
| 91 |
+
miscreant==0.3.0
|
| 92 |
+
mpmath==1.3.0
|
| 93 |
+
msgpack==1.0.8
|
| 94 |
+
multidict==6.0.5
|
| 95 |
+
multiprocess==0.70.16
|
| 96 |
+
natsort==8.4.0
|
| 97 |
+
networkx==3.2.1
|
| 98 |
+
ninja==1.11.1.1
|
| 99 |
+
numpy==1.24.4
|
| 100 |
+
nuscenes-devkit==1.1.11
|
| 101 |
+
oauthlib==3.2.2
|
| 102 |
+
omegaconf==2.3.0
|
| 103 |
+
open-clip-torch==2.24.0
|
| 104 |
+
openai-clip
|
| 105 |
+
opencv-python==4.9.0.80
|
| 106 |
+
opencv-python-headless==3.4.18.65
|
| 107 |
+
packaging==22.0
|
| 108 |
+
pandas==1.5.3
|
| 109 |
+
parquet==1.3.1
|
| 110 |
+
parso==0.8.4
|
| 111 |
+
PasteDeploy==3.1.0
|
| 112 |
+
pathlib2==2.3.7.post1
|
| 113 |
+
pathtools==0.1.2
|
| 114 |
+
pbkdf2==1.3
|
| 115 |
+
pexpect==4.9.0
|
| 116 |
+
pillow==10.3.0
|
| 117 |
+
plaster==1.1.2
|
| 118 |
+
plaster-pastedeploy==1.0.1
|
| 119 |
+
platformdirs==4.2.2
|
| 120 |
+
plotly==5.22.0
|
| 121 |
+
pluggy==1.5.0
|
| 122 |
+
ply==3.11
|
| 123 |
+
promise==2.3
|
| 124 |
+
prompt-toolkit==3.0.43
|
| 125 |
+
protobuf==3.20.3
|
| 126 |
+
psutil==5.9.8
|
| 127 |
+
ptyprocess==0.7.0
|
| 128 |
+
pure-eval==0.2.2
|
| 129 |
+
py==1.11.0
|
| 130 |
+
py-cpuinfo==9.0.0
|
| 131 |
+
py-spy==0.3.14
|
| 132 |
+
pyarrow==11.0.0
|
| 133 |
+
pyarrow-hotfix==0.6
|
| 134 |
+
pyasn1==0.6.0
|
| 135 |
+
pycocotools==2.0.7
|
| 136 |
+
pycparser==2.22
|
| 137 |
+
pycryptodomex==3.20.0
|
| 138 |
+
pycurl==7.43.0.6
|
| 139 |
+
pydantic==1.10.15
|
| 140 |
+
pydantic_core==2.18.3
|
| 141 |
+
Pygments==2.18.0
|
| 142 |
+
PyJWT==2.8.0
|
| 143 |
+
pynvml==11.5.0
|
| 144 |
+
pyope==0.2.2
|
| 145 |
+
pyOpenSSL==23.2.0
|
| 146 |
+
pyparsing==3.1.2
|
| 147 |
+
pyquaternion==0.9.9
|
| 148 |
+
pyramid==2.0.2
|
| 149 |
+
pyramid-mailer==0.15.1
|
| 150 |
+
pytest==6.2.5
|
| 151 |
+
python-consul==1.1.0
|
| 152 |
+
python-dateutil==2.9.0.post0
|
| 153 |
+
python-engineio==4.9.1
|
| 154 |
+
python-etcd==0.4.5
|
| 155 |
+
python-jose==3.3.0
|
| 156 |
+
python-socketio==5.11.2
|
| 157 |
+
python3-openid==3.2.0
|
| 158 |
+
pytorch-extension==0.2
|
| 159 |
+
pytorch-lightning==2.2.3
|
| 160 |
+
pytz==2024.1
|
| 161 |
+
PyYAML==6.0.1
|
| 162 |
+
regex==2024.5.15
|
| 163 |
+
repoze.sendmail==4.4.1
|
| 164 |
+
requests==2.31.0
|
| 165 |
+
requests-oauthlib==2.0.0
|
| 166 |
+
rsa==4.9
|
| 167 |
+
s3transfer==0.10.1
|
| 168 |
+
safetensors==0.4.3
|
| 169 |
+
schedule==1.2.2
|
| 170 |
+
scikit-image==0.22.0
|
| 171 |
+
scikit-learn==1.5.0
|
| 172 |
+
scipy==1.13.1
|
| 173 |
+
sentencepiece==0.2.0
|
| 174 |
+
sentry-sdk==2.3.1
|
| 175 |
+
setproctitle==1.3.3
|
| 176 |
+
Shapely==1.8.5.post1
|
| 177 |
+
shortuuid==1.0.13
|
| 178 |
+
simple-websocket==1.0.0
|
| 179 |
+
six==1.16.0
|
| 180 |
+
smmap==5.0.1
|
| 181 |
+
SQLAlchemy==2.0.30
|
| 182 |
+
stack-data==0.6.3
|
| 183 |
+
sympy==1.12
|
| 184 |
+
taming-transformers-rom1504==0.0.6
|
| 185 |
+
tenacity==8.3.0
|
| 186 |
+
tensorboardX==2.6.2.2
|
| 187 |
+
termcolor==2.4.0
|
| 188 |
+
threadpoolctl==3.5.0
|
| 189 |
+
thriftpy2==0.5.0
|
| 190 |
+
tifffile==2024.5.22
|
| 191 |
+
timm==1.0.3
|
| 192 |
+
tokenizers==0.19.1
|
| 193 |
+
toml==0.10.2
|
| 194 |
+
tomli==2.0.1
|
| 195 |
+
torch==2.2.1
|
| 196 |
+
torch-fidelity==0.3.0
|
| 197 |
+
torchmetrics==1.4.0.post0
|
| 198 |
+
torchvision==0.17.1
|
| 199 |
+
tox==3.28.0
|
| 200 |
+
tqdm==4.66.4
|
| 201 |
+
traitlets==5.14.3
|
| 202 |
+
transaction==4.0
|
| 203 |
+
transformers==4.41.1
|
| 204 |
+
translationstring==1.4
|
| 205 |
+
triton==2.2.0
|
| 206 |
+
typeguard==2.13.3
|
| 207 |
+
typing_extensions==4.12.0
|
| 208 |
+
tzdata==2024.1
|
| 209 |
+
urllib3==1.26.18
|
| 210 |
+
velruse==1.1.1
|
| 211 |
+
venusian==3.1.0
|
| 212 |
+
virtualenv==20.26.2
|
| 213 |
+
wandb==0.17.2
|
| 214 |
+
watchdog==4.0.1
|
| 215 |
+
wcwidth==0.2.13
|
| 216 |
+
webdataset==0.2.86
|
| 217 |
+
WebOb==1.8.7
|
| 218 |
+
websocket-client==1.8.0
|
| 219 |
+
wrapt==1.16.0
|
| 220 |
+
wsproto==1.2.0
|
| 221 |
+
WTForms==3.1.2
|
| 222 |
+
wtforms-recaptcha==0.3.2
|
| 223 |
+
xformers==0.0.25
|
| 224 |
+
xxhash==3.4.1
|
| 225 |
+
yarl==1.9.4
|
| 226 |
+
zope.deprecation==5.0
|
| 227 |
+
zope.interface==6.4.post2
|
| 228 |
+
zope.sqlalchemy==3.1
|
system_message.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SYSTEM_MESSAGE = """
|
| 2 |
+
You are a fashion assistant. Use weather tool to get weather alert.
|
| 3 |
+
|
| 4 |
+
Use the retrieve_image tool when user ask your find product from the database based on user descriptions.
|
| 5 |
+
Use the image_generate tool to generate fashion item image from descriptions. The description must be in English!
|
| 6 |
+
Use the fashion_recommend_without_image tool to generate recommendations when there are no image paths included.
|
| 7 |
+
Use the fashion_recommend tool for those query with uploaded image. Provide recommendation according to the query and uploaded images.
|
| 8 |
+
Use the try_on tool for user's request of try-on uploaded images. Those clothing images paths will be provided in the prompt.
|
| 9 |
+
"""
|
utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_image_grid(image_paths: List[str], output_path: str, grid_size: int = 2) -> None:
|
| 6 |
+
"""
|
| 7 |
+
将多个图片合并为 4 宫格图片,少于 4 个则留空,确保透明背景被填充为白色,并实现等比缩放。
|
| 8 |
+
|
| 9 |
+
参数:
|
| 10 |
+
- image_paths: 输入图片路径列表
|
| 11 |
+
- output_path: 输出 4 宫格图片路径
|
| 12 |
+
- grid_size: 网格大小(默认 2x2)
|
| 13 |
+
"""
|
| 14 |
+
images = []
|
| 15 |
+
target_size = (256, 256) # 每个格子的目标尺寸
|
| 16 |
+
|
| 17 |
+
for path in image_paths[:4]:
|
| 18 |
+
try:
|
| 19 |
+
# 加载图片,保留透明性
|
| 20 |
+
img = Image.open(path).convert('RGBA')
|
| 21 |
+
|
| 22 |
+
# 如果图片有透明通道,将透明区域填充为白色
|
| 23 |
+
if img.mode == 'RGBA':
|
| 24 |
+
background = Image.new('RGBA', img.size, (255, 255, 255, 255)) # 白色背景
|
| 25 |
+
img = Image.alpha_composite(background, img)
|
| 26 |
+
|
| 27 |
+
# 转换为 RGB 模式
|
| 28 |
+
img = img.convert('RGB')
|
| 29 |
+
original_width, original_height = img.size
|
| 30 |
+
aspect_ratio = original_width / original_height
|
| 31 |
+
if original_width >= original_height:
|
| 32 |
+
# 宽度是长边,调整宽度到 256
|
| 33 |
+
new_width = 256
|
| 34 |
+
new_height = int(256 / aspect_ratio)
|
| 35 |
+
else:
|
| 36 |
+
# 高度是长边,调整高度到 256
|
| 37 |
+
new_height = 256
|
| 38 |
+
new_width = int(256 * aspect_ratio)
|
| 39 |
+
|
| 40 |
+
# 等比缩放
|
| 41 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 42 |
+
# img.thumbnail(target_size, Image.Resampling.LANCZOS) # 使用高质量缩放算法
|
| 43 |
+
|
| 44 |
+
# 创建 256x256 的空白画布(白色背景)
|
| 45 |
+
canvas = Image.new('RGB', target_size, (255, 255, 255))
|
| 46 |
+
|
| 47 |
+
# 计算居中位置
|
| 48 |
+
offset_x = (target_size[0] - img.size[0]) // 2
|
| 49 |
+
offset_y = (target_size[1] - img.size[1]) // 2
|
| 50 |
+
|
| 51 |
+
# 将缩放后的图片居中贴到画布上
|
| 52 |
+
canvas.paste(img, (offset_x, offset_y))
|
| 53 |
+
|
| 54 |
+
images.append(canvas)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Error loading image {path}: {e}")
|
| 57 |
+
images.append(None)
|
| 58 |
+
|
| 59 |
+
# 如果图片不足 4 张,补空
|
| 60 |
+
while len(images) < 4:
|
| 61 |
+
images.append(None)
|
| 62 |
+
|
| 63 |
+
# 创建空白画布(512x512,白色背景)
|
| 64 |
+
grid_image = Image.new('RGB', (512, 512), (255, 255, 255)) # 白色背景
|
| 65 |
+
|
| 66 |
+
# 按 2x2 排列贴图
|
| 67 |
+
for idx, img in enumerate(images):
|
| 68 |
+
if img is not None:
|
| 69 |
+
x = (idx % 2) * 256
|
| 70 |
+
y = (idx // 2) * 256
|
| 71 |
+
grid_image.paste(img, (x, y))
|
| 72 |
+
|
| 73 |
+
# 保存图片为 JPG 格式
|
| 74 |
+
grid_image.save(output_path, quality=95) # 设置质量为 95,避免过度压缩
|