pangkaicheng commited on
Commit
f8a73ec
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv/
11
+ .env
12
+
13
+ generated_images/
14
+ .chainlit/
15
+ .idea/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FashionRec Dataset download is required
2
+
3
+ .env file is required, including
4
+
5
+ CHAINLIT_PORT=8888
6
+
7
+ PROXY=http://127.0.0.1:10809
8
+
9
+ OPENAI_API_KEY=
10
+ GEMINI_API_KEY=
11
+
12
+ FASHION_DATA_ROOT=path_to_FashionRec
13
+ GEN_IMG_DIR="./generated_images"
chainlit.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Chainlit! 🚀🤖
2
+
3
+ Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
4
+
5
+ ## Useful Links 🔗
6
+
7
+ - **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
8
+ - **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
9
+
10
+ We can't wait to see what you create with Chainlit! Happy coding! 💻😊
11
+
12
+ ## Welcome screen
13
+
14
+ To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
chainlit_app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pandas as pd
4
+ import uuid
5
+
6
+ from openai import AsyncOpenAI
7
+ from dotenv import load_dotenv
8
+ import chainlit as cl
9
+
10
+ from system_message import SYSTEM_MESSAGE
11
+ from mcp_client import MCPClient
12
+ from utils import create_image_grid
13
+ import httpx
14
+
15
+
16
+ # Load environment variables from .env
17
+ load_dotenv()
18
+
19
+ # Initialize OpenAI client
20
+ CHAINLIT_PORT = os.getenv("CHAINLIT_PORT", "8888")
21
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
22
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
23
+ FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT")
24
+ PROXY = os.getenv("PROXY")
25
+ items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet")
26
+ item_id_set = set(items_df.item_id)
27
+
28
+ http_client = httpx.AsyncClient(proxy=PROXY) if PROXY else httpx.Client()
29
+
30
+
31
+ class FashionAgent:
32
+ def __init__(self, user_id=None):
33
+ self.mcp_client = MCPClient("mcp_server_config.json", user_id)
34
+ self.openai = AsyncOpenAI(api_key=OPENAI_API_KEY, http_client=http_client)
35
+ self.user_id = user_id
36
+
37
+
38
+ # 全局 FashionAgent 实例
39
+ agent = FashionAgent(user_id=None)
40
+
41
+
42
+ @cl.on_chat_start
43
+ async def on_chat_start():
44
+ await agent.mcp_client.connect_to_servers()
45
+ cl.user_session.set("agent", agent)
46
+ await cl.Message(content="Hello Sophia! Welcome to FashionM3. How can I assist you today?").send()
47
+
48
+
49
+ @cl.on_message
50
+ async def on_message(message: cl.Message):
51
+ agent = cl.user_session.get("agent")
52
+ user_id = cl.user_session.get("user_id")
53
+ chat_history = cl.user_session.get("chat_history", [])
54
+
55
+ user_message = message.content
56
+
57
+ upload_image = [x.path for x in message.elements if isinstance(x, cl.Image)]
58
+ if len(upload_image) == 1:
59
+ user_message += f"\nThe uploaded image path is: {os.path.abspath(upload_image[0])}"
60
+ elif len(upload_image) > 1:
61
+ merged_image_path = f".files/{uuid.uuid4().hex}.jpg"
62
+ create_image_grid(upload_image[:4], merged_image_path)
63
+
64
+ user_message += f"\nThe uploaded image path is: {os.path.abspath(merged_image_path)}"
65
+
66
+ image_in_database = []
67
+ for image in message.elements:
68
+ if isinstance(image, cl.Image):
69
+ item_id = image.name.split(".")[0]
70
+ if item_id in item_id_set:
71
+ image_in_database.append(item_id)
72
+ if len(image_in_database) > 0:
73
+ user_message += f"\nUser id is: {user_id}"
74
+ user_message += f"\nlist_of_items are: {image_in_database}"
75
+ elif user_id:
76
+ user_message += f"\nUser id is: {user_id}"
77
+
78
+ # Prepare messages for OpenAI API
79
+ messages = [
80
+ {"role": "system", "content": SYSTEM_MESSAGE},
81
+ *[{"role": "user" if isinstance(msg, cl.Message) else "assistant", "content": msg.content} for msg in chat_history],
82
+ {"role": "user", "content": user_message}
83
+ ]
84
+
85
+ # Fetch available tools
86
+ available_tools = await agent.mcp_client.get_tools()
87
+
88
+ # Initial OpenAI API call
89
+ response = await agent.openai.chat.completions.create(
90
+ model="gpt-4o-mini",
91
+ messages=messages,
92
+ max_tokens=1000,
93
+ tools=available_tools if available_tools else None,
94
+ tool_choice="auto" if available_tools else None
95
+ )
96
+
97
+ # Process the response
98
+ response_message = response.choices[0].message
99
+ if response_message.tool_calls:
100
+ # Handle tool calls
101
+ for tool_call in response_message.tool_calls:
102
+ tool_name = tool_call.function.name
103
+ params = json.loads(tool_call.function.arguments)
104
+ try:
105
+ print(f"Agent execute {tool_name} with params: {params}")
106
+ result = await agent.mcp_client.execute_tool(tool_name, params)
107
+ if tool_name == "retrieve_image":
108
+ image_path = json.loads(result['result'][0].text)['image_path']
109
+ similarity = json.loads(result['result'][0].text)['similarity']
110
+ output = f"I found a matching fashion item with a similarity score of {similarity:.2f}"
111
+
112
+ images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
113
+ await cl.Message(content=output, elements=images, author="Fashion Agent").send()
114
+ if tool_name == "image_generate":
115
+ image_path = result['result'][0].text
116
+ images = [cl.Image(path=image_path, name="Product image", display="inline", size="medium")]
117
+ output = f"Here is the generated image."
118
+ await cl.Message(content=output, elements=images, author="Fashion Agent").send()
119
+ if tool_name == "fashion_recommend_without_image":
120
+ output = result['result'][0].text
121
+ await cl.Message(content=output, author="Fashion Agent").send()
122
+ if tool_name == "fashion_recommend":
123
+ output = json.loads(result['result'][0].text)['recommendation']
124
+ # user_preference = json.loads(result['result'][0].text)['user_preference']
125
+ # await cl.Message(content=user_preference, author="Fashion Agent").send()
126
+ await cl.Message(content=output, author="Fashion Agent").send()
127
+ if tool_name == "try_on":
128
+ image_path = result['result'][0].text
129
+ images = [cl.Image(path=image_path, name="Try-on image", display="inline", size="large")]
130
+ output = f"Here is the virtual try-on image."
131
+ await cl.Message(content=output, elements=images, author="Fashion Agent").send()
132
+ else:
133
+ output = result
134
+ except Exception as e:
135
+ output = f"Error executing tool {tool_name}: {str(e)}"
136
+
137
+ # Update chat history
138
+ chat_history.append(cl.Message(content=message.content, author="user"))
139
+ chat_history.append(cl.Message(content=output, author="assistant"))
140
+ cl.user_session.set("chat_history", chat_history)
141
+ else:
142
+ # Direct response from the model
143
+ output = response_message.content
144
+ chat_history.append(cl.Message(content=message.content, author="user"))
145
+ chat_history.append(cl.Message(content=output, author="assistant"))
146
+ cl.user_session.set("chat_history", chat_history)
147
+ await cl.Message(content=output, author="Fashion Agent").send()
148
+
149
+
150
+ @cl.on_chat_end
151
+ def on_chat_end():
152
+ print("Goodbye", cl.user_session.get("id"))
153
+
154
+
155
+ if __name__ == "__main__":
156
+ from chainlit.cli import run_chainlit
157
+
158
+ os.environ["CHAINLIT_PORT"] = CHAINLIT_PORT
159
+ run_chainlit(__file__)
mcp_client.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Optional
2
+ from contextlib import AsyncExitStack
3
+ import json
4
+ import aiohttp
5
+ import asyncio
6
+
7
+ from mcp import ClientSession, StdioServerParameters
8
+ from mcp.client.stdio import stdio_client
9
+ from mcp.client.sse import sse_client
10
+
11
+
12
+ class MCPClient:
13
+ def __init__(self, config_path: str, user_id: Optional[str] = None):
14
+ """
15
+ Initialize MCPClient with a list of server configurations.
16
+ Each config should be a dict with 'path' (script path) and optionally 'type' (python/node).
17
+ """
18
+ self.user_id = user_id
19
+ with open(config_path, 'r') as f:
20
+ self.server_configs = json.load(f)['mcpServers']
21
+ self.sessions: Dict[str, Any] = {} # 存储 stdio 的 ClientSession 或 sse 的 aiohttp session
22
+ self.exit_stack = AsyncExitStack()
23
+
24
+ async def connect_to_servers(self):
25
+ """Connect to all configured MCP servers based on their transport type."""
26
+ for server_name, config in self.server_configs.items():
27
+ transport = config.get("transport", "stdio") # 默认使用 stdio
28
+ print(f"Connecting to {server_name} ({transport})...")
29
+ if transport == "stdio":
30
+ command = config.get("command")
31
+ args = config.get("args", [])
32
+ env = config.get("env", None)
33
+ if not command:
34
+ raise ValueError(f"No command specified for server {server_name}")
35
+
36
+ server_params = StdioServerParameters(command=command, args=args, env=env)
37
+ stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
38
+ stdio, write = stdio_transport
39
+ session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
40
+ await session.initialize()
41
+ self.sessions[server_name] = session
42
+ # self.stdio_transports[server_name] = (stdio, write)
43
+ elif transport == "sse":
44
+ server_url = config.get("url", "")
45
+ if not server_url:
46
+ raise ValueError(f"No base_url specified for server {server_name}")
47
+
48
+ # 建立 SSE 连接
49
+ streams_context = sse_client(url=f"{server_url}/sse")
50
+ streams = await self.exit_stack.enter_async_context(streams_context)
51
+ session_context = ClientSession(*streams)
52
+ session = await self.exit_stack.enter_async_context(session_context)
53
+
54
+ # 初始化会话
55
+ await session.initialize()
56
+ self.sessions[server_name] = session
57
+ # self.sse_contexts[server_name] = (streams_context, session_context)
58
+
59
+ # 验证连接
60
+ print(f"Initialized SSE client for {server_name}...")
61
+ print("Listing tools...")
62
+ response = await session.list_tools()
63
+ tools = response.tools
64
+ print(f"Connected to {server_name} with tools:", [tool.name for tool in tools])
65
+ else:
66
+ raise ValueError(f"Unsupported transport type '{transport}' for {server_name}")
67
+
68
+ async def get_tools(self) -> List[Dict[str, Any]]:
69
+ """
70
+ Fetch the list of available tools from all connected MCP servers.
71
+ Returns a list of tool definitions with name, description, and inputSchema.
72
+ """
73
+ all_tools = []
74
+ for server_name, session in self.sessions.items():
75
+ response = await session.list_tools()
76
+ for tool in response.tools:
77
+ if not self.user_id and tool.name == 'personalized_fashion_recommend':
78
+ continue
79
+ all_tools.append(
80
+ {
81
+ "type": "function",
82
+ "function": {
83
+ "name": tool.name,
84
+ "description": tool.description,
85
+ "parameters": tool.inputSchema
86
+ }
87
+ }
88
+ )
89
+ return all_tools
90
+
91
+ async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
92
+ """
93
+ Execute a tool with the given parameters on the appropriate server.
94
+ """
95
+ # Find which server has this tool
96
+ for server_name, session in self.sessions.items():
97
+ response = await session.list_tools()
98
+ for tool in response.tools:
99
+ if tool.name == tool_name:
100
+ # Execute the tool on the correct server
101
+ result = await session.call_tool(tool_name, params)
102
+ return {
103
+ "result": result.content,
104
+ "server": server_name
105
+ }
106
+ raise Exception(f"Tool {tool_name} not found on any connected server")
107
+
108
+ async def close(self):
109
+ """Close all server connections."""
110
+ await self.exit_stack.aclose()
mcp_server_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mcpServers": {
3
+ "fashion_vlm": {
4
+ "transport": "sse",
5
+ "url": "http://localhost:8000"
6
+ },
7
+ "virtual_try_on": {
8
+ "command": "python",
9
+ "args": ["mcp_servers/virtual_try_on.py"],
10
+ "env": {
11
+ "HTTP_PROXY": "http://127.0.0.1:10809",
12
+ "HTTPS_PROXY": "http://127.0.0.1:10809"
13
+ }
14
+ }
15
+ }
16
+ }
mcp_servers/fashion_vlm/__init__.py ADDED
File without changes
mcp_servers/fashion_vlm/fashion_vlm_infer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import os
3
+ import datetime
4
+ from omegaconf import OmegaConf
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ from torch import tensor
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+ from models import Showo, MAGVITv2, CLIPVisionTower, get_mask_chedule
12
+ from prompting_utils import (UniversalPrompting,
13
+ create_attention_mask_for_mmu,
14
+ create_attention_mask_predict_next)
15
+ from transformers import AutoTokenizer, CLIPImageProcessor
16
+
17
+
18
+ def image_transform(image, resolution=256, normalize=True):
19
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
20
+ image = transforms.CenterCrop((resolution, resolution))(image)
21
+ image = transforms.ToTensor()(image)
22
+ if normalize:
23
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
24
+ return image
25
+
26
+
27
+ class FashionVLM:
28
+ def __init__(self, temperature, top_k, max_new_tokens, fashion_vlm_name,
29
+ batch_size=3, guidance_scale=5, generation_temperature=1.0, generation_timesteps=50,
30
+ save_dir="generated_images"):
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.temperature = temperature # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
33
+ self.top_k = top_k # retain only the top_k most likely tokens, clamp others to have 0 probability
34
+ self.max_new_tokens = max_new_tokens
35
+ self.fashion_vlm_name = fashion_vlm_name
36
+
37
+ #param for t2i
38
+ self.save_dir = save_dir
39
+ self.batch_size = batch_size
40
+ self.guidance_scale = guidance_scale
41
+ self.generation_temperature = generation_temperature
42
+ self.generation_timesteps = generation_timesteps
43
+
44
+ self._init_models()
45
+
46
+ def _init_models(self):
47
+ # 初始化Universal Prompting
48
+ self.tokenizer = AutoTokenizer.from_pretrained(
49
+ "microsoft/phi-1_5",
50
+ padding_side="left"
51
+ )
52
+ self.uni_prompting = UniversalPrompting(
53
+ self.tokenizer,
54
+ max_text_len=381,
55
+ special_tokens=(
56
+ "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"
57
+ ),
58
+ ignore_id=-100, cond_dropout_prob=0.1
59
+ )
60
+
61
+ # 初始化VQ模型
62
+ self.vq_model = MAGVITv2.from_pretrained("showlab/magvitv2").to(self.device)
63
+ self.vq_model.requires_grad_(False)
64
+ self.vq_model.eval()
65
+
66
+ self.model = Showo.from_pretrained(self.fashion_vlm_name).to(self.device)
67
+ self.model.eval()
68
+
69
+ def mmu_infer_tensor(self, image: tensor, prompt: tensor):
70
+ """
71
+ Image size: batch * 3 * 256 * 256
72
+ """
73
+ image = image.to(self.device)
74
+ prompt = prompt.to(self.device)
75
+ # pixel_values = self.clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
76
+ image_tokens = self.vq_model.get_code(image) + len(self.uni_prompting.text_tokenizer)
77
+
78
+ input_ids = torch.cat([
79
+ (torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|mmu|>']).to(self.device),
80
+ (torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|soi|>']).to(self.device),
81
+ image_tokens,
82
+ (torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|eoi|>']).to(self.device),
83
+ (torch.ones(prompt.shape[0], 1) * self.uni_prompting.sptids_dict['<|sot|>']).to(self.device),
84
+ prompt
85
+ ], dim=1).long()
86
+
87
+ attention_mask = create_attention_mask_for_mmu(
88
+ input_ids.to(self.device),
89
+ eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>'])
90
+ )
91
+
92
+ cont_toks_list = self.model.mmu_generate(
93
+ input_ids, attention_mask=attention_mask,
94
+ max_new_tokens=self.max_new_tokens, top_k=self.top_k,
95
+ eot_token=self.uni_prompting.sptids_dict['<|eot|>']
96
+ )
97
+
98
+ cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
99
+
100
+ text = self.uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
101
+ return text
102
+
103
+ def t2i_infer(self, prompts: List[str]):
104
+ output_path = []
105
+ for step in tqdm(range(0, len(prompts), self.batch_size)):
106
+ batch_prompt = prompts[step:step + self.batch_size]
107
+ image_tokens = torch.ones((len(batch_prompt), self.model.config.num_vq_tokens),
108
+ dtype=torch.long, device=self.device) * self.model.config.mask_token_id
109
+ input_ids, _ = self.uni_prompting((batch_prompt, image_tokens), 't2i_gen')
110
+ if self.guidance_scale > 0:
111
+ uncond_input_ids, _ = self.uni_prompting(([''] * len(batch_prompt), image_tokens), 't2i_gen')
112
+ attention_mask = create_attention_mask_predict_next(
113
+ torch.cat([input_ids, uncond_input_ids], dim=0),
114
+ pad_id=int(self.uni_prompting.sptids_dict['<|pad|>']),
115
+ soi_id=int(self.uni_prompting.sptids_dict['<|soi|>']),
116
+ eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>']),
117
+ rm_pad_in_image=True
118
+ )
119
+ else:
120
+ uncond_input_ids = None
121
+ attention_mask = create_attention_mask_predict_next(
122
+ input_ids,
123
+ pad_id=int(self.uni_prompting.sptids_dict['<|pad|>']),
124
+ soi_id=int(self.uni_prompting.sptids_dict['<|soi|>']),
125
+ eoi_id=int(self.uni_prompting.sptids_dict['<|eoi|>']),
126
+ rm_pad_in_image=True
127
+ )
128
+
129
+ mask_schedule = get_mask_chedule("cosine")
130
+ with torch.no_grad():
131
+ gen_token_ids = self.model.t2i_generate(
132
+ input_ids=input_ids,
133
+ uncond_input_ids=uncond_input_ids,
134
+ attention_mask=attention_mask,
135
+ guidance_scale=self.guidance_scale,
136
+ temperature=self.generation_temperature,
137
+ timesteps=self.generation_timesteps,
138
+ noise_schedule=mask_schedule,
139
+ )
140
+
141
+ gen_token_ids = torch.clamp(gen_token_ids, max=self.model.config.codebook_size - 1, min=0)
142
+ images = self.vq_model.decode_code(gen_token_ids)
143
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
144
+ images *= 255.0
145
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
146
+
147
+ # 保存图片
148
+ for idx, image in enumerate(images, start=1):
149
+ image = Image.fromarray(image)
150
+ # 使用时间戳和索引创建唯一的文件名
151
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
152
+ image_filename = f"{timestamp}_{step + idx}.jpg"
153
+ image_path = os.path.join(self.save_dir, image_filename)
154
+ image.save(image_path)
155
+ output_path.append(image_path)
156
+
157
+ return output_path
mcp_servers/fashion_vlm/main.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import numpy as np
4
+ from itertools import combinations
5
+ from typing import Dict, Any, List, Optional
6
+ import pandas as pd
7
+
8
+ from scipy import sparse
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+ import pickle
11
+ import uvicorn
12
+ import torch
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+ from transformers import AutoTokenizer, CLIPProcessor, CLIPModel
16
+ from mcp.server.fastmcp import FastMCP
17
+ from mcp.server.sse import SseServerTransport
18
+ from starlette.routing import Route, Mount
19
+ from starlette.applications import Starlette
20
+ from openai import AsyncOpenAI
21
+
22
+ from fashion_vlm_infer import FashionVLM
23
+ from prompting_utils import UniversalPrompting
24
+
25
+
26
+ # Load environment variables
27
+ load_dotenv()
28
+ FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT", "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec")
29
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
30
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
31
+ openai = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
32
+ VALID_CATEGORIES = [
33
+ 'Pants', 'Coats', 'Cross-body bags', 'Shirts', 'Hats & caps', 'Sneakers', 'Jeans', 'Boots', 'Dresses', 'Sandals',
34
+ 'T-shirts & vests', 'Knitwear', 'Skirts', 'Earrings', 'Hats', 'Sweaters & knitwear', 'Loafers', 'Ballet flats',
35
+ 'Espadrilles', 'Tote bags', 'Shoulder bags', 'Slides & flip flops', 'Pumps', 'Necklaces', 'Polo shirts', 'Suits',
36
+ 'Oxford shoes', 'Bracelets', 'Jackets', 'Tops', 'Rings', 'Mules', 'Luggage & holdalls', 'Brogues', 'Activewear',
37
+ 'Belts', 'Derby shoes', 'Mini bags', 'Watches', 'Backpacks', 'Denim', 'Laptop bags & briefcases', 'Clutch bags',
38
+ 'Clutches', 'Lingerie & Nightwear', 'Skiwear', 'Sunglasses', 'Ties & bow ties', 'Shorts', 'Scarves', 'Messenger bags'
39
+ ]
40
+
41
+
42
+ ###################################
43
+ #########Loading Data##############
44
+ ###################################
45
+ # Load item metadata
46
+ items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet").set_index("item_id")
47
+ outfits_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/outfits_lite.parquet").set_index("outfit_id")
48
+ users_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/users_lite.parquet").set_index("user_id")
49
+ image_paths = items_df["path"].to_dict()
50
+
51
+ ###################################
52
+ #########Loading Model#############
53
+ ###################################
54
+ # Load CLIP model and processor
55
+ print("Loading CLIP Model")
56
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
57
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
58
+ clip_model.eval()
59
+
60
+ print("Loading Fashion VLM params")
61
+
62
+ fashion_vlm = FashionVLM(
63
+ max_new_tokens=1000,
64
+ temperature=0.8,
65
+ top_k=1,
66
+ fashion_vlm_name='Anony100/FashionVLM',
67
+ save_dir="/mnt/d/PostDoc/fifth paper/code/FashionM3/generated_images"
68
+ )
69
+
70
+ resolution = 512
71
+ image_transform = transforms.Compose([
72
+ transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC),
73
+ transforms.CenterCrop((resolution, resolution)),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
76
+ ])
77
+
78
+ print("Loading tokenizer")
79
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-1_5', padding_side="left")
80
+ uni_prompting = UniversalPrompting(
81
+ tokenizer,
82
+ max_text_len=128,
83
+ special_tokens=(
84
+ "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"
85
+ ),
86
+ ignore_id=-100, cond_dropout_prob=0.1
87
+ )
88
+
89
+
90
+ class InteractionDataManager:
91
+ def __init__(self, users_df, outfits_df, items_df):
92
+ """
93
+ 初始化类,加载数据并设置基本参数
94
+
95
+ 参数:
96
+ - users_file: 用户数据文件路径 (parquet)
97
+ - outfits_file: Outfit 数据文件路径 (parquet)
98
+ - items_file: 单品数据文件路径 (parquet)
99
+ """
100
+ self.users_df = users_df
101
+ self.outfits_df = outfits_df
102
+ self.items_df = items_df
103
+
104
+ # 创建映射
105
+ self.item_id_to_index = {item_id: index for index, item_id in enumerate(self.items_df.index)}
106
+ self.index_to_item_id = {index: item_id for index, item_id in enumerate(self.items_df.index)}
107
+ self.user_id_to_index = {user_id: index for index, user_id in enumerate(self.users_df.index)}
108
+ self.index_to_user_id = {index: user_id for index, user_id in enumerate(self.users_df.index)}
109
+ self.outfit_ids_dict = self.outfits_df['item_ids'].to_dict() # get outfit's item ids from outfit id
110
+ self.item_category_dict = self.items_df['category'].to_dict() # get item's category from item id
111
+ self.item_subcategory_dict = self.items_df['subcategory'].to_dict() # get item's subcategory from item id
112
+ self.n_items = len(self.items_df)
113
+ self.n_users = len(self.users_df)
114
+
115
+ self.user_outfit_pairs = []
116
+ outfit_set = set(self.outfits_df.index)
117
+ for uid, user in self.users_df.iterrows():
118
+ oids = user.outfit_ids.split(",")
119
+ self.user_outfit_pairs.extend([(uid, oid) for oid in oids if oid in outfit_set])
120
+
121
+ # 预处理类别到物品ID的映射(使用groupby)
122
+ self.subcategory_to_items = self.items_df.groupby('subcategory').apply(lambda x: set(x.index)).to_dict()
123
+
124
+ # 预处理类别到物品索引的映射(优化查找效率)
125
+ self.subcategory_to_indices = {}
126
+ for subcategory, item_ids in self.subcategory_to_items.items():
127
+ self.subcategory_to_indices[subcategory] = set([self.item_id_to_index[item_id]
128
+ for item_id in item_ids
129
+ if item_id in self.item_id_to_index])
130
+
131
+ item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/item_matrix.npz'
132
+ try:
133
+ self.load_matrix('item', item_interaction_matrix_path)
134
+ except FileNotFoundError:
135
+ self.build_item_interaction_matrix()
136
+ self.save_matrix('item', item_interaction_matrix_path)
137
+
138
+ user_item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/user_item_matrix.npz'
139
+ try:
140
+ self.load_matrix('user_item', user_item_interaction_matrix_path)
141
+ except FileNotFoundError:
142
+ self.build_user_item_interaction_matrix()
143
+ self.save_matrix('user_item', user_item_interaction_matrix_path)
144
+
145
+ # 加载item clip features
146
+ with open(f"{FASHION_DATA_ROOT}/meta/clip_features.pkl", "rb") as f:
147
+ print("Loading Fashion Features...")
148
+ self.clip_features = pickle.load(f)
149
+ print("Loading Fashion Features Successfully")
150
+
151
+ # Prepare embeddings and item IDs
152
+ self.item_ids = list(self.clip_features.keys())
153
+ self.image_embeddings = np.array([self.clip_features[item_id]["image_embeds"] for item_id in self.item_ids])
154
+
155
+ def save_matrix(self, matrix_type, filepath):
156
+ """
157
+ 保存矩阵到文件
158
+
159
+ 参数:
160
+ - matrix_type: 'item' 或 'user_item',指定保存的矩阵类型
161
+ - filepath: 保存路径 (例如 'temp/item_matrix.npz')
162
+ """
163
+ if matrix_type == 'item':
164
+ matrix = self.item_interaction_matrix
165
+ elif matrix_type == 'user_item':
166
+ matrix = self.user_item_interaction_matrix
167
+ else:
168
+ raise ValueError("matrix_type must be 'item' or 'user_item'")
169
+
170
+ if matrix is None:
171
+ raise ValueError(f"{matrix_type} matrix has not been built yet.")
172
+
173
+ sparse.save_npz(filepath, matrix)
174
+ print(f"Saved {matrix_type} matrix to {filepath}")
175
+
176
+ def load_matrix(self, matrix_type, filepath):
177
+ """
178
+ 从文件加载矩阵
179
+
180
+ 参数:
181
+ - matrix_type: 'item' 或 'user_item',指定加载的矩阵类型
182
+ - filepath: 加载路径 (例如 'temp/item_matrix.npz')
183
+ """
184
+ if not os.path.exists(filepath):
185
+ raise FileNotFoundError(f"File {filepath} does not exist.")
186
+
187
+ matrix = sparse.load_npz(filepath)
188
+ if matrix_type == 'item':
189
+ self.item_interaction_matrix = matrix
190
+ elif matrix_type == 'user_item':
191
+ self.user_item_interaction_matrix = matrix
192
+ else:
193
+ raise ValueError("matrix_type must be 'item' or 'user_item'")
194
+
195
+ print(f"Loaded {matrix_type} matrix from {filepath}")
196
+ return matrix
197
+
198
+ def build_item_interaction_matrix(self):
199
+ """构建 Item-Item 交互矩阵"""
200
+ # 初始化单品交互矩阵
201
+ self.item_interaction_matrix = sparse.lil_matrix((self.n_items, self.n_items), dtype=int)
202
+
203
+ for index, outfit in self.outfits_df.iterrows():
204
+ item_ids = outfit['item_ids'].split(',')
205
+ # 记录 item 对的共现
206
+ for item_id1, item_id2 in combinations(item_ids, r=2):
207
+ if item_id1 in self.item_id_to_index and item_id2 in self.item_id_to_index:
208
+ idx1 = self.item_id_to_index[item_id1]
209
+ idx2 = self.item_id_to_index[item_id2]
210
+ self.item_interaction_matrix[idx1, idx2] += 1
211
+ self.item_interaction_matrix[idx2, idx1] += 1 # 无序对称
212
+
213
+ # 转换为 CSR 格式
214
+ self.item_interaction_matrix = self.item_interaction_matrix.tocsr()
215
+ return self.item_interaction_matrix
216
+
217
+ def build_user_item_interaction_matrix(self):
218
+ """构建 User-Item 交互矩阵"""
219
+ # 初始化用户-单品交互矩阵
220
+ self.user_item_interaction_matrix = sparse.lil_matrix((self.n_users, self.n_items), dtype=int)
221
+
222
+ for uid, user in self.users_df.iterrows():
223
+ oids = user["outfit_ids"].split(",")
224
+ outfits = self.outfits_df.loc[self.outfits_df.index.isin(oids)]
225
+ for oid, outfit in outfits.iterrows():
226
+ item_ids = outfit['item_ids'].split(',')
227
+ # 记录 user-item 对的出现
228
+ for iid in item_ids:
229
+ if iid in self.item_id_to_index:
230
+ uidx = self.user_id_to_index[uid]
231
+ iidx = self.item_id_to_index[iid]
232
+ self.user_item_interaction_matrix[uidx, iidx] += 1
233
+
234
+ # 转换为 CSR 格式
235
+ self.user_item_interaction_matrix = self.user_item_interaction_matrix.tocsr()
236
+ return self.user_item_interaction_matrix
237
+
238
+ def _process_interactions_for_category(
239
+ self,
240
+ matrix,
241
+ given_id,
242
+ category_indices,
243
+ id_to_index
244
+ ):
245
+ """
246
+ 处理单个实体与目标类别的交互
247
+
248
+ 参数:
249
+ - matrix: 交互矩阵
250
+ - given_id: 给定的实体ID(用户或物品)
251
+ - category_indices: 目标类别的物品索引集合
252
+
253
+ 返回:
254
+ - 交互列表,每个元素为一个包含item_id、interaction_count和score的字典
255
+ """
256
+ interactions = []
257
+
258
+ given_index = id_to_index[given_id]
259
+ row = matrix[given_index]
260
+
261
+ # 提取该行的非零元素
262
+ row_start = row.indptr[0]
263
+ row_end = row.indptr[1]
264
+ col_indices = row.indices[row_start:row_end]
265
+ data_values = row.data[row_start:row_end]
266
+
267
+ # 筛选出属于目标类别的物品
268
+ for col_idx, value in zip(col_indices, data_values):
269
+ # 检查是否为目标类别的物品
270
+ if col_idx in category_indices:
271
+ # 获取物品ID
272
+ output_id = self.index_to_item_id[col_idx]
273
+ interactions.append({
274
+ 'item_id': output_id,
275
+ 'interaction_count': int(value),
276
+ 'score': 0.0
277
+ })
278
+
279
+ return interactions
280
+
281
+ def get_item_category_interactions(
282
+ self,
283
+ target_category: str,
284
+ given_ids: List[str],
285
+ query_type='item', # item or user
286
+ top_k=None,
287
+ ):
288
+ """
289
+ 获取指定实体(用户或单品)与目标类别的所有交互情况
290
+
291
+ 参数:
292
+ - target_category: 待查询的subcategory
293
+ - given_ids: List of 目标类别
294
+ - query_type: 查询的类别, item或user
295
+ - top_k: 返回交互次数最多的前k个物品, 如果是None直接全部返回
296
+
297
+ 返回:
298
+ - 列表,包含与目标类别的交互统计信息,按交互次数排序
299
+ """
300
+ if query_type == 'item':
301
+ matrix = self.item_interaction_matrix
302
+ id_to_index = self.item_id_to_index
303
+ elif query_type == 'user':
304
+ matrix = self.user_item_interaction_matrix
305
+ id_to_index = self.user_id_to_index
306
+ else:
307
+ print(f'query_type must be either item or user but got {query_type}')
308
+ return []
309
+
310
+ # 收集所有交互记录
311
+ all_interactions = []
312
+ category = target_category
313
+ category_indices = self.subcategory_to_indices.get(category, set()) # 获取该类别的所有物品索引
314
+
315
+ # 获取该实体的所有交互
316
+ for given_id in given_ids:
317
+ interactions = self._process_interactions_for_category(
318
+ matrix, given_id, category_indices, id_to_index
319
+ )
320
+ # 将交互添加到结果列表
321
+ all_interactions.extend(interactions)
322
+
323
+ # 合并相同物品的交互次数
324
+ item_interactions = {}
325
+ for interaction in all_interactions:
326
+ item_id = interaction['item_id']
327
+ count = interaction['interaction_count']
328
+
329
+ if item_id in item_interactions:
330
+ item_interactions[item_id] += count
331
+ else:
332
+ item_interactions[item_id] = count
333
+
334
+ # 转换为结果格式
335
+ merged_interactions = [
336
+ {'item_id': item_id, 'interaction_count': count, 'score': 0.0}
337
+ for item_id, count in item_interactions.items()
338
+ ]
339
+
340
+ # 排序
341
+ if merged_interactions:
342
+ merged_interactions.sort(key=lambda x: x['interaction_count'], reverse=True)
343
+
344
+ # 截取top-k
345
+ if top_k and merged_interactions:
346
+ merged_interactions = merged_interactions[:top_k]
347
+
348
+ # 存储结果
349
+ return merged_interactions
350
+
351
+ def rank_by_similarity(self, item_interactions, user_interactions, beta=2.0):
352
+ """
353
+ 计算用户交互项与商品交互项的相似度并排序
354
+ """
355
+ def get_combined_features(feature_dict):
356
+ return (feature_dict['image_embeds'] + feature_dict['text_embeds']) / 2
357
+
358
+ if not item_interactions:
359
+ return user_interactions
360
+
361
+ item_feature_list = []
362
+ for item in item_interactions:
363
+ item_id = item['item_id']
364
+ if item_id not in self.clip_features:
365
+ raise ValueError(f"Didn't find clip feature of item with id: {item_id}")
366
+
367
+ item_features = get_combined_features(self.clip_features[item_id])
368
+ item_feature_list.append(item_features)
369
+
370
+ weights = np.array([x['interaction_count'] for x in item_interactions], dtype=np.float32)
371
+ weights = weights / np.sum(weights)
372
+ item_feature = np.sum(np.stack(item_feature_list, axis=0) * weights[:, np.newaxis], axis=0).reshape(1, -1)
373
+
374
+ max_count = max((user_item.get('interaction_count', 1) for user_item in user_interactions), default=1)
375
+ for user_item in user_interactions:
376
+ user_item_id = user_item['item_id']
377
+ if user_item_id not in self.clip_features:
378
+ raise ValueError(f"Didn't find clip feature of item with id: {user_item_id}")
379
+
380
+ user_item_features = get_combined_features(self.clip_features[user_item_id]).reshape(1, -1)
381
+ similarity = cosine_similarity(user_item_features, item_feature).item()
382
+ interaction_count = user_item['interaction_count']
383
+ count_factor = (interaction_count / max_count) * beta + 1
384
+ user_item['score'] = float(similarity) * count_factor
385
+
386
+ user_interactions.sort(key=lambda x: x.get('score', 0), reverse=True)
387
+ return user_interactions
388
+
389
+
390
+ data_manager = InteractionDataManager(users_df, outfits_df, items_df)
391
+ mcp = FastMCP('fashion-vlm-server')
392
+
393
+
394
+ async def compute_text_embedding(text: str) -> np.ndarray:
395
+ inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True)
396
+ with torch.no_grad():
397
+ text_embedding = clip_model.get_text_features(**inputs).numpy()
398
+ return text_embedding / np.linalg.norm(text_embedding, axis=1, keepdims=True)
399
+
400
+
401
+ async def find_most_similar_image(text_embedding: np.ndarray) -> Dict[str, Any]:
402
+ similarities = np.dot(data_manager.image_embeddings, text_embedding.T).flatten()
403
+ most_similar_idx = np.argmax(similarities)
404
+ most_similar_item_id = data_manager.item_ids[most_similar_idx]
405
+ return {
406
+ "image_path": image_paths[most_similar_item_id],
407
+ "similarity": float(similarities[most_similar_idx])
408
+ }
409
+
410
+
411
+ @mcp.tool()
412
+ async def retrieve_image(text: str) -> Dict[str, Any]:
413
+ """Search for the most similar fashion image based on a text description.
414
+
415
+ Args:
416
+ text (str): Text description of the fashion item to search.
417
+ """
418
+ print(f"Searching for {text}")
419
+ text_embedding = await compute_text_embedding(text)
420
+ return await find_most_similar_image(text_embedding)
421
+
422
+
423
+ def get_recommendation(query, image_path):
424
+ image = Image.open(image_path).convert("RGB")
425
+ image = image_transform(image).unsqueeze(0)
426
+ prompt = uni_prompting.text_tokenizer(['USER: \n' + query + ' ASSISTANT:'])['input_ids'][0]
427
+ prompt = torch.tensor(prompt).unsqueeze(0)
428
+ results = fashion_vlm.mmu_infer_tensor(image, prompt)
429
+ response = results[0]
430
+ return response
431
+
432
+
433
+ @mcp.tool()
434
+ async def fashion_recommend(query: str, image_path: str, target_category: str, user_id: Optional[str], list_of_items: List[str]) -> Dict[str, str]:
435
+ """Generate fashion recommendations based on a user's query and uploaded image.
436
+
437
+ This function processes the recommendation in the following steps:
438
+ 1. Retrieves the user's interaction history for the specified target category using user_id, target_category, and list_of_items.
439
+ 2. Summarizes the user's preferences for the target category by analyzing descriptions of previously interacted fashion items via a language model.
440
+ 3. Appends the summarized preference (as a single sentence) to the query and processes it with the uploaded image using a Fashion Vision-Language Model (VLM).
441
+ 4. Returns the personalized recommendation along with the derived user preference.
442
+
443
+ The target_category is inferred from the query (e.g., "I want a skirt ..." implies "Skirts") and must belong to a predefined list of valid categories.
444
+
445
+ Args:
446
+ query (str): A complete sentence explicitly stating the user's desired fashion item (e.g., "I want a skirt for summer"). Must be in English.
447
+ image_path (str): File path to the user-uploaded image, provided via the prompt.
448
+ target_category (str): The specific fashion category of interest, derived from the query (e.g., "Skirts"). Must be in valid categories.
449
+ user_id (str): Unique identifier for the user, provided via the prompt.
450
+ list_of_items (List[str]): List of item IDs used to filter the user's interaction history, provided via the prompt.
451
+
452
+ Returns:
453
+ Dict[str, str]: A dictionary containing:
454
+ - "recommendation": The personalized fashion recommendation text.
455
+ - "user_preference": The summarized user preference sentence.
456
+
457
+ Valid Categories:
458
+ ['Pants', 'Coats', 'Cross-body bags', 'Shirts', 'Hats & caps', 'Sneakers', 'Jeans', 'Boots', 'Dresses', 'Sandals',
459
+ 'T-shirts & vests', 'Knitwear', 'Skirts', 'Earrings', 'Hats', 'Sweaters & knitwear', 'Loafers', 'Ballet flats',
460
+ 'Espadrilles', 'Tote bags', 'Shoulder bags', 'Slides & flip flops', 'Pumps', 'Necklaces', 'Polo shirts', 'Suits',
461
+ 'Oxford shoes', 'Bracelets', 'Jackets', 'Tops', 'Rings', 'Mules', 'Luggage & holdalls', 'Brogues', 'Activewear',
462
+ 'Belts', 'Derby shoes', 'Mini bags', 'Watches', 'Backpacks', 'Denim', 'Laptop bags & briefcases', 'Clutch bags',
463
+ 'Clutches', 'Lingerie & Nightwear', 'Skiwear', 'Sunglasses', 'Ties & bow ties', 'Shorts', 'Scarves', 'Messenger bags']
464
+ """
465
+ def get_item(item_id: str) -> pd.Series:
466
+ return data_manager.items_df.loc[item_id]
467
+
468
+ # If no image uploaded, we should use fashion_recommend_without_image
469
+ if image_path == "":
470
+ recommendation = await fashion_recommend_without_image(query)
471
+ return {
472
+ "recommendation": recommendation,
473
+ "user_preference": ""
474
+ }
475
+
476
+ # If no user_id provided or user_id not found in database
477
+ if not user_id or user_id not in data_manager.user_id_to_index.keys():
478
+ return {
479
+ "recommendation": get_recommendation(query, image_path),
480
+ "user_preference": ""
481
+ }
482
+
483
+ user_preference = ""
484
+ if target_category in VALID_CATEGORIES:
485
+ user_interaction_result = data_manager.get_item_category_interactions(
486
+ target_category, [user_id], query_type='user'
487
+ )
488
+
489
+ if len(list_of_items) != 0:
490
+ item_interaction_result = data_manager.get_item_category_interactions(
491
+ target_category, list_of_items, query_type='item'
492
+ )
493
+ else:
494
+ item_interaction_result = []
495
+
496
+ descriptions_for_summary = []
497
+ historical_image_path = []
498
+
499
+ if len(user_interaction_result) >= 0:
500
+ user_interaction_result = data_manager.rank_by_similarity(
501
+ item_interaction_result,
502
+ user_interaction_result
503
+ )
504
+ for x in user_interaction_result[:5]:
505
+ item = get_item(x['item_id'])
506
+ descriptions_for_summary.append(item['gen_description'])
507
+ historical_image_path.append(os.path.abspath(item['path']))
508
+
509
+ if descriptions_for_summary:
510
+ user_message = f"Summary user's preference of {target_category} based on following descriptions of fashion items that user brought previously:"
511
+ for x in descriptions_for_summary:
512
+ user_message += f"\n{x}"
513
+ # Get summary using OpenAI API call
514
+ response = await openai.chat.completions.create(
515
+ model="gpt-4o-mini",
516
+ messages=[
517
+ {"role": "system", "content": f"You are a user preference summary assistant. Your response is limited in one sentence, staring at 'I prefer ...'"},
518
+ {"role": "user", "content": user_message}
519
+ ],
520
+ max_tokens=1000,
521
+ )
522
+ user_preference = response.choices[0].message.content
523
+ query += user_preference
524
+
525
+ return {
526
+ "recommendation": get_recommendation(query, image_path),
527
+ "user_preference": user_preference
528
+ }
529
+
530
+
531
+ @mcp.tool()
532
+ async def fashion_recommend_without_image(query: str) -> str:
533
+ """Recommend fashion items sorely based on user's query.
534
+ Output texts of fashion recommendation from model.
535
+
536
+ Args:
537
+ query (str): User's fashion related query including their recommendation request.
538
+ """
539
+ response = await openai.chat.completions.create(
540
+ model="gpt-4o-mini",
541
+ messages=[
542
+ {"role": "system", "content": "You are a fashion stylist. You should answer user's fashion-related question, especially about fashion recommendation."},
543
+ {"role": "user", "content": query}
544
+ ],
545
+ max_tokens=500,
546
+ )
547
+ return response.choices[0].message.content
548
+
549
+
550
+ @mcp.tool()
551
+ async def image_generate(text: str) -> str:
552
+ """"Generate image based on description. Output is path that saves generated image.
553
+
554
+ Args:
555
+ text (str): Descriptive text from user. Used for fashion image generation. English ONLY!
556
+ """
557
+ output_path = fashion_vlm.t2i_infer([text])[0]
558
+ output_path = os.path.abspath(output_path)
559
+ print(f"Generated image saved at {output_path}")
560
+ return output_path
561
+
562
+
563
+ # 获取内部 Server 对象
564
+ mcp_server = mcp._mcp_server
565
+ sse_transport = SseServerTransport("/messages/")
566
+
567
+
568
+ # 处理 SSE 连接
569
+ async def handle_sse(request):
570
+ print("Handling SSE connection")
571
+ async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
572
+ read_stream, write_stream = streams
573
+ await mcp_server.run(
574
+ read_stream,
575
+ write_stream,
576
+ mcp_server.create_initialization_options(),
577
+ )
578
+
579
+ # 定义路由
580
+ routes = [
581
+ Route("/sse", endpoint=handle_sse),
582
+ Mount("/messages/", app=sse_transport.handle_post_message),
583
+ ]
584
+
585
+ # 创建 Starlette 应用
586
+ starlette_app = Starlette(routes=routes)
587
+
588
+
589
+ if __name__ == "__main__":
590
+ print("Starting Fashion VLM server with HTTP and SSE...")
591
+ uvicorn.run(starlette_app, host="0.0.0.0", port=8000)
592
+
mcp_servers/fashion_vlm/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .modeling_showo import Showo
2
+ from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
3
+ from .sampling import *
4
+ from .clip_encoder import CLIPVisionTower
mcp_servers/fashion_vlm/models/clip_encoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+ class CLIPVisionTower(nn.Module):
7
+ def __init__(self, vision_tower):
8
+ super().__init__()
9
+
10
+ self.is_loaded = False
11
+
12
+ self.vision_tower_name = vision_tower
13
+ self.select_layer = -2
14
+ self.select_feature = "patch"
15
+ self.load_model()
16
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
17
+
18
+ def load_model(self, device_map=None):
19
+ if self.is_loaded:
20
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
21
+ return
22
+
23
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ def feature_select(self, image_forward_outs):
30
+ image_features = image_forward_outs.hidden_states[self.select_layer]
31
+ if self.select_feature == 'patch':
32
+ image_features = image_features[:, 1:]
33
+ elif self.select_feature == 'cls_patch':
34
+ image_features = image_features
35
+ else:
36
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
37
+ return image_features
38
+
39
+ @torch.no_grad()
40
+ def forward(self, images):
41
+ if type(images) is list:
42
+ image_features = []
43
+ for image in images:
44
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
46
+ image_features.append(image_feature)
47
+ else:
48
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dummy_feature(self):
55
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
+
57
+ @property
58
+ def dtype(self):
59
+ return self.vision_tower.dtype
60
+
61
+ @property
62
+ def device(self):
63
+ return self.vision_tower.device
64
+
65
+ @property
66
+ def config(self):
67
+ if self.is_loaded:
68
+ return self.vision_tower.config
69
+ else:
70
+ return self.cfg_only
71
+
72
+ @property
73
+ def hidden_size(self):
74
+ return self.config.hidden_size
75
+
76
+ @property
77
+ def num_patches_per_side(self):
78
+ return self.config.image_size // self.config.patch_size
79
+
80
+ @property
81
+ def num_patches(self):
82
+ return (self.config.image_size // self.config.patch_size) ** 2
83
+
84
+
85
+ class CLIPVisionTowerS2(CLIPVisionTower):
86
+ def __init__(self, vision_tower, args, delay_load=False):
87
+ super().__init__(vision_tower, args, delay_load)
88
+
89
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
90
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
91
+ self.s2_scales.sort()
92
+ self.s2_split_size = self.s2_scales[0]
93
+ self.s2_image_size = self.s2_scales[-1]
94
+
95
+ try:
96
+ from s2wrapper import forward as multiscale_forward
97
+ except ImportError:
98
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
99
+ self.multiscale_forward = multiscale_forward
100
+
101
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
102
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
103
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
104
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
105
+
106
+ def load_model(self, device_map=None):
107
+ if self.is_loaded:
108
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
109
+ return
110
+
111
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
112
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
113
+ self.vision_tower.requires_grad_(False)
114
+
115
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
116
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
117
+
118
+ self.is_loaded = True
119
+
120
+ @torch.no_grad()
121
+ def forward_feature(self, images):
122
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
123
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
124
+ return image_features
125
+
126
+ @torch.no_grad()
127
+ def forward(self, images):
128
+ if type(images) is list:
129
+ image_features = []
130
+ for image in images:
131
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
132
+ image_features.append(image_feature)
133
+ else:
134
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
135
+
136
+ return image_features
137
+
138
+ @property
139
+ def hidden_size(self):
140
+ return self.config.hidden_size * len(self.s2_scales)
mcp_servers/fashion_vlm/models/common_modules.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
3
+ """
4
+
5
+ import math
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+ from einops.layers.torch import Rearrange
14
+
15
+
16
+ def nonlinearity(x):
17
+ # swish
18
+ return x * torch.sigmoid(x)
19
+
20
+
21
+ def Normalize(in_channels):
22
+ return torch.nn.GroupNorm(
23
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
24
+ )
25
+
26
+
27
+ class Upsample(nn.Module):
28
+ def __init__(self, in_channels, with_conv):
29
+ super().__init__()
30
+ self.with_conv = with_conv
31
+ if self.with_conv:
32
+ self.conv = torch.nn.Conv2d(
33
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
38
+ if self.with_conv:
39
+ x = self.conv(x)
40
+ return x
41
+
42
+
43
+ class DepthToSpaceUpsample(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ ):
48
+ super().__init__()
49
+ conv = nn.Conv2d(in_channels, in_channels * 4, 1)
50
+
51
+ self.net = nn.Sequential(
52
+ conv,
53
+ nn.SiLU(),
54
+ Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
55
+ )
56
+
57
+ self.init_conv_(conv)
58
+
59
+ def init_conv_(self, conv):
60
+ o, i, h, w = conv.weight.shape
61
+ conv_weight = torch.empty(o // 4, i, h, w)
62
+ nn.init.kaiming_uniform_(conv_weight)
63
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
64
+
65
+ conv.weight.data.copy_(conv_weight)
66
+ nn.init.zeros_(conv.bias.data)
67
+
68
+ def forward(self, x):
69
+ out = self.net(x)
70
+ return out
71
+
72
+
73
+ class Downsample(nn.Module):
74
+ def __init__(self, in_channels, with_conv):
75
+ super().__init__()
76
+ self.with_conv = with_conv
77
+ if self.with_conv:
78
+ # no asymmetric padding in torch conv, must do it ourselves
79
+ self.conv = torch.nn.Conv2d(
80
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
81
+ )
82
+
83
+ def forward(self, x):
84
+ if self.with_conv:
85
+ pad = (0, 1, 0, 1)
86
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
87
+ x = self.conv(x)
88
+ else:
89
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
90
+ return x
91
+
92
+
93
+ def unpack_time(t, batch):
94
+ _, c, w, h = t.size()
95
+ out = torch.reshape(t, [batch, -1, c, w, h])
96
+ out = rearrange(out, "b t c h w -> b c t h w")
97
+ return out
98
+
99
+
100
+ def pack_time(t):
101
+ out = rearrange(t, "b c t h w -> b t c h w")
102
+ _, _, c, w, h = out.size()
103
+ return torch.reshape(out, [-1, c, w, h])
104
+
105
+
106
+ class TimeDownsample2x(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim,
110
+ dim_out=None,
111
+ kernel_size=3,
112
+ ):
113
+ super().__init__()
114
+ if dim_out is None:
115
+ dim_out = dim
116
+ self.time_causal_padding = (kernel_size - 1, 0)
117
+ self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
118
+
119
+ def forward(self, x):
120
+ x = rearrange(x, "b c t h w -> b h w c t")
121
+ b, h, w, c, t = x.size()
122
+ x = torch.reshape(x, [-1, c, t])
123
+
124
+ x = F.pad(x, self.time_causal_padding)
125
+ out = self.conv(x)
126
+
127
+ out = torch.reshape(out, [b, h, w, c, t])
128
+ out = rearrange(out, "b h w c t -> b c t h w")
129
+ out = rearrange(out, "b h w c t -> b c t h w")
130
+ return out
131
+
132
+
133
+ class TimeUpsample2x(nn.Module):
134
+ def __init__(self, dim, dim_out=None):
135
+ super().__init__()
136
+ if dim_out is None:
137
+ dim_out = dim
138
+ conv = nn.Conv1d(dim, dim_out * 2, 1)
139
+
140
+ self.net = nn.Sequential(
141
+ nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
142
+ )
143
+
144
+ self.init_conv_(conv)
145
+
146
+ def init_conv_(self, conv):
147
+ o, i, t = conv.weight.shape
148
+ conv_weight = torch.empty(o // 2, i, t)
149
+ nn.init.kaiming_uniform_(conv_weight)
150
+ conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
151
+
152
+ conv.weight.data.copy_(conv_weight)
153
+ nn.init.zeros_(conv.bias.data)
154
+
155
+ def forward(self, x):
156
+ x = rearrange(x, "b c t h w -> b h w c t")
157
+ b, h, w, c, t = x.size()
158
+ x = torch.reshape(x, [-1, c, t])
159
+
160
+ out = self.net(x)
161
+ out = out[:, :, 1:].contiguous()
162
+
163
+ out = torch.reshape(out, [b, h, w, c, t])
164
+ out = rearrange(out, "b h w c t -> b c t h w")
165
+ return out
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = Normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
176
+ )
177
+ self.k = torch.nn.Conv2d(
178
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
179
+ )
180
+ self.v = torch.nn.Conv2d(
181
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
182
+ )
183
+ self.proj_out = torch.nn.Conv2d(
184
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
185
+ )
186
+
187
+ def forward(self, x):
188
+ h_ = x
189
+ h_ = self.norm(h_)
190
+ q = self.q(h_)
191
+ k = self.k(h_)
192
+ v = self.v(h_)
193
+
194
+ # compute attention
195
+ b, c, h, w = q.shape
196
+ q = q.reshape(b, c, h * w)
197
+ q = q.permute(0, 2, 1) # b,hw,c
198
+ k = k.reshape(b, c, h * w) # b,c,hw
199
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
200
+ w_ = w_ * (int(c) ** (-0.5))
201
+ w_ = torch.nn.functional.softmax(w_, dim=2)
202
+
203
+ # attend to values
204
+ v = v.reshape(b, c, h * w)
205
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
206
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
207
+ h_ = h_.reshape(b, c, h, w)
208
+
209
+ h_ = self.proj_out(h_)
210
+
211
+ return x + h_
212
+
213
+
214
+ class TimeAttention(AttnBlock):
215
+ def forward(self, x, *args, **kwargs):
216
+ x = rearrange(x, "b c t h w -> b h w t c")
217
+ b, h, w, t, c = x.size()
218
+ x = torch.reshape(x, (-1, t, c))
219
+
220
+ x = super().forward(x, *args, **kwargs)
221
+
222
+ x = torch.reshape(x, [b, h, w, t, c])
223
+ return rearrange(x, "b h w t c -> b c t h w")
224
+
225
+
226
+ class Residual(nn.Module):
227
+ def __init__(self, fn: nn.Module):
228
+ super().__init__()
229
+ self.fn = fn
230
+
231
+ def forward(self, x, **kwargs):
232
+ return self.fn(x, **kwargs) + x
233
+
234
+
235
+ def cast_tuple(t, length=1):
236
+ return t if isinstance(t, tuple) else ((t,) * length)
237
+
238
+
239
+ class CausalConv3d(nn.Module):
240
+ def __init__(
241
+ self,
242
+ chan_in,
243
+ chan_out,
244
+ kernel_size: Union[int, Tuple[int, int, int]],
245
+ pad_mode="constant",
246
+ **kwargs
247
+ ):
248
+ super().__init__()
249
+ kernel_size = cast_tuple(kernel_size, 3)
250
+
251
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
252
+
253
+ dilation = kwargs.pop("dilation", 1)
254
+ stride = kwargs.pop("stride", 1)
255
+
256
+ self.pad_mode = pad_mode
257
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
258
+ height_pad = height_kernel_size // 2
259
+ width_pad = width_kernel_size // 2
260
+
261
+ self.time_pad = time_pad
262
+ self.time_causal_padding = (
263
+ width_pad,
264
+ width_pad,
265
+ height_pad,
266
+ height_pad,
267
+ time_pad,
268
+ 0,
269
+ )
270
+
271
+ stride = (stride, 1, 1)
272
+ dilation = (dilation, 1, 1)
273
+ self.conv = nn.Conv3d(
274
+ chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
275
+ )
276
+
277
+ def forward(self, x):
278
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
279
+
280
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
281
+ return self.conv(x)
282
+
283
+
284
+ def ResnetBlockCausal3D(
285
+ dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
286
+ ):
287
+ net = nn.Sequential(
288
+ Normalize(dim),
289
+ nn.SiLU(),
290
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
291
+ Normalize(dim),
292
+ nn.SiLU(),
293
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
294
+ )
295
+ return Residual(net)
296
+
297
+
298
+ class ResnetBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ *,
302
+ in_channels,
303
+ out_channels=None,
304
+ conv_shortcut=False,
305
+ dropout,
306
+ temb_channels=512
307
+ ):
308
+ super().__init__()
309
+ self.in_channels = in_channels
310
+ out_channels = in_channels if out_channels is None else out_channels
311
+ self.out_channels = out_channels
312
+ self.use_conv_shortcut = conv_shortcut
313
+
314
+ self.norm1 = Normalize(in_channels)
315
+ self.conv1 = torch.nn.Conv2d(
316
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
317
+ )
318
+ if temb_channels > 0:
319
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
320
+ else:
321
+ self.temb_proj = None
322
+ self.norm2 = Normalize(out_channels)
323
+ self.dropout = torch.nn.Dropout(dropout)
324
+ self.conv2 = torch.nn.Conv2d(
325
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
326
+ )
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = torch.nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = torch.nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x, temb):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+
343
+ if temb is not None:
344
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
345
+
346
+ h = self.norm2(h)
347
+ h = nonlinearity(h)
348
+ h = self.dropout(h)
349
+ h = self.conv2(h)
350
+
351
+ if self.in_channels != self.out_channels:
352
+ if self.use_conv_shortcut:
353
+ x = self.conv_shortcut(x)
354
+ else:
355
+ x = self.nin_shortcut(x)
356
+
357
+ return x + h
mcp_servers/fashion_vlm/models/misc.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Iterable,
8
+ List,
9
+ NamedTuple,
10
+ NewType,
11
+ Optional,
12
+ Sized,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+ try:
19
+ from typing import Literal
20
+ except ImportError:
21
+ from typing_extensions import Literal
22
+
23
+ # Tensor dtype
24
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
25
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
26
+
27
+ # Config type
28
+ from omegaconf import DictConfig
29
+
30
+ # PyTorch Tensor type
31
+ from torch import Tensor
32
+
33
+ # Runtime type checking decorator
34
+ from typeguard import typechecked as typechecker
35
+
36
+
37
+ def broadcast(tensor, src=0):
38
+ if not _distributed_available():
39
+ return tensor
40
+ else:
41
+ torch.distributed.broadcast(tensor, src=src)
42
+ return tensor
43
+
44
+ def _distributed_available():
45
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
46
+
47
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
48
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
49
+ if '--local-rank' in cfg:
50
+ del cfg['--local-rank']
51
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
52
+ scfg = OmegaConf.structured(fields(**cfg))
53
+ return scfg
mcp_servers/fashion_vlm/models/modeling_magvitv2.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from .common_modules import *
6
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
7
+ from .misc import *
8
+ import math
9
+
10
+ class Updateable:
11
+ def do_update_step(
12
+ self, epoch: int, global_step: int, on_load_weights: bool = False
13
+ ):
14
+ for attr in self.__dir__():
15
+ if attr.startswith("_"):
16
+ continue
17
+ try:
18
+ module = getattr(self, attr)
19
+ except:
20
+ continue # ignore attributes like property, which can't be retrived using getattr?
21
+ if isinstance(module, Updateable):
22
+ module.do_update_step(
23
+ epoch, global_step, on_load_weights=on_load_weights
24
+ )
25
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
26
+
27
+ def do_update_step_end(self, epoch: int, global_step: int):
28
+ for attr in self.__dir__():
29
+ if attr.startswith("_"):
30
+ continue
31
+ try:
32
+ module = getattr(self, attr)
33
+ except:
34
+ continue # ignore attributes like property, which can't be retrived using getattr?
35
+ if isinstance(module, Updateable):
36
+ module.do_update_step_end(epoch, global_step)
37
+ self.update_step_end(epoch, global_step)
38
+
39
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
40
+ # override this method to implement custom update logic
41
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
42
+ # as the models and tensors are not guarenteed to be on the same device
43
+ pass
44
+
45
+ def update_step_end(self, epoch: int, global_step: int):
46
+ pass
47
+
48
+ class VQGANEncoder(ModelMixin, ConfigMixin):
49
+ @dataclass
50
+ class Config:
51
+ ch: int = 128
52
+ ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
53
+ num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
54
+ attn_resolutions: List[int] = field(default_factory=lambda: [5])
55
+ dropout: float = 0.0
56
+ in_ch: int = 3
57
+ out_ch: int = 3
58
+ resolution: int = 256
59
+ z_channels: int = 13
60
+ double_z: bool = False
61
+
62
+ def __init__(self,
63
+ ch: int = 128,
64
+ ch_mult: List[int] = [1, 2, 2, 4, 4],
65
+ num_res_blocks: List[int] = [4, 3, 4, 3, 4],
66
+ attn_resolutions: List[int] = [5],
67
+ dropout: float = 0.0,
68
+ in_ch: int = 3,
69
+ out_ch: int = 3,
70
+ resolution: int = 256,
71
+ z_channels: int = 13,
72
+ double_z: bool = False):
73
+ super().__init__()
74
+ self.ch = ch
75
+ self.temb_ch = 0
76
+ self.num_resolutions = len(ch_mult)
77
+ self.num_res_blocks = num_res_blocks
78
+ self.resolution = resolution
79
+ self.in_ch = in_ch
80
+ # downsampling
81
+ self.conv_in = torch.nn.Conv2d(
82
+ self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
83
+ )
84
+
85
+ curr_res = self.resolution
86
+ in_ch_mult = (1,) + tuple(ch_mult)
87
+ self.down = nn.ModuleList()
88
+ for i_level in range(self.num_resolutions):
89
+ block = nn.ModuleList()
90
+ attn = nn.ModuleList()
91
+ block_in = self.ch * in_ch_mult[i_level]
92
+ block_out = self.ch * ch_mult[i_level]
93
+ for i_block in range(self.num_res_blocks[i_level]):
94
+ block.append(
95
+ ResnetBlock(
96
+ in_channels=block_in,
97
+ out_channels=block_out,
98
+ temb_channels=self.temb_ch,
99
+ dropout=dropout,
100
+ )
101
+ )
102
+ block_in = block_out
103
+ if curr_res in attn_resolutions:
104
+ attn.append(AttnBlock(block_in))
105
+ down = nn.Module()
106
+ down.block = block
107
+ down.attn = attn
108
+ if i_level != self.num_resolutions - 1:
109
+ down.downsample = Downsample(block_in, True)
110
+ curr_res = curr_res // 2
111
+ self.down.append(down)
112
+
113
+ # middle
114
+ self.mid = nn.Module()
115
+ self.mid.block_1 = ResnetBlock(
116
+ in_channels=block_in,
117
+ out_channels=block_in,
118
+ temb_channels=self.temb_ch,
119
+ dropout=dropout,
120
+ )
121
+ self.mid.attn_1 = AttnBlock(block_in)
122
+ self.mid.block_2 = ResnetBlock(
123
+ in_channels=block_in,
124
+ out_channels=block_in,
125
+ temb_channels=self.temb_ch,
126
+ dropout=dropout,
127
+ )
128
+
129
+
130
+ self.norm_out = Normalize(block_in)
131
+ self.conv_out = torch.nn.Conv2d(
132
+ block_in,
133
+ 2 * z_channels if double_z else z_channels,
134
+ kernel_size=3,
135
+ stride=1,
136
+ padding=1,
137
+ )
138
+
139
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
140
+ # for param in self.parameters():
141
+ # broadcast(param, src=0)
142
+
143
+ def forward(self, x):
144
+ # timestep embedding
145
+ temb = None
146
+
147
+ # downsampling
148
+ hs = [self.conv_in(x)]
149
+ for i_level in range(self.num_resolutions):
150
+ for i_block in range(self.num_res_blocks[i_level]):
151
+ h = self.down[i_level].block[i_block](hs[-1], temb)
152
+ if len(self.down[i_level].attn) > 0:
153
+ h = self.down[i_level].attn[i_block](h)
154
+ hs.append(h)
155
+ if i_level != self.num_resolutions - 1:
156
+ hs.append(self.down[i_level].downsample(hs[-1]))
157
+
158
+ # middle
159
+ h = hs[-1]
160
+ h = self.mid.block_1(h, temb)
161
+ h = self.mid.attn_1(h)
162
+ h = self.mid.block_2(h, temb)
163
+
164
+ # end
165
+ h = self.norm_out(h)
166
+ h = nonlinearity(h)
167
+ h = self.conv_out(h)
168
+ h = self.quant_conv(h)
169
+ return h
170
+
171
+
172
+ class LFQuantizer(nn.Module):
173
+ def __init__(self, num_codebook_entry: int = -1,
174
+ codebook_dim: int = 13,
175
+ beta: float = 0.25,
176
+ entropy_multiplier: float = 0.1,
177
+ commit_loss_multiplier: float = 0.1, ):
178
+ super().__init__()
179
+ self.codebook_size = 2 ** codebook_dim
180
+ print(
181
+ f"Look-up free quantizer with codebook size: {self.codebook_size}"
182
+ )
183
+ self.e_dim = codebook_dim
184
+ self.beta = beta
185
+
186
+ indices = torch.arange(self.codebook_size)
187
+
188
+ binary = (
189
+ indices.unsqueeze(1)
190
+ >> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
191
+ ) & 1
192
+
193
+ embedding = binary.float() * 2 - 1
194
+ self.register_buffer("embedding", embedding)
195
+ self.register_buffer(
196
+ "power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
197
+ )
198
+ self.commit_loss_multiplier = commit_loss_multiplier
199
+ self.entropy_multiplier = entropy_multiplier
200
+
201
+ def get_indices(self, z_q):
202
+ return (
203
+ (self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
204
+ .sum(1, keepdim=True)
205
+ .long()
206
+ )
207
+
208
+ def get_codebook_entry(self, indices, shape=None):
209
+ if shape is None:
210
+ h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
211
+ else:
212
+ h, w = shape
213
+ b, _ = indices.shape
214
+ indices = indices.reshape(-1)
215
+ z_q = self.embedding[indices]
216
+ z_q = z_q.view(b, h, w, -1)
217
+
218
+ # reshape back to match original input shape
219
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
220
+
221
+ return z_q
222
+
223
+ def forward(self, z, get_code=False):
224
+ """
225
+ Inputs the output of the encoder network z and maps it to a discrete
226
+ one-hot vector that is the index of the closest embedding vector e_j
227
+ z (continuous) -> z_q (discrete)
228
+ z.shape = (batch, channel, height, width)
229
+ quantization pipeline:
230
+ 1. get encoder input (B,C,H,W)
231
+ 2. flatten input to (B*H*W,C)
232
+ """
233
+ if get_code:
234
+ return self.get_codebook_entry(z)
235
+
236
+ # reshape z -> (batch, height, width, channel) and flatten
237
+ z = z.permute(0, 2, 3, 1).contiguous()
238
+ z_flattened = z.view(-1, self.e_dim)
239
+ ge_zero = (z_flattened > 0).float()
240
+ ones = torch.ones_like(z_flattened)
241
+ z_q = ones * ge_zero + -ones * (1 - ge_zero)
242
+
243
+ # preserve gradients
244
+ z_q = z_flattened + (z_q - z_flattened).detach()
245
+
246
+ # compute entropy loss
247
+ CatDist = torch.distributions.categorical.Categorical
248
+ logit = torch.stack(
249
+ [
250
+ -(z_flattened - torch.ones_like(z_q)).pow(2),
251
+ -(z_flattened - torch.ones_like(z_q) * -1).pow(2),
252
+ ],
253
+ dim=-1,
254
+ )
255
+ cat_dist = CatDist(logits=logit)
256
+ entropy = cat_dist.entropy().mean()
257
+ mean_prob = cat_dist.probs.mean(0)
258
+ mean_entropy = CatDist(probs=mean_prob).entropy().mean()
259
+
260
+ # compute loss for embedding
261
+ commit_loss = torch.mean(
262
+ (z_q.detach() - z_flattened) ** 2
263
+ ) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
264
+
265
+ # reshape back to match original input shape
266
+ z_q = z_q.view(z.shape)
267
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
268
+
269
+ return {
270
+ "z": z_q,
271
+ "quantizer_loss": commit_loss * self.commit_loss_multiplier,
272
+ "entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
273
+ "indices": self.get_indices(z_q),
274
+ }
275
+
276
+
277
+ class VQGANDecoder(ModelMixin, ConfigMixin):
278
+ def __init__(self, ch: int = 128,
279
+ ch_mult: List[int] = [1, 1, 2, 2, 4],
280
+ num_res_blocks: List[int] = [4, 4, 3, 4, 3],
281
+ attn_resolutions: List[int] = [5],
282
+ dropout: float = 0.0,
283
+ in_ch: int = 3,
284
+ out_ch: int = 3,
285
+ resolution: int = 256,
286
+ z_channels: int = 13,
287
+ double_z: bool = False):
288
+ super().__init__()
289
+ self.ch = ch
290
+ self.temb_ch = 0
291
+ self.num_resolutions = len(ch_mult)
292
+ self.num_res_blocks = num_res_blocks
293
+ self.resolution = resolution
294
+ self.in_ch = in_ch
295
+ self.give_pre_end = False
296
+
297
+ self.z_channels = z_channels
298
+ # compute in_ch_mult, block_in and curr_res at lowest res
299
+ in_ch_mult = (1,) + tuple(ch_mult)
300
+ block_in = ch * ch_mult[self.num_resolutions - 1]
301
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
302
+ self.z_shape = (1, z_channels, curr_res, curr_res)
303
+ print(
304
+ "Working with z of shape {} = {} dimensions.".format(
305
+ self.z_shape, np.prod(self.z_shape)
306
+ )
307
+ )
308
+
309
+ # z to block_in
310
+ self.conv_in = torch.nn.Conv2d(
311
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
312
+ )
313
+
314
+ # middle
315
+ self.mid = nn.Module()
316
+ self.mid.block_1 = ResnetBlock(
317
+ in_channels=block_in,
318
+ out_channels=block_in,
319
+ temb_channels=self.temb_ch,
320
+ dropout=dropout,
321
+ )
322
+ self.mid.attn_1 = AttnBlock(block_in)
323
+ self.mid.block_2 = ResnetBlock(
324
+ in_channels=block_in,
325
+ out_channels=block_in,
326
+ temb_channels=self.temb_ch,
327
+ dropout=dropout,
328
+ )
329
+
330
+ # upsampling
331
+ self.up = nn.ModuleList()
332
+ for i_level in reversed(range(self.num_resolutions)):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_out = ch * ch_mult[i_level]
336
+ for i_block in range(self.num_res_blocks[i_level]):
337
+ block.append(
338
+ ResnetBlock(
339
+ in_channels=block_in,
340
+ out_channels=block_out,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ )
344
+ )
345
+ block_in = block_out
346
+ if curr_res in attn_resolutions:
347
+ attn.append(AttnBlock(block_in))
348
+ up = nn.Module()
349
+ up.block = block
350
+ up.attn = attn
351
+ if i_level != 0:
352
+ up.upsample = Upsample(block_in, True)
353
+ curr_res = curr_res * 2
354
+ self.up.insert(0, up) # prepend to get consistent order
355
+
356
+ self.norm_out = Normalize(block_in)
357
+ self.conv_out = torch.nn.Conv2d(
358
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
359
+ )
360
+ self.post_quant_conv = torch.nn.Conv2d(
361
+ z_channels, z_channels, 1
362
+ )
363
+
364
+
365
+ def forward(self, z):
366
+ # assert z.shape[1:] == self.z_shape[1:]
367
+ self.last_z_shape = z.shape
368
+ # timestep embedding
369
+ temb = None
370
+ output = dict()
371
+ z = self.post_quant_conv(z)
372
+
373
+ # z to block_in
374
+ h = self.conv_in(z)
375
+
376
+ # middle
377
+ h = self.mid.block_1(h, temb)
378
+ h = self.mid.attn_1(h)
379
+ h = self.mid.block_2(h, temb)
380
+
381
+ # upsampling
382
+ for i_level in reversed(range(self.num_resolutions)):
383
+ for i_block in range(self.num_res_blocks[i_level]):
384
+ h = self.up[i_level].block[i_block](h, temb)
385
+ if len(self.up[i_level].attn) > 0:
386
+ h = self.up[i_level].attn[i_block](h)
387
+ if i_level != 0:
388
+ h = self.up[i_level].upsample(h)
389
+
390
+ # end
391
+ output["output"] = h
392
+ if self.give_pre_end:
393
+ return output
394
+
395
+ h = self.norm_out(h)
396
+ h = nonlinearity(h)
397
+ h = self.conv_out(h)
398
+ output["output"] = h
399
+ return output
400
+
401
+
402
+ class MAGVITv2(ModelMixin, ConfigMixin):
403
+ @register_to_config
404
+ def __init__(
405
+ self,
406
+ ):
407
+ super().__init__()
408
+
409
+ self.encoder = VQGANEncoder()
410
+ self.decoder = VQGANDecoder()
411
+ self.quantize = LFQuantizer()
412
+
413
+ def forward(self, pixel_values, return_loss=False):
414
+ pass
415
+
416
+ def encode(self, pixel_values, return_loss=False):
417
+ hidden_states = self.encoder(pixel_values)
418
+ quantized_states = self.quantize(hidden_states)['z']
419
+ codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
420
+ output = (quantized_states, codebook_indices)
421
+ return output
422
+
423
+ def get_code(self, pixel_values):
424
+ hidden_states = self.encoder(pixel_values)
425
+ codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
426
+
427
+ return codebook_indices
428
+
429
+ def decode_code(self, codebook_indices, shape=None):
430
+ z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
431
+
432
+ reconstructed_pixel_values = self.decoder(z_q)["output"]
433
+ return reconstructed_pixel_values
434
+
435
+
436
+ if __name__ == '__main__':
437
+ encoder = VQGANEncoder()
438
+ import ipdb
439
+ ipdb.set_trace()
440
+ print()
mcp_servers/fashion_vlm/models/modeling_showo.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 NUS Show Lab, HuggingFace.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from transformers import AutoConfig
19
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
20
+ from .sampling import cosine_schedule, mask_by_random_topk
21
+ from .phi import PhiForCausalLM
22
+
23
+
24
+ class Showo(ModelMixin, ConfigMixin):
25
+ _supports_gradient_checkpointing = True
26
+
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ w_clip_vit,
31
+ vocab_size,
32
+ llm_vocab_size,
33
+ llm_model_path='',
34
+ codebook_size=8192,
35
+ num_vq_tokens=256,
36
+ load_from_showo=True,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+ self.vocab_size = vocab_size
41
+ self.register_to_config(mask_token_id=vocab_size - 1)
42
+ if load_from_showo:
43
+ config = AutoConfig.from_pretrained(llm_model_path)
44
+ self.showo = PhiForCausalLM(config)
45
+ else:
46
+ self.showo = PhiForCausalLM.from_pretrained(llm_model_path, attn_implementation='sdpa')
47
+ self.showo.resize_token_embeddings(self.vocab_size)
48
+ self.output_size = self.vocab_size
49
+
50
+ if self.w_clip_vit:
51
+ self.mm_projector = torch.nn.Sequential(
52
+ torch.nn.Linear(1024, 2048),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(2048, 2048)
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ input_ids,
60
+ input_embeddings=None,
61
+ attention_mask=None,
62
+ labels=None,
63
+ label_smoothing=0.0,
64
+ batch_size_t2i=0,
65
+ batch_size_lm=0,
66
+ batch_size_mmu=0,
67
+ max_seq_length=128,
68
+ labels_mask_text=None,
69
+ labels_mask_image=None,
70
+ **kwargs,
71
+ ):
72
+
73
+ if input_embeddings is None:
74
+ logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
75
+ else:
76
+ logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
77
+
78
+ if labels is not None:
79
+ # 1. Mask token prediction (discrete diffusion) for image generation
80
+ # Note that, max_seq_length indicates the maximum number of text tokens, maybe a bit confused.
81
+ loss_t2i = F.cross_entropy(
82
+ logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
83
+ labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
84
+ )
85
+
86
+ # 2. Next token prediction for language modeling
87
+ loss_lm = F.cross_entropy(
88
+ logits[batch_size_t2i:batch_size_t2i + batch_size_lm, :-1].contiguous().view(-1, self.output_size),
89
+ labels[batch_size_t2i:batch_size_t2i + batch_size_lm, 1:].contiguous().view(-1), ignore_index=-100,
90
+ )
91
+ # loss_lm = torch.tensor(0.0, device=logits.device)
92
+
93
+ # 3. Next token prediction for captioning/multimodal understanding
94
+ loss_mmu = F.cross_entropy(
95
+ logits[-batch_size_mmu:, :-1].contiguous().view(-1, self.output_size),
96
+ labels[-batch_size_mmu:, 1:].contiguous().view(-1), ignore_index=-100,
97
+ )
98
+
99
+ return logits, loss_t2i, loss_lm, loss_mmu
100
+
101
+ return logits
102
+
103
+ def t2i_generate(
104
+ self,
105
+ input_ids: torch.LongTensor = None,
106
+ uncond_input_ids: torch.LongTensor = None,
107
+ attention_mask=None,
108
+ temperature=1.0,
109
+ timesteps=18, # ideal number of steps is 18 in maskgit paper
110
+ guidance_scale=0,
111
+ noise_schedule=cosine_schedule,
112
+ generator: torch.Generator = None,
113
+ ):
114
+ # begin with all image token ids masked
115
+ mask_token_id = self.config.mask_token_id
116
+ num_vq_tokens = self.config.num_vq_tokens
117
+ num_new_special_tokens = 10
118
+ llm_vocab_size = self.config.llm_vocab_size
119
+ max_seq_length = 381
120
+
121
+ input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
122
+ input_ids_minus_lm_vocab_size = torch.where(
123
+ input_ids_minus_lm_vocab_size == mask_token_id,
124
+ mask_token_id,
125
+ input_ids_minus_lm_vocab_size - llm_vocab_size - num_new_special_tokens
126
+ )
127
+
128
+ # for classifier-free guidance
129
+ if uncond_input_ids is not None:
130
+ uncond_prefix = uncond_input_ids[:, :max_seq_length + 1]
131
+
132
+ for step in range(timesteps):
133
+ if uncond_input_ids is not None and guidance_scale > 0:
134
+ uncond_input_ids = torch.cat(
135
+ [uncond_prefix, input_ids[:, max_seq_length + 1:]], dim=1)
136
+ model_input = torch.cat([input_ids, uncond_input_ids])
137
+ cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
138
+ # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
139
+ # it seems that muse has a different cfg setting
140
+ logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
141
+ logits = logits[:, -(num_vq_tokens + 1):-1, llm_vocab_size + num_new_special_tokens:-1]
142
+ else:
143
+ logits = self(input_ids, attention_mask=attention_mask)
144
+ logits = logits[:, -(num_vq_tokens + 1):-1, llm_vocab_size + num_new_special_tokens:-1]
145
+
146
+ probs = logits.softmax(dim=-1)
147
+ sampled = probs.reshape(-1, logits.size(-1))
148
+ sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
149
+
150
+ unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
151
+ sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
152
+ # Defines the mask ratio for the next round. The number to mask out is
153
+ # determined by mask_ratio * unknown_number_in_the_beginning.
154
+ ratio = 1.0 * (step + 1) / timesteps
155
+ mask_ratio = noise_schedule(torch.tensor(ratio))
156
+ # Computes the probabilities of each selected tokens.
157
+ selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
158
+ selected_probs = selected_probs.squeeze(-1)
159
+
160
+ # Ignores the tokens given in the input by overwriting their confidence.
161
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
162
+ # Gets mask lens for each sample in the batch according to the mask ratio.
163
+ mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
164
+ # Keeps at least one of prediction in this round and also masks out at least
165
+ # one and for the next iteration
166
+ mask_len = torch.max(
167
+ torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
168
+ )
169
+ # Adds noise for randomness
170
+ temperature = temperature * (1.0 - ratio)
171
+ masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
172
+ # Masks tokens with lower confidence.
173
+ input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
174
+ sampled_ids + llm_vocab_size
175
+ + num_new_special_tokens)
176
+ input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
177
+
178
+ return sampled_ids
179
+
180
+ @torch.no_grad()
181
+ def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
182
+ """
183
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
184
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
185
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
186
+ """
187
+ try:
188
+ device = idx.device
189
+ except:
190
+ device = input_embeddings.device
191
+
192
+ result = []
193
+ for _ in range(max_new_tokens):
194
+ # if the sequence context is growing too long we must crop it at block_size
195
+ # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
196
+ # forward the model to get the logits for the index in the sequence
197
+ # logits, _ = self(idx_cond)
198
+ logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
199
+
200
+ L = attention_mask.shape[-1]
201
+ attention_mask = attention_mask.squeeze()
202
+ attention_mask_a = torch.hstack(
203
+ [
204
+ attention_mask, # L, L
205
+ torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
206
+ ]
207
+ )
208
+ attention_mask_b = torch.vstack(
209
+ [
210
+ attention_mask_a, # L, L+1
211
+ torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
212
+ ]
213
+ )
214
+ attention_mask = attention_mask_b
215
+
216
+ # pluck the logits at the final step and scale by desired temperature
217
+ logits = logits[:, -1, :] / temperature
218
+ # optionally crop the logits to only the top k options
219
+ if top_k is not None:
220
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
221
+ logits[logits < v[:, [-1]]] = -float('Inf')
222
+ # apply softmax to convert logits to (normalized) probabilities
223
+ probs = F.softmax(logits, dim=-1)
224
+ # sample from the distribution
225
+ idx_next = torch.multinomial(probs, num_samples=1)
226
+ result.append(idx_next[0][0])
227
+ # append sampled index to the running sequence and continue
228
+ if self.config.w_clip_vit:
229
+ idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
230
+ input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
231
+ else:
232
+ idx = torch.cat((idx, idx_next), dim=1)
233
+
234
+ if eot_token is not None and idx_next.cpu() == eot_token:
235
+ break
236
+
237
+ return result
mcp_servers/fashion_vlm/models/modeling_utils.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from functools import partial
24
+ from pathlib import Path
25
+ from typing import Any, Callable, List, Optional, Tuple, Union
26
+
27
+ import safetensors
28
+ import torch
29
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
30
+ from huggingface_hub.utils import validate_hf_hub_args
31
+ from torch import Tensor, nn
32
+
33
+ from diffusers import __version__
34
+ from diffusers.utils import (
35
+ FLAX_WEIGHTS_NAME,
36
+ SAFE_WEIGHTS_INDEX_NAME,
37
+ WEIGHTS_INDEX_NAME,
38
+ _add_variant,
39
+ _get_checkpoint_shard_files,
40
+ _get_model_file,
41
+ deprecate,
42
+ is_accelerate_available,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+
47
+ CONFIG_NAME = "config.json"
48
+ WEIGHTS_NAME = "pytorch_model.bin"
49
+ SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
50
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
51
+
52
+ from diffusers.utils.hub_utils import (
53
+ PushToHubMixin,
54
+ load_or_create_model_card,
55
+ populate_model_card,
56
+ )
57
+ from diffusers.models.model_loading_utils import (
58
+ _determine_device_map,
59
+ _fetch_index_file,
60
+ _load_state_dict_into_model,
61
+ load_model_dict_into_meta,
62
+ load_state_dict,
63
+ )
64
+
65
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
66
+
67
+ logger = logging.get_logger(__name__)
68
+
69
+ _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
70
+
71
+
72
+ if is_torch_version(">=", "1.9.0"):
73
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
74
+ else:
75
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
76
+
77
+
78
+ if is_accelerate_available():
79
+ import accelerate
80
+
81
+
82
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
83
+ try:
84
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
85
+ return next(parameters_and_buffers).device
86
+ except StopIteration:
87
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
88
+
89
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
90
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
91
+ return tuples
92
+
93
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
94
+ first_tuple = next(gen)
95
+ return first_tuple[1].device
96
+
97
+
98
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
99
+ try:
100
+ params = tuple(parameter.parameters())
101
+ if len(params) > 0:
102
+ return params[0].dtype
103
+
104
+ buffers = tuple(parameter.buffers())
105
+ if len(buffers) > 0:
106
+ return buffers[0].dtype
107
+
108
+ except StopIteration:
109
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
110
+
111
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
112
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
113
+ return tuples
114
+
115
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
116
+ first_tuple = next(gen)
117
+ return first_tuple[1].dtype
118
+
119
+
120
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
121
+ r"""
122
+ Base class for all models.
123
+
124
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
125
+ saving models.
126
+
127
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
128
+ """
129
+
130
+ config_name = CONFIG_NAME
131
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
132
+ _supports_gradient_checkpointing = False
133
+ _keys_to_ignore_on_load_unexpected = None
134
+ _no_split_modules = None
135
+
136
+ def __init__(self):
137
+ super().__init__()
138
+
139
+ def __getattr__(self, name: str) -> Any:
140
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
141
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
142
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
143
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
144
+ """
145
+
146
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
147
+ is_attribute = name in self.__dict__
148
+
149
+ if is_in_config and not is_attribute:
150
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
151
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
152
+ return self._internal_dict[name]
153
+
154
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
155
+ return super().__getattr__(name)
156
+
157
+ @property
158
+ def is_gradient_checkpointing(self) -> bool:
159
+ """
160
+ Whether gradient checkpointing is activated for this model or not.
161
+ """
162
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
163
+
164
+ def enable_gradient_checkpointing(self) -> None:
165
+ """
166
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
167
+ *checkpoint activations* in other frameworks).
168
+ """
169
+ if not self._supports_gradient_checkpointing:
170
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
171
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
172
+
173
+ def disable_gradient_checkpointing(self) -> None:
174
+ """
175
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
176
+ *checkpoint activations* in other frameworks).
177
+ """
178
+ if self._supports_gradient_checkpointing:
179
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
180
+
181
+ def set_use_npu_flash_attention(self, valid: bool) -> None:
182
+ r"""
183
+ Set the switch for the npu flash attention.
184
+ """
185
+
186
+ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
187
+ if hasattr(module, "set_use_npu_flash_attention"):
188
+ module.set_use_npu_flash_attention(valid)
189
+
190
+ for child in module.children():
191
+ fn_recursive_set_npu_flash_attention(child)
192
+
193
+ for module in self.children():
194
+ if isinstance(module, torch.nn.Module):
195
+ fn_recursive_set_npu_flash_attention(module)
196
+
197
+ def enable_npu_flash_attention(self) -> None:
198
+ r"""
199
+ Enable npu flash attention from torch_npu
200
+
201
+ """
202
+ self.set_use_npu_flash_attention(True)
203
+
204
+ def disable_npu_flash_attention(self) -> None:
205
+ r"""
206
+ disable npu flash attention from torch_npu
207
+
208
+ """
209
+ self.set_use_npu_flash_attention(False)
210
+
211
+ def set_use_memory_efficient_attention_xformers(
212
+ self, valid: bool, attention_op: Optional[Callable] = None
213
+ ) -> None:
214
+ # Recursively walk through all the children.
215
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
216
+ # gets the message
217
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
218
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
219
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
220
+
221
+ for child in module.children():
222
+ fn_recursive_set_mem_eff(child)
223
+
224
+ for module in self.children():
225
+ if isinstance(module, torch.nn.Module):
226
+ fn_recursive_set_mem_eff(module)
227
+
228
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
229
+ r"""
230
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
231
+
232
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
233
+ inference. Speed up during training is not guaranteed.
234
+
235
+ <Tip warning={true}>
236
+
237
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
238
+ precedent.
239
+
240
+ </Tip>
241
+
242
+ Parameters:
243
+ attention_op (`Callable`, *optional*):
244
+ Override the default `None` operator for use as `op` argument to the
245
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
246
+ function of xFormers.
247
+
248
+ Examples:
249
+
250
+ ```py
251
+ >>> import torch
252
+ >>> from diffusers import UNet2DConditionModel
253
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
254
+
255
+ >>> model = UNet2DConditionModel.from_pretrained(
256
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
257
+ ... )
258
+ >>> model = model.to("cuda")
259
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
260
+ ```
261
+ """
262
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
263
+
264
+ def disable_xformers_memory_efficient_attention(self) -> None:
265
+ r"""
266
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
267
+ """
268
+ self.set_use_memory_efficient_attention_xformers(False)
269
+
270
+ def save_pretrained(
271
+ self,
272
+ save_directory: Union[str, os.PathLike],
273
+ is_main_process: bool = True,
274
+ save_function: Optional[Callable] = None,
275
+ safe_serialization: bool = True,
276
+ variant: Optional[str] = None,
277
+ max_shard_size: Union[int, str] = "10GB",
278
+ push_to_hub: bool = False,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Save a model and its configuration file to a directory so that it can be reloaded using the
283
+ [`~models.ModelMixin.from_pretrained`] class method.
284
+
285
+ Arguments:
286
+ save_directory (`str` or `os.PathLike`):
287
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
288
+ is_main_process (`bool`, *optional*, defaults to `True`):
289
+ Whether the process calling this is the main process or not. Useful during distributed training and you
290
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
291
+ process to avoid race conditions.
292
+ save_function (`Callable`):
293
+ The function to use to save the state dictionary. Useful during distributed training when you need to
294
+ replace `torch.save` with another method. Can be configured with the environment variable
295
+ `DIFFUSERS_SAVE_MODE`.
296
+ safe_serialization (`bool`, *optional*, defaults to `True`):
297
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
298
+ variant (`str`, *optional*):
299
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
300
+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
301
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
302
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
303
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
304
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
305
+ This is to establish a common default size for this argument across different libraries in the Hugging
306
+ Face ecosystem (`transformers`, and `accelerate`, for example).
307
+ push_to_hub (`bool`, *optional*, defaults to `False`):
308
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
309
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
310
+ namespace).
311
+ kwargs (`Dict[str, Any]`, *optional*):
312
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
313
+ """
314
+ if os.path.isfile(save_directory):
315
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
316
+ return
317
+
318
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
+ weights_name = _add_variant(weights_name, variant)
320
+ weight_name_split = weights_name.split(".")
321
+ if len(weight_name_split) in [2, 3]:
322
+ weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
323
+ else:
324
+ raise ValueError(f"Invalid {weights_name} provided.")
325
+
326
+ os.makedirs(save_directory, exist_ok=True)
327
+
328
+ if push_to_hub:
329
+ commit_message = kwargs.pop("commit_message", None)
330
+ private = kwargs.pop("private", False)
331
+ create_pr = kwargs.pop("create_pr", False)
332
+ token = kwargs.pop("token", None)
333
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
334
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
335
+
336
+ # Only save the model itself if we are using distributed training
337
+ model_to_save = self
338
+
339
+ # Attach architecture to the config
340
+ # Save the config
341
+ if is_main_process:
342
+ model_to_save.save_config(save_directory)
343
+
344
+ # Save the model
345
+ state_dict = model_to_save.state_dict()
346
+
347
+ # Save the model
348
+ state_dict_split = split_torch_state_dict_into_shards(
349
+ state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
350
+ )
351
+
352
+ # Clean the folder from a previous save
353
+ if is_main_process:
354
+ for filename in os.listdir(save_directory):
355
+ if filename in state_dict_split.filename_to_tensors.keys():
356
+ continue
357
+ full_filename = os.path.join(save_directory, filename)
358
+ if not os.path.isfile(full_filename):
359
+ continue
360
+ weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
361
+ weights_without_ext = weights_without_ext.replace("{suffix}", "")
362
+ filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
363
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
364
+ if (
365
+ filename.startswith(weights_without_ext)
366
+ and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
367
+ ):
368
+ os.remove(full_filename)
369
+
370
+ for filename, tensors in state_dict_split.filename_to_tensors.items():
371
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
372
+ filepath = os.path.join(save_directory, filename)
373
+ if safe_serialization:
374
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
375
+ # joyfulness), but for now this enough.
376
+ safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
377
+ else:
378
+ torch.save(shard, filepath)
379
+
380
+ if state_dict_split.is_sharded:
381
+ index = {
382
+ "metadata": state_dict_split.metadata,
383
+ "weight_map": state_dict_split.tensor_to_filename,
384
+ }
385
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
386
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
387
+ # Save the index as well
388
+ with open(save_index_file, "w", encoding="utf-8") as f:
389
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
390
+ f.write(content)
391
+ logger.info(
392
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
393
+ f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
394
+ f"index located at {save_index_file}."
395
+ )
396
+ else:
397
+ path_to_weights = os.path.join(save_directory, weights_name)
398
+ logger.info(f"Model weights saved in {path_to_weights}")
399
+
400
+ if push_to_hub:
401
+ # Create a new empty model card and eventually tag it
402
+ model_card = load_or_create_model_card(repo_id, token=token)
403
+ model_card = populate_model_card(model_card)
404
+ model_card.save(Path(save_directory, "README.md").as_posix())
405
+
406
+ self._upload_folder(
407
+ save_directory,
408
+ repo_id,
409
+ token=token,
410
+ commit_message=commit_message,
411
+ create_pr=create_pr,
412
+ )
413
+
414
+ @classmethod
415
+ @validate_hf_hub_args
416
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
417
+ r"""
418
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
419
+
420
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
421
+ train the model, set it back in training mode with `model.train()`.
422
+
423
+ Parameters:
424
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
425
+ Can be either:
426
+
427
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
428
+ the Hub.
429
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
430
+ with [`~ModelMixin.save_pretrained`].
431
+
432
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
433
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
434
+ is not used.
435
+ torch_dtype (`str` or `torch.dtype`, *optional*):
436
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
437
+ dtype is automatically derived from the model's weights.
438
+ force_download (`bool`, *optional*, defaults to `False`):
439
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
440
+ cached versions if they exist.
441
+ proxies (`Dict[str, str]`, *optional*):
442
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
443
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
444
+ output_loading_info (`bool`, *optional*, defaults to `False`):
445
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
446
+ local_files_only(`bool`, *optional*, defaults to `False`):
447
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
448
+ won't be downloaded from the Hub.
449
+ token (`str` or *bool*, *optional*):
450
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
451
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
452
+ revision (`str`, *optional*, defaults to `"main"`):
453
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
454
+ allowed by Git.
455
+ from_flax (`bool`, *optional*, defaults to `False`):
456
+ Load the model weights from a Flax checkpoint save file.
457
+ subfolder (`str`, *optional*, defaults to `""`):
458
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
459
+ mirror (`str`, *optional*):
460
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
461
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
462
+ information.
463
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
464
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
465
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
466
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
467
+
468
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
469
+ more information about each option see [designing a device
470
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
471
+ max_memory (`Dict`, *optional*):
472
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
473
+ each GPU and the available CPU RAM if unset.
474
+ offload_folder (`str` or `os.PathLike`, *optional*):
475
+ The path to offload weights if `device_map` contains the value `"disk"`.
476
+ offload_state_dict (`bool`, *optional*):
477
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
478
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
479
+ when there is some disk offload.
480
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
481
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
482
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
483
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
484
+ argument to `True` will raise an error.
485
+ variant (`str`, *optional*):
486
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
487
+ loading `from_flax`.
488
+ use_safetensors (`bool`, *optional*, defaults to `None`):
489
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
490
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
491
+ weights. If set to `False`, `safetensors` weights are not loaded.
492
+
493
+ <Tip>
494
+
495
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
496
+ `huggingface-cli login`. You can also activate the special
497
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
498
+ firewalled environment.
499
+
500
+ </Tip>
501
+
502
+ Example:
503
+
504
+ ```py
505
+ from diffusers import UNet2DConditionModel
506
+
507
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
508
+ ```
509
+
510
+ If you get the error message below, you need to finetune the weights for your downstream task:
511
+
512
+ ```bash
513
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
514
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
515
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
516
+ ```
517
+ """
518
+ cache_dir = kwargs.pop("cache_dir", None)
519
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
520
+ force_download = kwargs.pop("force_download", False)
521
+ from_flax = kwargs.pop("from_flax", False)
522
+ proxies = kwargs.pop("proxies", None)
523
+ output_loading_info = kwargs.pop("output_loading_info", False)
524
+ local_files_only = kwargs.pop("local_files_only", None)
525
+ token = kwargs.pop("token", None)
526
+ revision = kwargs.pop("revision", None)
527
+ torch_dtype = kwargs.pop("torch_dtype", None)
528
+ subfolder = kwargs.pop("subfolder", None)
529
+ device_map = kwargs.pop("device_map", None)
530
+ max_memory = kwargs.pop("max_memory", None)
531
+ offload_folder = kwargs.pop("offload_folder", None)
532
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
533
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
534
+ variant = kwargs.pop("variant", None)
535
+ use_safetensors = kwargs.pop("use_safetensors", None)
536
+
537
+ allow_pickle = False
538
+ if use_safetensors is None:
539
+ use_safetensors = True
540
+ allow_pickle = True
541
+
542
+ if low_cpu_mem_usage and not is_accelerate_available():
543
+ low_cpu_mem_usage = False
544
+ logger.warning(
545
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
546
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
547
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
548
+ " install accelerate\n```\n."
549
+ )
550
+
551
+ if device_map is not None and not is_accelerate_available():
552
+ raise NotImplementedError(
553
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
554
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
555
+ )
556
+
557
+ # Check if we can handle device_map and dispatching the weights
558
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
559
+ raise NotImplementedError(
560
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
561
+ " `device_map=None`."
562
+ )
563
+
564
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
565
+ raise NotImplementedError(
566
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
567
+ " `low_cpu_mem_usage=False`."
568
+ )
569
+
570
+ if low_cpu_mem_usage is False and device_map is not None:
571
+ raise ValueError(
572
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
573
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
574
+ )
575
+
576
+ # change device_map into a map if we passed an int, a str or a torch.device
577
+ if isinstance(device_map, torch.device):
578
+ device_map = {"": device_map}
579
+ elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
580
+ try:
581
+ device_map = {"": torch.device(device_map)}
582
+ except RuntimeError:
583
+ raise ValueError(
584
+ "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
585
+ f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
586
+ )
587
+ elif isinstance(device_map, int):
588
+ if device_map < 0:
589
+ raise ValueError(
590
+ "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
591
+ )
592
+ else:
593
+ device_map = {"": device_map}
594
+
595
+ if device_map is not None:
596
+ if low_cpu_mem_usage is None:
597
+ low_cpu_mem_usage = True
598
+ elif not low_cpu_mem_usage:
599
+ raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
600
+
601
+ if low_cpu_mem_usage:
602
+ if device_map is not None and not is_torch_version(">=", "1.10"):
603
+ # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
604
+ raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
605
+
606
+ # Load config if we don't provide a configuration
607
+ config_path = pretrained_model_name_or_path
608
+
609
+ user_agent = {
610
+ "diffusers": __version__,
611
+ "file_type": "model",
612
+ "framework": "pytorch",
613
+ }
614
+
615
+ # load config
616
+ config, unused_kwargs, commit_hash = cls.load_config(
617
+ config_path,
618
+ cache_dir=cache_dir,
619
+ return_unused_kwargs=True,
620
+ return_commit_hash=True,
621
+ force_download=force_download,
622
+ proxies=proxies,
623
+ local_files_only=local_files_only,
624
+ token=token,
625
+ revision=revision,
626
+ subfolder=subfolder,
627
+ user_agent=user_agent,
628
+ **kwargs,
629
+ )
630
+
631
+ # Determine if we're loading from a directory of sharded checkpoints.
632
+ is_sharded = False
633
+ index_file = None
634
+ is_local = os.path.isdir(pretrained_model_name_or_path)
635
+ index_file = _fetch_index_file(
636
+ is_local=is_local,
637
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
638
+ subfolder=subfolder or "",
639
+ use_safetensors=use_safetensors,
640
+ cache_dir=cache_dir,
641
+ variant=variant,
642
+ force_download=force_download,
643
+ proxies=proxies,
644
+ local_files_only=local_files_only,
645
+ token=token,
646
+ revision=revision,
647
+ user_agent=user_agent,
648
+ commit_hash=commit_hash,
649
+ )
650
+ if index_file is not None and index_file.is_file():
651
+ is_sharded = True
652
+
653
+ if is_sharded and from_flax:
654
+ raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
655
+
656
+ # load model
657
+ model_file = None
658
+ if from_flax:
659
+ model_file = _get_model_file(
660
+ pretrained_model_name_or_path,
661
+ weights_name=FLAX_WEIGHTS_NAME,
662
+ cache_dir=cache_dir,
663
+ force_download=force_download,
664
+ proxies=proxies,
665
+ local_files_only=local_files_only,
666
+ token=token,
667
+ revision=revision,
668
+ subfolder=subfolder,
669
+ user_agent=user_agent,
670
+ commit_hash=commit_hash,
671
+ )
672
+ model = cls.from_config(config, **unused_kwargs)
673
+
674
+ # Convert the weights
675
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
676
+
677
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
678
+ else:
679
+ if is_sharded:
680
+ sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
681
+ pretrained_model_name_or_path,
682
+ index_file,
683
+ cache_dir=cache_dir,
684
+ proxies=proxies,
685
+ local_files_only=local_files_only,
686
+ token=token,
687
+ user_agent=user_agent,
688
+ revision=revision,
689
+ subfolder=subfolder or "",
690
+ )
691
+
692
+ elif use_safetensors and not is_sharded:
693
+ try:
694
+ model_file = _get_model_file(
695
+ pretrained_model_name_or_path,
696
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
697
+ cache_dir=cache_dir,
698
+ force_download=force_download,
699
+ proxies=proxies,
700
+ local_files_only=local_files_only,
701
+ token=token,
702
+ revision=revision,
703
+ subfolder=subfolder,
704
+ user_agent=user_agent,
705
+ commit_hash=commit_hash,
706
+ )
707
+
708
+ except IOError as e:
709
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
710
+ if not allow_pickle:
711
+ raise
712
+ logger.warning(
713
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
714
+ )
715
+
716
+ if model_file is None and not is_sharded:
717
+ model_file = _get_model_file(
718
+ pretrained_model_name_or_path,
719
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
720
+ cache_dir=cache_dir,
721
+ force_download=force_download,
722
+ proxies=proxies,
723
+ local_files_only=local_files_only,
724
+ token=token,
725
+ revision=revision,
726
+ subfolder=subfolder,
727
+ user_agent=user_agent,
728
+ commit_hash=commit_hash,
729
+ )
730
+
731
+ if low_cpu_mem_usage:
732
+ # Instantiate model with empty weights
733
+ with accelerate.init_empty_weights():
734
+ model = cls.from_config(config, **unused_kwargs)
735
+
736
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
737
+ if device_map is None and not is_sharded:
738
+ param_device = "cpu"
739
+ state_dict = load_state_dict(model_file, variant=variant)
740
+ model._convert_deprecated_attention_blocks(state_dict)
741
+ # move the params from meta device to cpu
742
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
743
+ if len(missing_keys) > 0:
744
+ raise ValueError(
745
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
746
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
747
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
748
+ " those weights or else make sure your checkpoint file is correct."
749
+ )
750
+
751
+ unexpected_keys = load_model_dict_into_meta(
752
+ model,
753
+ state_dict,
754
+ device=param_device,
755
+ dtype=torch_dtype,
756
+ model_name_or_path=pretrained_model_name_or_path,
757
+ )
758
+
759
+ if cls._keys_to_ignore_on_load_unexpected is not None:
760
+ for pat in cls._keys_to_ignore_on_load_unexpected:
761
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
762
+
763
+ if len(unexpected_keys) > 0:
764
+ logger.warning(
765
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
766
+ )
767
+
768
+ else: # else let accelerate handle loading and dispatching.
769
+ # Load weights and dispatch according to the device_map
770
+ # by default the device_map is None and the weights are loaded on the CPU
771
+ force_hook = True
772
+ device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
773
+ if device_map is None and is_sharded:
774
+ # we load the parameters on the cpu
775
+ device_map = {"": "cpu"}
776
+ force_hook = False
777
+ try:
778
+ accelerate.load_checkpoint_and_dispatch(
779
+ model,
780
+ model_file if not is_sharded else index_file,
781
+ device_map,
782
+ max_memory=max_memory,
783
+ offload_folder=offload_folder,
784
+ offload_state_dict=offload_state_dict,
785
+ dtype=torch_dtype,
786
+ force_hooks=force_hook,
787
+ strict=True,
788
+ )
789
+ except AttributeError as e:
790
+ # When using accelerate loading, we do not have the ability to load the state
791
+ # dict and rename the weight names manually. Additionally, accelerate skips
792
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
793
+ # (which look like they should be private variables?), so we can't use the standard hooks
794
+ # to rename parameters on load. We need to mimic the original weight names so the correct
795
+ # attributes are available. After we have loaded the weights, we convert the deprecated
796
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
797
+ # the weights so we don't have to do this again.
798
+
799
+ if "'Attention' object has no attribute" in str(e):
800
+ logger.warning(
801
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
802
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
803
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
804
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
805
+ " please also re-upload it or open a PR on the original repository."
806
+ )
807
+ model._temp_convert_self_to_deprecated_attention_blocks()
808
+ accelerate.load_checkpoint_and_dispatch(
809
+ model,
810
+ model_file if not is_sharded else index_file,
811
+ device_map,
812
+ max_memory=max_memory,
813
+ offload_folder=offload_folder,
814
+ offload_state_dict=offload_state_dict,
815
+ dtype=torch_dtype,
816
+ force_hooks=force_hook,
817
+ strict=True,
818
+ )
819
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
820
+ else:
821
+ raise e
822
+
823
+ loading_info = {
824
+ "missing_keys": [],
825
+ "unexpected_keys": [],
826
+ "mismatched_keys": [],
827
+ "error_msgs": [],
828
+ }
829
+ else:
830
+ model = cls.from_config(config, **unused_kwargs)
831
+
832
+ state_dict = load_state_dict(model_file, variant=variant)
833
+ model._convert_deprecated_attention_blocks(state_dict)
834
+
835
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
836
+ model,
837
+ state_dict,
838
+ model_file,
839
+ pretrained_model_name_or_path,
840
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
841
+ )
842
+
843
+ loading_info = {
844
+ "missing_keys": missing_keys,
845
+ "unexpected_keys": unexpected_keys,
846
+ "mismatched_keys": mismatched_keys,
847
+ "error_msgs": error_msgs,
848
+ }
849
+
850
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
851
+ raise ValueError(
852
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
853
+ )
854
+ elif torch_dtype is not None:
855
+ model = model.to(torch_dtype)
856
+
857
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
858
+
859
+ # Set model in evaluation mode to deactivate DropOut modules by default
860
+ model.eval()
861
+ if output_loading_info:
862
+ return model, loading_info
863
+
864
+ return model
865
+
866
+ @classmethod
867
+ def _load_pretrained_model(
868
+ cls,
869
+ model,
870
+ state_dict: OrderedDict,
871
+ resolved_archive_file,
872
+ pretrained_model_name_or_path: Union[str, os.PathLike],
873
+ ignore_mismatched_sizes: bool = False,
874
+ ):
875
+ # Retrieve missing & unexpected_keys
876
+ model_state_dict = model.state_dict()
877
+ loaded_keys = list(state_dict.keys())
878
+
879
+ expected_keys = list(model_state_dict.keys())
880
+
881
+ original_loaded_keys = loaded_keys
882
+
883
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
884
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
885
+
886
+ # Make sure we are able to load base models as well as derived models (with heads)
887
+ model_to_load = model
888
+
889
+ def _find_mismatched_keys(
890
+ state_dict,
891
+ model_state_dict,
892
+ loaded_keys,
893
+ ignore_mismatched_sizes,
894
+ ):
895
+ mismatched_keys = []
896
+ if ignore_mismatched_sizes:
897
+ for checkpoint_key in loaded_keys:
898
+ model_key = checkpoint_key
899
+
900
+ if (
901
+ model_key in model_state_dict
902
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
903
+ ):
904
+ mismatched_keys.append(
905
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
906
+ )
907
+ del state_dict[checkpoint_key]
908
+ return mismatched_keys
909
+
910
+ if state_dict is not None:
911
+ # Whole checkpoint
912
+ mismatched_keys = _find_mismatched_keys(
913
+ state_dict,
914
+ model_state_dict,
915
+ original_loaded_keys,
916
+ ignore_mismatched_sizes,
917
+ )
918
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
919
+
920
+ if len(error_msgs) > 0:
921
+ error_msg = "\n\t".join(error_msgs)
922
+ if "size mismatch" in error_msg:
923
+ error_msg += (
924
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
925
+ )
926
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
927
+
928
+ if len(unexpected_keys) > 0:
929
+ logger.warning(
930
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
931
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
932
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
933
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
934
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
935
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
936
+ " identical (initializing a BertForSequenceClassification model from a"
937
+ " BertForSequenceClassification model)."
938
+ )
939
+ else:
940
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
941
+ if len(missing_keys) > 0:
942
+ logger.warning(
943
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
944
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
945
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
946
+ )
947
+ elif len(mismatched_keys) == 0:
948
+ logger.info(
949
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
950
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
951
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
952
+ " without further training."
953
+ )
954
+ if len(mismatched_keys) > 0:
955
+ mismatched_warning = "\n".join(
956
+ [
957
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
958
+ for key, shape1, shape2 in mismatched_keys
959
+ ]
960
+ )
961
+ logger.warning(
962
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
963
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
964
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
965
+ " able to use it for predictions and inference."
966
+ )
967
+
968
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
969
+
970
+ @classmethod
971
+ def _get_signature_keys(cls, obj):
972
+ parameters = inspect.signature(obj.__init__).parameters
973
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
974
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
975
+ expected_modules = set(required_parameters.keys()) - {"self"}
976
+
977
+ return expected_modules, optional_parameters
978
+
979
+ # Adapted from `transformers` modeling_utils.py
980
+ def _get_no_split_modules(self, device_map: str):
981
+ """
982
+ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
983
+ get the underlying `_no_split_modules`.
984
+
985
+ Args:
986
+ device_map (`str`):
987
+ The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
988
+
989
+ Returns:
990
+ `List[str]`: List of modules that should not be split
991
+ """
992
+ _no_split_modules = set()
993
+ modules_to_check = [self]
994
+ while len(modules_to_check) > 0:
995
+ module = modules_to_check.pop(-1)
996
+ # if the module does not appear in _no_split_modules, we also check the children
997
+ if module.__class__.__name__ not in _no_split_modules:
998
+ if isinstance(module, ModelMixin):
999
+ if module._no_split_modules is None:
1000
+ raise ValueError(
1001
+ f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
1002
+ "class needs to implement the `_no_split_modules` attribute."
1003
+ )
1004
+ else:
1005
+ _no_split_modules = _no_split_modules | set(module._no_split_modules)
1006
+ modules_to_check += list(module.children())
1007
+ return list(_no_split_modules)
1008
+
1009
+ @property
1010
+ def device(self) -> torch.device:
1011
+ """
1012
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1013
+ device).
1014
+ """
1015
+ return get_parameter_device(self)
1016
+
1017
+ @property
1018
+ def dtype(self) -> torch.dtype:
1019
+ """
1020
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1021
+ """
1022
+ return get_parameter_dtype(self)
1023
+
1024
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1025
+ """
1026
+ Get number of (trainable or non-embedding) parameters in the module.
1027
+
1028
+ Args:
1029
+ only_trainable (`bool`, *optional*, defaults to `False`):
1030
+ Whether or not to return only the number of trainable parameters.
1031
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1032
+ Whether or not to return only the number of non-embedding parameters.
1033
+
1034
+ Returns:
1035
+ `int`: The number of parameters.
1036
+
1037
+ Example:
1038
+
1039
+ ```py
1040
+ from diffusers import UNet2DConditionModel
1041
+
1042
+ model_id = "runwayml/stable-diffusion-v1-5"
1043
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1044
+ unet.num_parameters(only_trainable=True)
1045
+ 859520964
1046
+ ```
1047
+ """
1048
+
1049
+ if exclude_embeddings:
1050
+ embedding_param_names = [
1051
+ f"{name}.weight"
1052
+ for name, module_type in self.named_modules()
1053
+ if isinstance(module_type, torch.nn.Embedding)
1054
+ ]
1055
+ non_embedding_parameters = [
1056
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1057
+ ]
1058
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1059
+ else:
1060
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1061
+
1062
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1063
+ deprecated_attention_block_paths = []
1064
+
1065
+ def recursive_find_attn_block(name, module):
1066
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1067
+ deprecated_attention_block_paths.append(name)
1068
+
1069
+ for sub_name, sub_module in module.named_children():
1070
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1071
+ recursive_find_attn_block(sub_name, sub_module)
1072
+
1073
+ recursive_find_attn_block("", self)
1074
+
1075
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1076
+ # because it is possible we are loading from a state dict that was already
1077
+ # converted
1078
+
1079
+ for path in deprecated_attention_block_paths:
1080
+ # group_norm path stays the same
1081
+
1082
+ # query -> to_q
1083
+ if f"{path}.query.weight" in state_dict:
1084
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1085
+ if f"{path}.query.bias" in state_dict:
1086
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1087
+
1088
+ # key -> to_k
1089
+ if f"{path}.key.weight" in state_dict:
1090
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1091
+ if f"{path}.key.bias" in state_dict:
1092
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1093
+
1094
+ # value -> to_v
1095
+ if f"{path}.value.weight" in state_dict:
1096
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1097
+ if f"{path}.value.bias" in state_dict:
1098
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1099
+
1100
+ # proj_attn -> to_out.0
1101
+ if f"{path}.proj_attn.weight" in state_dict:
1102
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1103
+ if f"{path}.proj_attn.bias" in state_dict:
1104
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1105
+
1106
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1107
+ deprecated_attention_block_modules = []
1108
+
1109
+ def recursive_find_attn_block(module):
1110
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1111
+ deprecated_attention_block_modules.append(module)
1112
+
1113
+ for sub_module in module.children():
1114
+ recursive_find_attn_block(sub_module)
1115
+
1116
+ recursive_find_attn_block(self)
1117
+
1118
+ for module in deprecated_attention_block_modules:
1119
+ module.query = module.to_q
1120
+ module.key = module.to_k
1121
+ module.value = module.to_v
1122
+ module.proj_attn = module.to_out[0]
1123
+
1124
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1125
+ # that _all_ the weights are loaded into the new attributes and we're not
1126
+ # making an incorrect assumption that this model should be converted when
1127
+ # it really shouldn't be.
1128
+ del module.to_q
1129
+ del module.to_k
1130
+ del module.to_v
1131
+ del module.to_out
1132
+
1133
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1134
+ deprecated_attention_block_modules = []
1135
+
1136
+ def recursive_find_attn_block(module) -> None:
1137
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1138
+ deprecated_attention_block_modules.append(module)
1139
+
1140
+ for sub_module in module.children():
1141
+ recursive_find_attn_block(sub_module)
1142
+
1143
+ recursive_find_attn_block(self)
1144
+
1145
+ for module in deprecated_attention_block_modules:
1146
+ module.to_q = module.query
1147
+ module.to_k = module.key
1148
+ module.to_v = module.value
1149
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1150
+
1151
+ del module.query
1152
+ del module.key
1153
+ del module.value
1154
+ del module.proj_attn
1155
+
1156
+
1157
+ class LegacyModelMixin(ModelMixin):
1158
+ r"""
1159
+ A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
1160
+ pipeline-specific classes (like `DiTTransformer2DModel`).
1161
+ """
1162
+
1163
+ @classmethod
1164
+ @validate_hf_hub_args
1165
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1166
+ # To prevent dependency import problem.
1167
+ from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
1168
+
1169
+ # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
1170
+ kwargs_copy = kwargs.copy()
1171
+
1172
+ cache_dir = kwargs.pop("cache_dir", None)
1173
+ force_download = kwargs.pop("force_download", False)
1174
+ proxies = kwargs.pop("proxies", None)
1175
+ local_files_only = kwargs.pop("local_files_only", None)
1176
+ token = kwargs.pop("token", None)
1177
+ revision = kwargs.pop("revision", None)
1178
+ subfolder = kwargs.pop("subfolder", None)
1179
+
1180
+ # Load config if we don't provide a configuration
1181
+ config_path = pretrained_model_name_or_path
1182
+
1183
+ user_agent = {
1184
+ "diffusers": __version__,
1185
+ "file_type": "model",
1186
+ "framework": "pytorch",
1187
+ }
1188
+
1189
+ # load config
1190
+ config, _, _ = cls.load_config(
1191
+ config_path,
1192
+ cache_dir=cache_dir,
1193
+ return_unused_kwargs=True,
1194
+ return_commit_hash=True,
1195
+ force_download=force_download,
1196
+ proxies=proxies,
1197
+ local_files_only=local_files_only,
1198
+ token=token,
1199
+ revision=revision,
1200
+ subfolder=subfolder,
1201
+ user_agent=user_agent,
1202
+ **kwargs,
1203
+ )
1204
+ # resolve remapping
1205
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
1206
+
1207
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
mcp_servers/fashion_vlm/models/phi.py ADDED
@@ -0,0 +1,1489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from packaging import version
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import (
31
+ _prepare_4d_causal_attention_mask,
32
+ _prepare_4d_causal_attention_mask_for_sdpa,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.utils import (
42
+ add_code_sample_docstrings,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ get_torch_version,
46
+ is_flash_attn_2_available,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.models.phi.configuration_phi import PhiConfig
52
+
53
+
54
+ if is_flash_attn_2_available():
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
62
+ _CONFIG_FOR_DOC = "PhiConfig"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
79
+ class PhiRotaryEmbedding(nn.Module):
80
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
81
+ super().__init__()
82
+
83
+ self.dim = dim
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.base = base
86
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
87
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
88
+
89
+ # Build here to make `torch.jit.trace` work.
90
+ self._set_cos_sin_cache(
91
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
92
+ )
93
+
94
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
95
+ self.max_seq_len_cached = seq_len
96
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
97
+
98
+ freqs = torch.outer(t, self.inv_freq)
99
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
+ emb = torch.cat((freqs, freqs), dim=-1)
101
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
102
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
103
+
104
+ def forward(self, x, seq_len=None):
105
+ # x: [bs, num_attention_heads, seq_len, head_size]
106
+ if seq_len > self.max_seq_len_cached:
107
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
108
+
109
+ return (
110
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
111
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
112
+ )
113
+
114
+
115
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
116
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
117
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
118
+
119
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
120
+ self.scaling_factor = scaling_factor
121
+ super().__init__(dim, max_position_embeddings, base, device)
122
+
123
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
124
+ self.max_seq_len_cached = seq_len
125
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
126
+ t = t / self.scaling_factor
127
+
128
+ freqs = torch.outer(t, self.inv_freq)
129
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
130
+ emb = torch.cat((freqs, freqs), dim=-1)
131
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
132
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
133
+
134
+
135
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
136
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
137
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
138
+
139
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
140
+ self.scaling_factor = scaling_factor
141
+ super().__init__(dim, max_position_embeddings, base, device)
142
+
143
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
144
+ self.max_seq_len_cached = seq_len
145
+
146
+ if seq_len > self.max_position_embeddings:
147
+ base = self.base * (
148
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
149
+ ) ** (self.dim / (self.dim - 2))
150
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
151
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
152
+
153
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
154
+
155
+ freqs = torch.outer(t, self.inv_freq)
156
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
157
+ emb = torch.cat((freqs, freqs), dim=-1)
158
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
159
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
160
+
161
+
162
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
163
+ def rotate_half(x):
164
+ """Rotates half the hidden dims of the input."""
165
+ x1 = x[..., : x.shape[-1] // 2]
166
+ x2 = x[..., x.shape[-1] // 2 :]
167
+ return torch.cat((-x2, x1), dim=-1)
168
+
169
+
170
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
171
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
172
+ """Applies Rotary Position Embedding to the query and key tensors.
173
+
174
+ Args:
175
+ q (`torch.Tensor`): The query tensor.
176
+ k (`torch.Tensor`): The key tensor.
177
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
178
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
179
+ position_ids (`torch.Tensor`):
180
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
181
+ used to pass offsetted position ids when working with a KV-cache.
182
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
183
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
184
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
185
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
186
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
187
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
188
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
189
+ Returns:
190
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
191
+ """
192
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
193
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
194
+ q_embed = (q * cos) + (rotate_half(q) * sin)
195
+ k_embed = (k * cos) + (rotate_half(k) * sin)
196
+ return q_embed, k_embed
197
+
198
+
199
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
200
+ class PhiMLP(nn.Module):
201
+ def __init__(self, config):
202
+ super().__init__()
203
+ self.config = config
204
+ self.activation_fn = ACT2FN[config.hidden_act]
205
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
206
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
207
+
208
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
209
+ hidden_states = self.fc1(hidden_states)
210
+ hidden_states = self.activation_fn(hidden_states)
211
+ hidden_states = self.fc2(hidden_states)
212
+ return hidden_states
213
+
214
+
215
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
216
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
217
+ """
218
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
219
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
220
+ """
221
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
222
+ if n_rep == 1:
223
+ return hidden_states
224
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
225
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
226
+
227
+
228
+ class PhiAttention(nn.Module):
229
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
230
+
231
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
232
+ super().__init__()
233
+ self.config = config
234
+ self.layer_idx = layer_idx
235
+ if layer_idx is None:
236
+ logger.warning_once(
237
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
238
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
239
+ "when creating this class."
240
+ )
241
+
242
+ self.attention_dropout = config.attention_dropout
243
+ self.hidden_size = config.hidden_size
244
+ self.num_heads = config.num_attention_heads
245
+ self.head_dim = self.hidden_size // self.num_heads
246
+ self.num_key_value_heads = config.num_key_value_heads
247
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
248
+ self.max_position_embeddings = config.max_position_embeddings
249
+ self.rope_theta = config.rope_theta
250
+ self.partial_rotary_factor = config.partial_rotary_factor
251
+ self.is_causal = True
252
+
253
+ if (self.head_dim * self.num_heads) != self.hidden_size:
254
+ raise ValueError(
255
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
256
+ f" and `num_heads`: {self.num_heads})."
257
+ )
258
+
259
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
260
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
261
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
262
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
263
+
264
+ self.qk_layernorm = config.qk_layernorm
265
+ if self.qk_layernorm:
266
+ self.q_layernorm = nn.LayerNorm(
267
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
268
+ )
269
+ self.k_layernorm = nn.LayerNorm(
270
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
271
+ )
272
+
273
+ self._init_rope()
274
+
275
+ def _init_rope(self):
276
+ if self.config.rope_scaling is None:
277
+ self.rotary_emb = PhiRotaryEmbedding(
278
+ int(self.partial_rotary_factor * self.head_dim),
279
+ max_position_embeddings=self.max_position_embeddings,
280
+ base=self.rope_theta,
281
+ )
282
+ else:
283
+ scaling_type = self.config.rope_scaling["type"]
284
+ scaling_factor = self.config.rope_scaling["factor"]
285
+ if scaling_type == "linear":
286
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
287
+ int(self.partial_rotary_factor * self.head_dim),
288
+ max_position_embeddings=self.max_position_embeddings,
289
+ scaling_factor=scaling_factor,
290
+ base=self.rope_theta,
291
+ )
292
+ elif scaling_type == "dynamic":
293
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
294
+ int(self.partial_rotary_factor * self.head_dim),
295
+ max_position_embeddings=self.max_position_embeddings,
296
+ scaling_factor=scaling_factor,
297
+ base=self.rope_theta,
298
+ )
299
+ else:
300
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ position_ids: Optional[torch.LongTensor] = None,
307
+ past_key_value: Optional[Cache] = None,
308
+ output_attentions: bool = False,
309
+ use_cache: bool = False,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ bsz, q_len, _ = hidden_states.size()
312
+
313
+ query_states = self.q_proj(hidden_states)
314
+ key_states = self.k_proj(hidden_states)
315
+ value_states = self.v_proj(hidden_states)
316
+
317
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
318
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320
+
321
+ if self.qk_layernorm:
322
+ query_states = self.q_layernorm(query_states)
323
+ key_states = self.k_layernorm(key_states)
324
+
325
+ kv_seq_len = key_states.shape[-2]
326
+ if past_key_value is not None:
327
+ if self.layer_idx is None:
328
+ raise ValueError(
329
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
330
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
331
+ "with a layer index."
332
+ )
333
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
334
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
+
336
+ # Partial rotary embedding
337
+ query_rot, query_pass = (
338
+ query_states[..., : self.rotary_emb.dim],
339
+ query_states[..., self.rotary_emb.dim :],
340
+ )
341
+ key_rot, key_pass = (
342
+ key_states[..., : self.rotary_emb.dim],
343
+ key_states[..., self.rotary_emb.dim :],
344
+ )
345
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
346
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
347
+
348
+ # [batch_size, seq_length, num_heads, head_dim]
349
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
350
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
351
+
352
+ if past_key_value is not None:
353
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
354
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
355
+
356
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
357
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
358
+
359
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
360
+ attn_weights = torch.matmul(
361
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
362
+ ) / math.sqrt(self.head_dim)
363
+
364
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
+ raise ValueError(
366
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
+ f" {attn_weights.size()}"
368
+ )
369
+
370
+ if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
+ raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
+ )
375
+ attn_weights = attn_weights + attention_mask
376
+
377
+ # upcast attention to fp32
378
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
379
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
380
+
381
+ attn_output = torch.matmul(attn_weights, value_states)
382
+
383
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
+ raise ValueError(
385
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
+ f" {attn_output.size()}"
387
+ )
388
+
389
+ attn_output = attn_output.transpose(1, 2).contiguous()
390
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
+
392
+ attn_output = self.dense(attn_output)
393
+
394
+ if not output_attentions:
395
+ attn_weights = None
396
+
397
+ return attn_output, attn_weights, past_key_value
398
+
399
+
400
+ class PhiFlashAttention2(PhiAttention):
401
+ """
402
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
403
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
404
+ flash attention and deal with padding tokens in case the input contains any of them.
405
+ """
406
+
407
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
408
+ def __init__(self, *args, **kwargs):
409
+ super().__init__(*args, **kwargs)
410
+
411
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
412
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
413
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
414
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ attention_mask: Optional[torch.LongTensor] = None,
420
+ position_ids: Optional[torch.LongTensor] = None,
421
+ past_key_value: Optional[Cache] = None,
422
+ output_attentions: bool = False,
423
+ use_cache: bool = False,
424
+ **kwargs,
425
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
426
+ # PhiFlashAttention2 attention does not support output_attentions
427
+
428
+ output_attentions = False
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ query_states = self.q_proj(hidden_states)
433
+ key_states = self.k_proj(hidden_states)
434
+ value_states = self.v_proj(hidden_states)
435
+
436
+ # Flash attention requires the input to have the shape
437
+ # batch_size x seq_length x head_dim x hidden_dim
438
+ # therefore we just need to keep the original shape
439
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
440
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
441
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
442
+
443
+ if self.qk_layernorm:
444
+ query_states = self.q_layernorm(query_states)
445
+ key_states = self.k_layernorm(key_states)
446
+
447
+ kv_seq_len = key_states.shape[-2]
448
+ if past_key_value is not None:
449
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
450
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
451
+
452
+ # Partial rotary embedding
453
+ query_rot, query_pass = (
454
+ query_states[..., : self.rotary_emb.dim],
455
+ query_states[..., self.rotary_emb.dim :],
456
+ )
457
+ key_rot, key_pass = (
458
+ key_states[..., : self.rotary_emb.dim],
459
+ key_states[..., self.rotary_emb.dim :],
460
+ )
461
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
462
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
463
+
464
+ # [batch_size, seq_length, num_heads, head_dim]
465
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
466
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
467
+
468
+ if past_key_value is not None:
469
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
470
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
471
+
472
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
473
+ # to be able to avoid many of these transpose/reshape/view.
474
+ query_states = query_states.transpose(1, 2)
475
+ key_states = key_states.transpose(1, 2)
476
+ value_states = value_states.transpose(1, 2)
477
+
478
+ attn_dropout = self.attention_dropout if self.training else 0.0
479
+
480
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
481
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
482
+ # cast them back in the correct dtype just to be sure everything works as expected.
483
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
484
+ # in fp32.
485
+
486
+ if query_states.dtype == torch.float32:
487
+ if torch.is_autocast_enabled():
488
+ target_dtype = torch.get_autocast_gpu_dtype()
489
+ # Handle the case where the model is quantized
490
+ elif hasattr(self.config, "_pre_quantization_dtype"):
491
+ target_dtype = self.config._pre_quantization_dtype
492
+ else:
493
+ target_dtype = self.q_proj.weight.dtype
494
+
495
+ logger.warning_once(
496
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
497
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
498
+ f" {target_dtype}."
499
+ )
500
+
501
+ query_states = query_states.to(target_dtype)
502
+ key_states = key_states.to(target_dtype)
503
+ value_states = value_states.to(target_dtype)
504
+
505
+ attn_output = self._flash_attention_forward(
506
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
507
+ )
508
+
509
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
510
+ attn_output = self.dense(attn_output)
511
+
512
+ if not output_attentions:
513
+ attn_weights = None
514
+
515
+ return attn_output, attn_weights, past_key_value
516
+
517
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
518
+ def _flash_attention_forward(
519
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
520
+ ):
521
+ """
522
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
523
+ first unpad the input, then computes the attention scores and pad the final attention scores.
524
+
525
+ Args:
526
+ query_states (`torch.Tensor`):
527
+ Input query states to be passed to Flash Attention API
528
+ key_states (`torch.Tensor`):
529
+ Input key states to be passed to Flash Attention API
530
+ value_states (`torch.Tensor`):
531
+ Input value states to be passed to Flash Attention API
532
+ attention_mask (`torch.Tensor`):
533
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
534
+ position of padding tokens and 1 for the position of non-padding tokens.
535
+ dropout (`float`):
536
+ Attention dropout
537
+ softmax_scale (`float`, *optional*):
538
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
539
+ """
540
+ if not self._flash_attn_uses_top_left_mask:
541
+ causal = self.is_causal
542
+ else:
543
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
544
+ causal = self.is_causal and query_length != 1
545
+
546
+ # Contains at least one padding token in the sequence
547
+ if attention_mask is not None:
548
+ batch_size = query_states.shape[0]
549
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
550
+ query_states, key_states, value_states, attention_mask, query_length
551
+ )
552
+
553
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
554
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
555
+
556
+ attn_output_unpad = flash_attn_varlen_func(
557
+ query_states,
558
+ key_states,
559
+ value_states,
560
+ cu_seqlens_q=cu_seqlens_q,
561
+ cu_seqlens_k=cu_seqlens_k,
562
+ max_seqlen_q=max_seqlen_in_batch_q,
563
+ max_seqlen_k=max_seqlen_in_batch_k,
564
+ dropout_p=dropout,
565
+ softmax_scale=softmax_scale,
566
+ causal=causal,
567
+ )
568
+
569
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
570
+ else:
571
+ attn_output = flash_attn_func(
572
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
573
+ )
574
+
575
+ return attn_output
576
+
577
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
578
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
579
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
580
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
581
+
582
+ key_layer = index_first_axis(
583
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
584
+ )
585
+ value_layer = index_first_axis(
586
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
+ )
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q,
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ class PhiSdpaAttention(PhiAttention):
618
+ def __init__(self, *args, **kwargs):
619
+ super().__init__(*args, **kwargs)
620
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
621
+
622
+ """
623
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
624
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
625
+ SDPA API.
626
+ """
627
+
628
+ # Adapted from PhiAttention.forward
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.Tensor,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ position_ids: Optional[torch.LongTensor] = None,
634
+ past_key_value: Optional[Cache] = None,
635
+ output_attentions: bool = False,
636
+ use_cache: bool = False,
637
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
638
+ if output_attentions:
639
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
640
+ logger.warning_once(
641
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
642
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
643
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
644
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
645
+ )
646
+ return super().forward(
647
+ hidden_states=hidden_states,
648
+ attention_mask=attention_mask,
649
+ position_ids=position_ids,
650
+ past_key_value=past_key_value,
651
+ output_attentions=output_attentions,
652
+ use_cache=use_cache,
653
+ )
654
+
655
+ bsz, q_len, _ = hidden_states.size()
656
+
657
+ query_states = self.q_proj(hidden_states)
658
+ key_states = self.k_proj(hidden_states)
659
+ value_states = self.v_proj(hidden_states)
660
+
661
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
662
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
663
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
664
+
665
+ if self.qk_layernorm:
666
+ query_states = self.q_layernorm(query_states)
667
+ key_states = self.k_layernorm(key_states)
668
+
669
+ kv_seq_len = key_states.shape[-2]
670
+ if past_key_value is not None:
671
+ if self.layer_idx is None:
672
+ raise ValueError(
673
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
674
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
675
+ "with a layer index."
676
+ )
677
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
678
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
679
+
680
+ # Partial rotary embedding
681
+ query_rot, query_pass = (
682
+ query_states[..., : self.rotary_emb.dim],
683
+ query_states[..., self.rotary_emb.dim :],
684
+ )
685
+ key_rot, key_pass = (
686
+ key_states[..., : self.rotary_emb.dim],
687
+ key_states[..., self.rotary_emb.dim :],
688
+ )
689
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
690
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
691
+
692
+ # [batch_size, seq_length, num_heads, head_dim]
693
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
694
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
695
+
696
+ if past_key_value is not None:
697
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
698
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
699
+
700
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
701
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
702
+
703
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
704
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
705
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
706
+ if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
707
+ query_states = query_states.contiguous()
708
+ key_states = key_states.contiguous()
709
+ value_states = value_states.contiguous()
710
+
711
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
712
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
713
+ is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
714
+
715
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
716
+ query_states,
717
+ key_states,
718
+ value_states,
719
+ attn_mask=attention_mask,
720
+ dropout_p=self.attention_dropout if self.training else 0.0,
721
+ is_causal=is_causal,
722
+ )
723
+
724
+ attn_output = attn_output.transpose(1, 2).contiguous()
725
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
726
+
727
+ attn_output = self.dense(attn_output)
728
+
729
+ return attn_output, None, past_key_value
730
+
731
+
732
+ PHI_ATTENTION_CLASSES = {
733
+ "eager": PhiAttention,
734
+ "flash_attention_2": PhiFlashAttention2,
735
+ "sdpa": PhiSdpaAttention,
736
+ }
737
+
738
+
739
+ class PhiDecoderLayer(nn.Module):
740
+ def __init__(self, config: PhiConfig, layer_idx: int):
741
+ super().__init__()
742
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
743
+ self.mlp = PhiMLP(config)
744
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
745
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
746
+
747
+ def forward(
748
+ self,
749
+ hidden_states: torch.Tensor,
750
+ attention_mask: Optional[torch.Tensor] = None,
751
+ position_ids: Optional[torch.LongTensor] = None,
752
+ output_attentions: Optional[bool] = False,
753
+ use_cache: Optional[bool] = False,
754
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
755
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
756
+ """
757
+ Args:
758
+ hidden_states (`torch.FloatTensor`):
759
+ input to the layer of shape `(batch, seq_len, embed_dim)`
760
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
761
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
762
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
763
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
764
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
765
+ output_attentions (`bool`, *optional*):
766
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
767
+ returned tensors for more detail.
768
+ use_cache (`bool`, *optional*):
769
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
770
+ (see `past_key_values`).
771
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
772
+ """
773
+
774
+ residual = hidden_states
775
+
776
+ hidden_states = self.input_layernorm(hidden_states)
777
+
778
+ # Self Attention
779
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
780
+ hidden_states=hidden_states,
781
+ attention_mask=attention_mask,
782
+ position_ids=position_ids,
783
+ past_key_value=past_key_value,
784
+ output_attentions=output_attentions,
785
+ use_cache=use_cache,
786
+ )
787
+ attn_outputs = self.resid_dropout(attn_outputs)
788
+
789
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
790
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
791
+ outputs = (hidden_states,)
792
+
793
+ if output_attentions:
794
+ outputs += (self_attn_weights,)
795
+
796
+ if use_cache:
797
+ outputs += (present_key_value,)
798
+
799
+ return outputs
800
+
801
+
802
+ PHI_START_DOCSTRING = r"""
803
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
804
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
805
+ etc.)
806
+
807
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
808
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
809
+ and behavior.
810
+
811
+ Parameters:
812
+ config ([`PhiConfig`]):
813
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
814
+ load the weights associated with the model, only the configuration. Check out the
815
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
816
+ """
817
+
818
+
819
+ @add_start_docstrings(
820
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
821
+ PHI_START_DOCSTRING,
822
+ )
823
+ class PhiPreTrainedModel(PreTrainedModel):
824
+ config_class = PhiConfig
825
+ base_model_prefix = "model"
826
+ supports_gradient_checkpointing = True
827
+ _no_split_modules = ["PhiDecoderLayer"]
828
+ _skip_keys_device_placement = "past_key_values"
829
+ _supports_flash_attn_2 = True
830
+ _supports_sdpa = True
831
+ _supports_cache_class = True
832
+
833
+ def _init_weights(self, module):
834
+ std = self.config.initializer_range
835
+ if isinstance(module, nn.Linear):
836
+ module.weight.data.normal_(mean=0.0, std=std)
837
+ if module.bias is not None:
838
+ module.bias.data.zero_()
839
+ elif isinstance(module, nn.Embedding):
840
+ module.weight.data.normal_(mean=0.0, std=std)
841
+ if module.padding_idx is not None:
842
+ module.weight.data[module.padding_idx].zero_()
843
+
844
+
845
+ PHI_INPUTS_DOCSTRING = r"""
846
+ Args:
847
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
848
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
849
+ it.
850
+
851
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
852
+ [`PreTrainedTokenizer.__call__`] for details.
853
+
854
+ [What are input IDs?](../glossary#input-ids)
855
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
856
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
857
+
858
+ - 1 for tokens that are **not masked**,
859
+ - 0 for tokens that are **masked**.
860
+
861
+ [What are attention masks?](../glossary#attention-mask)
862
+
863
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
864
+ [`PreTrainedTokenizer.__call__`] for details.
865
+
866
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
867
+ `past_key_values`).
868
+
869
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
870
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
871
+ information on the default strategy.
872
+
873
+ - 1 indicates the head is **not masked**,
874
+ - 0 indicates the head is **masked**.
875
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
876
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
877
+ config.n_positions - 1]`.
878
+
879
+ [What are position IDs?](../glossary#position-ids)
880
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
881
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
882
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
883
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
884
+
885
+ Two formats are allowed:
886
+ - a [`~cache_utils.Cache`] instance;
887
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
888
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
889
+ cache format.
890
+
891
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
892
+ legacy cache format will be returned.
893
+
894
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
895
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
896
+ of shape `(batch_size, sequence_length)`.
897
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
898
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
899
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
900
+ model's internal embedding lookup matrix.
901
+ use_cache (`bool`, *optional*):
902
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
903
+ `past_key_values`).
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
906
+ tensors for more detail.
907
+ output_hidden_states (`bool`, *optional*):
908
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
909
+ more detail.
910
+ return_dict (`bool`, *optional*):
911
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
912
+ """
913
+
914
+
915
+ @add_start_docstrings(
916
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
917
+ PHI_START_DOCSTRING,
918
+ )
919
+ class PhiModel(PhiPreTrainedModel):
920
+ """
921
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
922
+
923
+ Args:
924
+ config: PhiConfig
925
+ """
926
+
927
+ def __init__(self, config: PhiConfig):
928
+ super().__init__(config)
929
+ self.padding_idx = config.pad_token_id
930
+ self.vocab_size = config.vocab_size
931
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
932
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
933
+ print("attention implementation: ", config._attn_implementation)
934
+ self.layers = nn.ModuleList(
935
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
936
+ )
937
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
938
+
939
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
940
+ self._use_sdpa = config._attn_implementation == "sdpa"
941
+
942
+ self.gradient_checkpointing = False
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ def get_input_embeddings(self):
947
+ return self.embed_tokens
948
+
949
+ def set_input_embeddings(self, value):
950
+ self.embed_tokens = value
951
+
952
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
953
+ def forward(
954
+ self,
955
+ input_ids: torch.LongTensor = None,
956
+ attention_mask: Optional[torch.Tensor] = None,
957
+ position_ids: Optional[torch.LongTensor] = None,
958
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
959
+ inputs_embeds: Optional[torch.FloatTensor] = None,
960
+ use_cache: Optional[bool] = None,
961
+ output_attentions: Optional[bool] = None,
962
+ output_hidden_states: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
965
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
966
+ output_hidden_states = (
967
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
968
+ )
969
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
970
+
971
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
972
+
973
+ # retrieve input_ids and inputs_embeds
974
+ if input_ids is not None and inputs_embeds is not None:
975
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
976
+ elif input_ids is not None:
977
+ batch_size, seq_length = input_ids.shape[:2]
978
+ elif inputs_embeds is not None:
979
+ batch_size, seq_length = inputs_embeds.shape[:2]
980
+ else:
981
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
982
+
983
+ past_key_values_length = 0
984
+
985
+ if self.gradient_checkpointing and self.training:
986
+ if use_cache:
987
+ logger.warning_once(
988
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
989
+ )
990
+ use_cache = False
991
+
992
+ if use_cache:
993
+ use_legacy_cache = not isinstance(past_key_values, Cache)
994
+ if use_legacy_cache:
995
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
996
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
997
+
998
+ if position_ids is None:
999
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1000
+ position_ids = torch.arange(
1001
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1002
+ )
1003
+ position_ids = position_ids.unsqueeze(0)
1004
+
1005
+ if inputs_embeds is None:
1006
+ inputs_embeds = self.embed_tokens(input_ids)
1007
+
1008
+ inputs_embeds = self.embed_dropout(inputs_embeds)
1009
+ # commented by Xavier
1010
+ # Attention mask.
1011
+ # if self._use_flash_attention_2:
1012
+ # # 2d mask is passed through the layers
1013
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1014
+ # elif self._use_sdpa and not output_attentions:
1015
+ # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1016
+ # attention_mask,
1017
+ # (batch_size, seq_length),
1018
+ # inputs_embeds,
1019
+ # past_key_values_length,
1020
+ # )
1021
+ # else:
1022
+ # # 4d mask is passed through the layers
1023
+ # attention_mask = _prepare_4d_causal_attention_mask(
1024
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1025
+ # )
1026
+ # commented by Xavier
1027
+
1028
+ hidden_states = inputs_embeds
1029
+
1030
+ # decoder layers
1031
+ all_hidden_states = () if output_hidden_states else None
1032
+ all_self_attns = () if output_attentions else None
1033
+ next_decoder_cache = None
1034
+ for decoder_layer in self.layers:
1035
+ if output_hidden_states:
1036
+ all_hidden_states += (hidden_states,)
1037
+
1038
+ if self.gradient_checkpointing and self.training:
1039
+ layer_outputs = self._gradient_checkpointing_func(
1040
+ decoder_layer.__call__,
1041
+ hidden_states,
1042
+ attention_mask,
1043
+ position_ids,
1044
+ past_key_values,
1045
+ output_attentions,
1046
+ )
1047
+ else:
1048
+ layer_outputs = decoder_layer(
1049
+ hidden_states,
1050
+ attention_mask=attention_mask,
1051
+ position_ids=position_ids,
1052
+ past_key_value=past_key_values,
1053
+ output_attentions=output_attentions,
1054
+ use_cache=use_cache,
1055
+ )
1056
+
1057
+ hidden_states = layer_outputs[0]
1058
+
1059
+ if use_cache:
1060
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1061
+
1062
+ if output_attentions:
1063
+ all_self_attns += (layer_outputs[1],)
1064
+
1065
+ hidden_states = self.final_layernorm(hidden_states)
1066
+
1067
+ # add hidden states from the last decoder layer
1068
+ if output_hidden_states:
1069
+ all_hidden_states += (hidden_states,)
1070
+
1071
+ next_cache = None
1072
+ if use_cache:
1073
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1074
+ if not return_dict:
1075
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1076
+ return BaseModelOutputWithPast(
1077
+ last_hidden_state=hidden_states,
1078
+ past_key_values=next_cache,
1079
+ hidden_states=all_hidden_states,
1080
+ attentions=all_self_attns,
1081
+ )
1082
+
1083
+
1084
+ class PhiForCausalLM(PhiPreTrainedModel):
1085
+ _tied_weights_keys = ["lm_head.weight"]
1086
+ def __init__(self, config):
1087
+ super().__init__(config)
1088
+ config.qk_layernorm = True
1089
+ config.use_cache = False
1090
+ self.model = PhiModel(config)
1091
+ self.vocab_size = config.vocab_size
1092
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
1093
+
1094
+ # Initialize weights and apply final processing
1095
+ self.post_init()
1096
+
1097
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1098
+ def get_input_embeddings(self):
1099
+ return self.model.embed_tokens
1100
+
1101
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1102
+ def set_input_embeddings(self, value):
1103
+ self.model.embed_tokens = value
1104
+
1105
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1106
+ def get_output_embeddings(self):
1107
+ return self.lm_head
1108
+
1109
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1110
+ def set_output_embeddings(self, new_embeddings):
1111
+ self.lm_head = new_embeddings
1112
+
1113
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1114
+ def set_decoder(self, decoder):
1115
+ self.model = decoder
1116
+
1117
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1118
+ def get_decoder(self):
1119
+ return self.model
1120
+
1121
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1122
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1123
+ def forward(
1124
+ self,
1125
+ input_ids: torch.LongTensor = None,
1126
+ attention_mask: Optional[torch.Tensor] = None,
1127
+ position_ids: Optional[torch.LongTensor] = None,
1128
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1129
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1130
+ labels: Optional[torch.LongTensor] = None,
1131
+ use_cache: Optional[bool] = None,
1132
+ output_attentions: Optional[bool] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ return_dict: Optional[bool] = None,
1135
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1136
+ r"""
1137
+ Args:
1138
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1139
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1140
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1141
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1142
+
1143
+ Returns:
1144
+
1145
+ Example:
1146
+
1147
+ ```python
1148
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1149
+
1150
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1151
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1152
+
1153
+ >>> prompt = "This is an example script ."
1154
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1155
+
1156
+ >>> # Generate
1157
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1158
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1159
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1160
+ ```"""
1161
+
1162
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1163
+ output_hidden_states = (
1164
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1165
+ )
1166
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1167
+
1168
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1169
+ outputs = self.model(
1170
+ input_ids=input_ids,
1171
+ attention_mask=attention_mask,
1172
+ position_ids=position_ids,
1173
+ past_key_values=past_key_values,
1174
+ inputs_embeds=inputs_embeds,
1175
+ use_cache=use_cache,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ return_dict=return_dict,
1179
+ )
1180
+
1181
+ hidden_states = outputs[0]
1182
+ logits = self.lm_head(hidden_states)
1183
+ logits = logits.float()
1184
+
1185
+ loss = None
1186
+ if labels is not None:
1187
+ # Shift so that tokens < n predict n
1188
+ shift_logits = logits[..., :-1, :].contiguous()
1189
+ shift_labels = labels[..., 1:].contiguous()
1190
+ # Flatten the tokens
1191
+ loss_fct = CrossEntropyLoss()
1192
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1193
+ shift_labels = shift_labels.view(-1)
1194
+ # Enable model parallelism
1195
+ shift_labels = shift_labels.to(shift_logits.device)
1196
+ loss = loss_fct(shift_logits, shift_labels)
1197
+
1198
+ if not return_dict:
1199
+ output = (logits,) + outputs[1:]
1200
+ return (loss,) + output if loss is not None else output
1201
+
1202
+ return CausalLMOutputWithPast(
1203
+ loss=loss,
1204
+ logits=logits,
1205
+ past_key_values=outputs.past_key_values,
1206
+ hidden_states=outputs.hidden_states,
1207
+ attentions=outputs.attentions,
1208
+ )
1209
+
1210
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1211
+ def prepare_inputs_for_generation(
1212
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1213
+ ):
1214
+ if past_key_values is not None:
1215
+ if isinstance(past_key_values, Cache):
1216
+ cache_length = past_key_values.get_seq_length()
1217
+ past_length = past_key_values.seen_tokens
1218
+ max_cache_length = past_key_values.get_max_length()
1219
+ else:
1220
+ cache_length = past_length = past_key_values[0][0].shape[2]
1221
+ max_cache_length = None
1222
+
1223
+ # Keep only the unprocessed tokens:
1224
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1225
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1226
+ # input)
1227
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1228
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1229
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1230
+ # input_ids based on the past_length.
1231
+ elif past_length < input_ids.shape[1]:
1232
+ input_ids = input_ids[:, past_length:]
1233
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1234
+
1235
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1236
+ if (
1237
+ max_cache_length is not None
1238
+ and attention_mask is not None
1239
+ and cache_length + input_ids.shape[1] > max_cache_length
1240
+ ):
1241
+ attention_mask = attention_mask[:, -max_cache_length:]
1242
+
1243
+ position_ids = kwargs.get("position_ids", None)
1244
+ if attention_mask is not None and position_ids is None:
1245
+ # create position_ids on the fly for batch generation
1246
+ position_ids = attention_mask.long().cumsum(-1) - 1
1247
+ position_ids.masked_fill_(attention_mask == 0, 1)
1248
+ if past_key_values:
1249
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1250
+
1251
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1252
+ if inputs_embeds is not None and past_key_values is None:
1253
+ model_inputs = {"inputs_embeds": inputs_embeds}
1254
+ else:
1255
+ model_inputs = {"input_ids": input_ids}
1256
+
1257
+ model_inputs.update(
1258
+ {
1259
+ "position_ids": position_ids,
1260
+ "past_key_values": past_key_values,
1261
+ "use_cache": kwargs.get("use_cache"),
1262
+ "attention_mask": attention_mask,
1263
+ }
1264
+ )
1265
+ return model_inputs
1266
+
1267
+ @staticmethod
1268
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1269
+ def _reorder_cache(past_key_values, beam_idx):
1270
+ reordered_past = ()
1271
+ for layer_past in past_key_values:
1272
+ reordered_past += (
1273
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1274
+ )
1275
+ return reordered_past
1276
+
1277
+
1278
+ @add_start_docstrings(
1279
+ """
1280
+ The PhiModel with a sequence classification head on top (linear layer).
1281
+
1282
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1283
+ (e.g. GPT-2) do.
1284
+
1285
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1286
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1287
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1288
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1289
+ each row of the batch).
1290
+ """,
1291
+ PHI_START_DOCSTRING,
1292
+ )
1293
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1294
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1295
+ def __init__(self, config):
1296
+ super().__init__(config)
1297
+ self.num_labels = config.num_labels
1298
+ self.model = PhiModel(config)
1299
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1300
+
1301
+ # Initialize weights and apply final processing
1302
+ self.post_init()
1303
+
1304
+ def get_input_embeddings(self):
1305
+ return self.model.embed_tokens
1306
+
1307
+ def set_input_embeddings(self, value):
1308
+ self.model.embed_tokens = value
1309
+
1310
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1311
+ def forward(
1312
+ self,
1313
+ input_ids: torch.LongTensor = None,
1314
+ attention_mask: Optional[torch.Tensor] = None,
1315
+ position_ids: Optional[torch.LongTensor] = None,
1316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1317
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1318
+ labels: Optional[torch.LongTensor] = None,
1319
+ use_cache: Optional[bool] = None,
1320
+ output_attentions: Optional[bool] = None,
1321
+ output_hidden_states: Optional[bool] = None,
1322
+ return_dict: Optional[bool] = None,
1323
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1324
+ r"""
1325
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1326
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1327
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1328
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1329
+ """
1330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1331
+
1332
+ model_outputs = self.model(
1333
+ input_ids,
1334
+ attention_mask=attention_mask,
1335
+ position_ids=position_ids,
1336
+ past_key_values=past_key_values,
1337
+ inputs_embeds=inputs_embeds,
1338
+ use_cache=use_cache,
1339
+ output_attentions=output_attentions,
1340
+ output_hidden_states=output_hidden_states,
1341
+ return_dict=return_dict,
1342
+ )
1343
+ hidden_states = model_outputs[0]
1344
+ logits = self.score(hidden_states)
1345
+
1346
+ if input_ids is not None:
1347
+ batch_size = input_ids.shape[0]
1348
+ else:
1349
+ batch_size = inputs_embeds.shape[0]
1350
+
1351
+ if self.config.pad_token_id is None and batch_size != 1:
1352
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1353
+ if self.config.pad_token_id is None:
1354
+ sequence_lengths = -1
1355
+ else:
1356
+ if input_ids is not None:
1357
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1358
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1359
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1360
+ sequence_lengths = sequence_lengths.to(logits.device)
1361
+ else:
1362
+ sequence_lengths = -1
1363
+
1364
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1365
+
1366
+ loss = None
1367
+ if labels is not None:
1368
+ labels = labels.to(logits.device)
1369
+ if self.config.problem_type is None:
1370
+ if self.num_labels == 1:
1371
+ self.config.problem_type = "regression"
1372
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1373
+ self.config.problem_type = "single_label_classification"
1374
+ else:
1375
+ self.config.problem_type = "multi_label_classification"
1376
+
1377
+ if self.config.problem_type == "regression":
1378
+ loss_fct = MSELoss()
1379
+ if self.num_labels == 1:
1380
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1381
+ else:
1382
+ loss = loss_fct(pooled_logits, labels)
1383
+ elif self.config.problem_type == "single_label_classification":
1384
+ loss_fct = CrossEntropyLoss()
1385
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1386
+ elif self.config.problem_type == "multi_label_classification":
1387
+ loss_fct = BCEWithLogitsLoss()
1388
+ loss = loss_fct(pooled_logits, labels)
1389
+ if not return_dict:
1390
+ output = (pooled_logits,) + model_outputs[1:]
1391
+ return ((loss,) + output) if loss is not None else output
1392
+
1393
+ return SequenceClassifierOutputWithPast(
1394
+ loss=loss,
1395
+ logits=pooled_logits,
1396
+ past_key_values=model_outputs.past_key_values,
1397
+ hidden_states=model_outputs.hidden_states,
1398
+ attentions=model_outputs.attentions,
1399
+ )
1400
+
1401
+
1402
+ @add_start_docstrings(
1403
+ """
1404
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1405
+ Named-Entity-Recognition (NER) tasks.
1406
+ """,
1407
+ PHI_START_DOCSTRING,
1408
+ )
1409
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1410
+ class PhiForTokenClassification(PhiPreTrainedModel):
1411
+ def __init__(self, config: PhiConfig):
1412
+ super().__init__(config)
1413
+ self.num_labels = config.num_labels
1414
+
1415
+ self.model = PhiModel(config)
1416
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1417
+ classifier_dropout = config.classifier_dropout
1418
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1419
+ classifier_dropout = config.hidden_dropout
1420
+ else:
1421
+ classifier_dropout = 0.1
1422
+ self.dropout = nn.Dropout(classifier_dropout)
1423
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1424
+
1425
+ # Initialize weights and apply final processing
1426
+ self.post_init()
1427
+
1428
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1429
+ @add_code_sample_docstrings(
1430
+ checkpoint=_CHECKPOINT_FOR_DOC,
1431
+ output_type=TokenClassifierOutput,
1432
+ config_class=_CONFIG_FOR_DOC,
1433
+ )
1434
+ def forward(
1435
+ self,
1436
+ input_ids: Optional[torch.LongTensor] = None,
1437
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1438
+ attention_mask: Optional[torch.Tensor] = None,
1439
+ inputs_embeds: Optional[torch.Tensor] = None,
1440
+ labels: Optional[torch.Tensor] = None,
1441
+ use_cache: Optional[bool] = None,
1442
+ output_attentions: Optional[bool] = None,
1443
+ output_hidden_states: Optional[bool] = None,
1444
+ return_dict: Optional[bool] = None,
1445
+ **deprecated_arguments,
1446
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1447
+ r"""
1448
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1449
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1450
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1451
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1452
+ """
1453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1454
+
1455
+ model_outputs = self.model(
1456
+ input_ids,
1457
+ past_key_values=past_key_values,
1458
+ attention_mask=attention_mask,
1459
+ inputs_embeds=inputs_embeds,
1460
+ use_cache=use_cache,
1461
+ output_attentions=output_attentions,
1462
+ output_hidden_states=output_hidden_states,
1463
+ return_dict=return_dict,
1464
+ )
1465
+
1466
+ hidden_states = model_outputs[0]
1467
+ hidden_states = self.dropout(hidden_states)
1468
+ logits = self.classifier(hidden_states)
1469
+
1470
+ loss = None
1471
+ if labels is not None:
1472
+ # move labels to correct device to enable model parallelism
1473
+ labels = labels.to(logits.device)
1474
+ batch_size, seq_length = labels.shape
1475
+ loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1478
+ )
1479
+
1480
+ if not return_dict:
1481
+ output = (logits,) + model_outputs[2:]
1482
+ return ((loss,) + output) if loss is not None else output
1483
+
1484
+ return TokenClassifierOutput(
1485
+ loss=loss,
1486
+ logits=logits,
1487
+ hidden_states=model_outputs.hidden_states,
1488
+ attentions=model_outputs.attentions,
1489
+ )
mcp_servers/fashion_vlm/models/sampling.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def log(t, eps=1e-20):
11
+ return torch.log(t.clamp(min=eps))
12
+
13
+
14
+ def gumbel_noise(t, generator=None):
15
+ noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
16
+ return -log(-log(noise))
17
+
18
+
19
+ def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
20
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
21
+
22
+
23
+ def top_k(logits, thres=0.9):
24
+ k = math.ceil((1 - thres) * logits.shape[-1])
25
+ val, ind = logits.topk(k, dim=-1)
26
+ probs = torch.full_like(logits, float("-inf"))
27
+ probs.scatter_(2, ind, val)
28
+ return probs
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ def cosine_schedule(t):
40
+ return torch.cos(t * math.pi * 0.5)
41
+
42
+
43
+ def linear_schedule(t):
44
+ mask_ratio = 1 - t
45
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
46
+ return mask_ratio
47
+
48
+
49
+ def pow(t, method):
50
+ exponent = float(method.replace("pow", ""))
51
+ mask_ratio = 1.0 - t**exponent
52
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
53
+ return mask_ratio
54
+
55
+
56
+ def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
57
+ for item in [t, start, end, tau]:
58
+ item = torch.tensor(item) if not torch.is_tensor(item) else item
59
+
60
+ # A gamma function based on sigmoid function.
61
+ v_start = torch.sigmoid(torch.tensor(start / tau))
62
+ v_end = torch.sigmoid(torch.tensor(end / tau))
63
+ output = torch.sigmoid((t * (end - start) + start) / tau)
64
+ output = (v_end - output) / (v_end - v_start)
65
+ return torch.clip(output, clip_min, 1.0)
66
+
67
+
68
+ def get_mask_chedule(method, **schedule_kwargs):
69
+ if method == "cosine":
70
+ return cosine_schedule
71
+ elif method == "linear":
72
+ return linear_schedule
73
+ elif "pow" in method:
74
+ return partial(pow, method=method)
75
+ elif method == "sigmoid":
76
+ return partial(sigmoid_schedule, **schedule_kwargs)
77
+ else:
78
+ raise ValueError("Unknown schedule method: {}".format(method))
79
+
80
+ def top_k_top_p_filtering(
81
+ logits: torch.Tensor,
82
+ top_k: int = 0,
83
+ top_p: float = 1.0,
84
+ filter_value: float = -float("Inf"),
85
+ min_tokens_to_keep: int = 1,
86
+ ) -> torch.Tensor:
87
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
88
+ Args:
89
+ logits: logits distribution shape (batch size, vocabulary size)
90
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
91
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
92
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
93
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
94
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
95
+ """
96
+ if top_k > 0:
97
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
98
+ # Remove all tokens with a probability less than the last token of the top-k
99
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
100
+ logits[indices_to_remove] = filter_value
101
+
102
+ if top_p < 1.0:
103
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
104
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
105
+
106
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
107
+ sorted_indices_to_remove = cumulative_probs > top_p
108
+ if min_tokens_to_keep > 1:
109
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
110
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
111
+ # Shift the indices to the right to keep also the first token above the threshold
112
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
113
+ sorted_indices_to_remove[..., 0] = 0
114
+
115
+ # scatter sorted tensors to original indexing
116
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
117
+ logits[indices_to_remove] = filter_value
118
+ return logits
mcp_servers/fashion_vlm/prompting_utils.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 NUS Show Lab.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ # TODO - SHOULD BE FURTHER IMPROVED.
18
+ class UniversalPrompting():
19
+ def __init__(self, text_tokenizer,
20
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
21
+ max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1):
22
+ """
23
+ :param text_tokenizer: original text tokenizer
24
+ """
25
+ self.text_tokenizer = text_tokenizer
26
+ self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
27
+ self.text_tokenizer.add_tokens(list(special_tokens))
28
+ self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
29
+ special_tokens}
30
+ self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
31
+ self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
32
+ self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
33
+ # plus 1 because at this time we add a task token before
34
+ self.max_text_len = max_text_len + 1
35
+ self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]')
36
+ self.ignore_id = ignore_id
37
+ self.cond_dropout_prob = cond_dropout_prob
38
+
39
+ def t2i_prompt(self, text_ids, image_ids, labels):
40
+
41
+ device = image_ids.device
42
+ sequence_ids = []
43
+ attention_masks = []
44
+ label_ids = []
45
+ probs = torch.rand(len(text_ids))
46
+ for i in range(len(text_ids)):
47
+
48
+ if len(text_ids[i]) == 0:
49
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
50
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
51
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
52
+
53
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
54
+
55
+ # randomly dropout text condition
56
+ if probs[i] < self.cond_dropout_prob:
57
+ temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
58
+
59
+ if self.max_text_len >= len(temp_ids):
60
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
61
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
62
+ else:
63
+ # should add the eos token
64
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
65
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
66
+
67
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
68
+ temp_label_ids = torch.cat([
69
+ # should we predict text tokens when doing image reconstruction?
70
+ torch.tensor(temp_ids).to(device),
71
+ self.sptids_dict['<|soi|>'].to(device),
72
+ labels[i],
73
+ self.sptids_dict['<|eoi|>'].to(device)
74
+ ], dim=0)
75
+
76
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
77
+
78
+ temp_ids = torch.cat([
79
+ torch.tensor(temp_ids).to(device),
80
+ self.sptids_dict['<|soi|>'].to(device),
81
+ image_ids[i],
82
+ self.sptids_dict['<|eoi|>'].to(device)
83
+ ], dim=0)
84
+
85
+ temp_masks = torch.tensor(temp_masks).to(device)
86
+ sequence_ids.append(temp_ids.unsqueeze(0))
87
+ attention_masks.append(temp_masks.unsqueeze(0))
88
+ label_ids.append(temp_label_ids.unsqueeze(0))
89
+
90
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
91
+
92
+ def t2i_gen_prompt(self, text_ids, image_ids):
93
+
94
+ device = image_ids.device
95
+ sequence_ids = []
96
+ attention_masks = []
97
+ for i in range(len(text_ids)):
98
+ if len(text_ids[i]) == 0:
99
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
100
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
101
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
102
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
103
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
104
+ if self.max_text_len >= len(temp_ids):
105
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
106
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
107
+ else:
108
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
109
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
110
+
111
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
112
+ temp_ids = torch.cat([
113
+ torch.tensor(temp_ids).to(device),
114
+ self.sptids_dict['<|soi|>'].to(device),
115
+ image_ids[i],
116
+ self.sptids_dict['<|eoi|>'].to(device)
117
+ ], dim=0)
118
+
119
+ temp_masks = torch.tensor(temp_masks).to(device)
120
+ sequence_ids.append(temp_ids.unsqueeze(0))
121
+ attention_masks.append(temp_masks.unsqueeze(0))
122
+
123
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
124
+
125
+ # language modeling
126
+ def lm_prompt(self, text_ids, max_seq_len):
127
+
128
+ sequence_ids = []
129
+ attention_masks = []
130
+ label_ids = []
131
+ for i in range(len(text_ids)):
132
+ if len(text_ids[i]) == 0:
133
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
134
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
135
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
136
+
137
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
138
+
139
+ if max_seq_len >= len(temp_ids):
140
+ temp_labels_ids = temp_ids + [self.ignore_id] * (max_seq_len - len(temp_ids))
141
+ temp_ids = temp_ids + [self.pad_id] * (max_seq_len - len(temp_ids))
142
+ temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
143
+ else:
144
+ # In language modeling, we only process text tokens. We do not add the eos token if the text length
145
+ # exceeds the max sequence length
146
+ temp_labels_ids = temp_ids[:max_seq_len]
147
+ temp_ids = temp_ids[:max_seq_len]
148
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
149
+
150
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
151
+ temp_ids = torch.tensor(temp_ids)
152
+ temp_masks = torch.tensor(temp_masks)
153
+ temp_labels_ids = torch.tensor(temp_labels_ids)
154
+
155
+ sequence_ids.append(temp_ids.unsqueeze(0))
156
+ attention_masks.append(temp_masks.unsqueeze(0))
157
+ label_ids.append(temp_labels_ids.unsqueeze(0))
158
+
159
+ # input_ids, masks, labels
160
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
161
+
162
+ def mmu_prompt(self, image_ids, text_ids):
163
+ device = image_ids.device
164
+ sequence_ids = []
165
+ attention_masks = []
166
+ label_ids = []
167
+ max_text_len = self.max_text_len - 1
168
+ for i in range(len(text_ids)):
169
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
170
+ # for empty list []
171
+
172
+ if len(text_ids[i]) == 0:
173
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
174
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
175
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
176
+
177
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
178
+
179
+ if max_text_len >= len(temp_ids):
180
+ # minus 1 because task token was prepended to the former image tokens
181
+ temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids))
182
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
183
+ else:
184
+ # should add the eos token
185
+ temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
186
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
187
+
188
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
189
+ temp_label_ids = torch.cat([
190
+ torch.tensor([self.ignore_id]).to(device),
191
+ torch.tensor([self.ignore_id]).to(device),
192
+ torch.ones_like(image_ids[i]) * self.ignore_id,
193
+ torch.tensor([self.ignore_id]).to(device),
194
+ torch.tensor(temp_ids).to(device),
195
+ ], dim=0)
196
+
197
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
198
+
199
+ temp_ids = torch.cat([
200
+ self.sptids_dict['<|mmu|>'].to(device), # task token
201
+ self.sptids_dict['<|soi|>'].to(device),
202
+ image_ids[i],
203
+ self.sptids_dict['<|eoi|>'].to(device),
204
+ torch.tensor(temp_ids).to(device),
205
+ ], dim=0)
206
+
207
+ temp_masks = torch.tensor(temp_masks).to(device)
208
+ sequence_ids.append(temp_ids.unsqueeze(0))
209
+ attention_masks.append(temp_masks.unsqueeze(0))
210
+ label_ids.append(temp_label_ids.unsqueeze(0))
211
+
212
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
213
+
214
+ def t2v_prompt(self, text_ids, image_ids, labels):
215
+
216
+ device = image_ids.device
217
+ sequence_ids = []
218
+ attention_masks = []
219
+ label_ids = []
220
+ probs = torch.rand(len(text_ids))
221
+ for i in range(len(text_ids)):
222
+
223
+ if len(text_ids[i]) == 0:
224
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
225
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
226
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
227
+
228
+ temp_ids = [int(self.sptids_dict['<|t2v|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
229
+
230
+ # randomly dropout text condition
231
+ if probs[i] < self.cond_dropout_prob:
232
+ temp_ids = [int(self.sptids_dict['<|t2v|>']), self.text_tokenizer.bos_token_id,
233
+ self.text_tokenizer.eos_token_id]
234
+
235
+ if self.max_text_len >= len(temp_ids):
236
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
237
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
238
+ else:
239
+ # should add the eos token
240
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
241
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
242
+
243
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
244
+ temp_label_ids = torch.cat([
245
+ # should we predict text tokens when doing image reconstruction?
246
+ torch.tensor(temp_ids).to(device),
247
+ self.sptids_dict['<|sov|>'].to(device),
248
+ labels[i],
249
+ self.sptids_dict['<|eov|>'].to(device)
250
+ ], dim=0)
251
+
252
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
253
+
254
+ temp_ids = torch.cat([
255
+ torch.tensor(temp_ids).to(device),
256
+ self.sptids_dict['<|sov|>'].to(device),
257
+ image_ids[i],
258
+ self.sptids_dict['<|eov|>'].to(device)
259
+ ], dim=0)
260
+
261
+ temp_masks = torch.tensor(temp_masks).to(device)
262
+ sequence_ids.append(temp_ids.unsqueeze(0))
263
+ attention_masks.append(temp_masks.unsqueeze(0))
264
+ label_ids.append(temp_label_ids.unsqueeze(0))
265
+
266
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
267
+
268
+ def t2v_gen_prompt(self, text_ids, image_ids):
269
+
270
+ device = image_ids.device
271
+ sequence_ids = []
272
+ attention_masks = []
273
+ for i in range(len(text_ids)):
274
+ if len(text_ids[i]) == 0:
275
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
276
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
277
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
278
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
279
+ temp_ids = [int(self.sptids_dict['<|t2v|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
280
+ if self.max_text_len >= len(temp_ids):
281
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
282
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
283
+ else:
284
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
285
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
286
+
287
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
288
+ temp_ids = torch.cat([
289
+ torch.tensor(temp_ids).to(device),
290
+ self.sptids_dict['<|sov|>'].to(device),
291
+ image_ids[i],
292
+ self.sptids_dict['<|eov|>'].to(device)
293
+ ], dim=0)
294
+
295
+ temp_masks = torch.tensor(temp_masks).to(device)
296
+ sequence_ids.append(temp_ids.unsqueeze(0))
297
+ attention_masks.append(temp_masks.unsqueeze(0))
298
+
299
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
300
+
301
+ def i2v_prompt(self, image_ids, video_ids):
302
+ """
303
+ :param image_ids:
304
+ :param video_ids:
305
+ :return:
306
+ """
307
+ pass
308
+
309
+ def lvg_prompt(self, text_ids, image_ids, labels):
310
+
311
+ device = image_ids.device
312
+ sequence_ids = []
313
+ attention_masks = []
314
+ label_ids = []
315
+ probs = torch.rand(len(text_ids))
316
+ probs2 = torch.rand(len(text_ids))
317
+ for i in range(len(text_ids)):
318
+
319
+ if len(text_ids[i]) == 0:
320
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
321
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
322
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
323
+
324
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
325
+
326
+ # randomly dropout text condition
327
+ if probs[i] < self.cond_dropout_prob:
328
+ temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id,
329
+ self.text_tokenizer.eos_token_id]
330
+
331
+ if self.max_text_len >= len(temp_ids):
332
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
333
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
334
+ else:
335
+ # should add the eos token
336
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
337
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
338
+
339
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
340
+ temp_label_ids = torch.cat([
341
+ # should we predict text tokens when doing image reconstruction?
342
+ torch.tensor(temp_ids).to(device),
343
+ self.sptids_dict['<|soi|>'].to(device),
344
+ labels[i],
345
+ self.sptids_dict['<|eoi|>'].to(device)
346
+ ], dim=0)
347
+
348
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
349
+
350
+ temp_ids = torch.cat([
351
+ torch.tensor(temp_ids).to(device),
352
+ self.sptids_dict['<|soi|>'].to(device),
353
+ image_ids[i],
354
+ self.sptids_dict['<|eoi|>'].to(device)
355
+ ], dim=0)
356
+
357
+ temp_masks = torch.tensor(temp_masks).to(device)
358
+ sequence_ids.append(temp_ids.unsqueeze(0))
359
+ attention_masks.append(temp_masks.unsqueeze(0))
360
+ label_ids.append(temp_label_ids.unsqueeze(0))
361
+
362
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
363
+
364
+ def lvg_gen_prompt(self, text_ids, image_ids):
365
+
366
+ device = image_ids.device
367
+ sequence_ids = []
368
+ attention_masks = []
369
+ for i in range(len(text_ids)):
370
+ if len(text_ids[i]) == 0:
371
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
372
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
373
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
374
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
375
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
376
+ if self.max_text_len >= len(temp_ids):
377
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
378
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
379
+ else:
380
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
381
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
382
+
383
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
384
+ temp_ids = torch.cat([
385
+ torch.tensor(temp_ids).to(device),
386
+ self.sptids_dict['<|soi|>'].to(device),
387
+ image_ids[i],
388
+ self.sptids_dict['<|eoi|>'].to(device)
389
+ ], dim=0)
390
+
391
+ temp_masks = torch.tensor(temp_masks).to(device)
392
+ sequence_ids.append(temp_ids.unsqueeze(0))
393
+ attention_masks.append(temp_masks.unsqueeze(0))
394
+
395
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
396
+
397
+ def mask_prompt(self):
398
+ pass
399
+
400
+ def __call__(self, input, task, padding=True, config=None):
401
+ """
402
+ input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
403
+ task (str) : a flag indicates the current task.
404
+ """
405
+ if task == "t2i":
406
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
407
+ image_ids = input[1] # (B, #tokens)
408
+ sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
409
+
410
+ elif task == "t2v":
411
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
412
+ image_ids = input[1] # (B, #tokens)
413
+ sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2])
414
+
415
+ elif task == "t2i_plus_lm":
416
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
417
+ image_ids = input[1] # (B, #tokens)
418
+ sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids,
419
+ input[2])
420
+ sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
421
+ return sequence_ids_with_masks, sequence_ids_with_masks_lm
422
+
423
+ elif task == "t2i_gen":
424
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
425
+ image_ids = input[1] # (B, #tokens)
426
+ sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
427
+
428
+ elif task == "t2v_gen":
429
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
430
+ image_ids = input[1] # (B, #tokens)
431
+ sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids)
432
+
433
+ elif task == "lm":
434
+ text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
435
+ sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
436
+
437
+ elif task == "mmu":
438
+ image_ids = input[0]
439
+ text_ids = self.text_tokenizer(input[1])['input_ids']
440
+ sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
441
+
442
+ elif task == "t2v":
443
+ text_ids = self.text_tokenizer(input[0]['input_ids'])
444
+ video_ids = self.vision_tokenizer(input[1])
445
+ sequence_ids_with_masks = self.t2v_prompt(text_ids, video_ids)
446
+
447
+ elif task == "i2v":
448
+ image_ids = self.text_tokenizer(input[0])
449
+ video_ids = self.vision_tokenizer(input[1])
450
+ sequence_ids_with_masks = self.i2v_prompt(image_ids, video_ids)
451
+
452
+ elif task == "lvg":
453
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
454
+ image_ids = input[1] # (B, #tokens)
455
+ sequence_ids_with_masks = self.lvg_prompt(text_ids, image_ids, input[2])
456
+
457
+ elif task == "lvg_gen":
458
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
459
+ image_ids = input[1] # (B, #tokens)
460
+ sequence_ids_with_masks = self.lvg_gen_prompt(text_ids, image_ids)
461
+ else:
462
+ raise NotImplementedError
463
+
464
+ return sequence_ids_with_masks
465
+
466
+ def create_attention_mask_predict_next(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, rm_pad_in_image=False,
467
+ return_inverse_mask=True):
468
+ # sequence is expected to be of shape [N, L]
469
+ N, L = sequence.shape
470
+
471
+ # Masks to identify different types of tokens
472
+ is_padding = sequence == pad_id
473
+
474
+ is_start_image = sequence == soi_id
475
+
476
+ is_end_image = sequence == eoi_id
477
+
478
+ # Create cumulative sum masks to identify regions of image tokens
479
+ cumulative_start = torch.cumsum(is_start_image, dim=1)
480
+ cumulative_end = torch.cumsum(is_end_image, dim=1)
481
+ in_image_segment = (cumulative_start > cumulative_end) | is_start_image | is_end_image
482
+
483
+ is_text = ~(in_image_segment)
484
+
485
+ causal_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)).to(sequence.device)
486
+
487
+ mask_text = is_text[:, :, None] * causal_mask[None, :, :]
488
+
489
+ is_text_image = is_text | in_image_segment
490
+
491
+ mask_text_image_bi = is_text_image[:, :, None] * is_text_image[:, None, :]
492
+ if rm_pad_in_image: # remove padding token in image
493
+ sid_img = torch.where(sequence == soi_id)[1]
494
+ for i in range(mask_text_image_bi.shape[0]):
495
+ pad_end_idx = torch.where(sequence[i] == pad_id)
496
+ if len(pad_end_idx[0]) != 0:
497
+ pad_end_idx = pad_end_idx[0][-1]
498
+ mask_text[i][pad_end_idx + 1:, :pad_end_idx + 1] = 0
499
+ id_padding = torch.where(is_padding[i] == True)
500
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
501
+
502
+ mask_text[in_image_segment] = mask_text_image_bi[in_image_segment]
503
+ # No token attends to padding tokens and padding tokens do not attend to any token
504
+ if return_inverse_mask:
505
+ inverted_mask = 1.0 - mask_text.type(sequence.dtype)
506
+ inverted_mask = inverted_mask.masked_fill(
507
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
508
+ )
509
+ return inverted_mask.unsqueeze(1)
510
+ else:
511
+ return mask_text.unsqueeze(1)
512
+
513
+ def create_attention_mask_lvg(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, return_inverse_mask=True):
514
+ # sequence is expected to be of shape [N, L]
515
+ N, L = sequence.shape
516
+ # Masks to identify different types of tokens
517
+ is_padding = sequence == pad_id
518
+ mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device)
519
+
520
+ sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
521
+ sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
522
+ eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
523
+ for i in range(N):
524
+ id_padding = torch.where(is_padding[i] == True)
525
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
526
+ for j in range(sid_img_for_bi.shape[-1]):
527
+ mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1,
528
+ sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1] = 1
529
+
530
+ # No token attends to padding tokens and padding tokens do not attend to any token
531
+ if return_inverse_mask:
532
+ inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
533
+ inverted_mask = inverted_mask.masked_fill(
534
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
535
+ )
536
+ return inverted_mask.unsqueeze(1)
537
+ else:
538
+ return mask_text_image_bi.unsqueeze(1)
539
+
540
+ # texts without attending image regions
541
+ def create_attention_mask_lvg_v2(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, sot_id=1000, eot_id=1001, return_inverse_mask=True):
542
+ # sequence is expected to be of shape [N, L]
543
+ N, L = sequence.shape
544
+ # Masks to identify different types of tokens
545
+ is_padding = sequence == pad_id
546
+ # is_text = torch.where(sequence < 2000, True, False)
547
+ is_text = torch.where(sequence < pad_id, True, False)
548
+ mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device).int()
549
+ sid_text_for_bi = torch.where(sequence == sot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
550
+ eid_text_for_bi = torch.where(sequence == eot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
551
+ # import ipdb
552
+ # ipdb.set_trace()
553
+ if sot_id == eot_id:
554
+ if sid_text_for_bi.shape[-1] % 2 != 0:
555
+ sid_text_for_bi = sid_text_for_bi[:, :-1]
556
+ eid_text_for_bi = eid_text_for_bi[:, :-1]
557
+ select_idx = [i for i in range(0, sid_text_for_bi.shape[1], 2)]
558
+ sid_text_for_bi = sid_text_for_bi[:, select_idx]
559
+ select_idx = [i+1 for i in range(0, eid_text_for_bi.shape[1], 2)]
560
+ eid_text_for_bi = eid_text_for_bi[:, select_idx]
561
+ sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
562
+ eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
563
+ all_zeros = torch.zeros_like(mask_text_image_bi).int()
564
+ for i in range(N):
565
+ all_zeros[i, :, is_text[i]] = 1
566
+ for j in range(sid_text_for_bi.shape[-1]):
567
+ all_zeros[i][is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
568
+ all_zeros[i][~is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
569
+ for j in range(sid_img_for_bi.shape[-1]):
570
+ all_zeros[i][~is_text[i], sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
571
+ mask_text_image_bi = mask_text_image_bi * all_zeros
572
+ sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
573
+
574
+ for i in range(N):
575
+ id_padding = torch.where(is_padding[i] == True)
576
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
577
+ for j in range(sid_img_for_bi.shape[-1]):
578
+ mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1, sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
579
+
580
+ mask_text_image_bi[:, :, 0] = 1
581
+ # No token attends to padding tokens and padding tokens do not attend to any token
582
+ if return_inverse_mask:
583
+ inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
584
+ inverted_mask = inverted_mask.masked_fill(
585
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
586
+ )
587
+ return inverted_mask.unsqueeze(1)
588
+ else:
589
+ return mask_text_image_bi.unsqueeze(1)
590
+
591
+ def create_attention_mask_for_mmu(sequence, eoi_id=128258, return_inverse_mask=True):
592
+ N, L = sequence.shape
593
+ causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
594
+ eoi_image = torch.where(sequence == eoi_id)[1]
595
+ causal_mask[:, :, :, :eoi_image[0] + 1] = 1
596
+
597
+ if return_inverse_mask:
598
+ inverted_mask = 1.0 - causal_mask.type(sequence.dtype)
599
+ inverted_mask = inverted_mask.masked_fill(
600
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
601
+ )
602
+ return inverted_mask
603
+ else:
604
+ return causal_mask
605
+
606
+ def create_attention_mask_for_mmu_vit(
607
+ sequence,
608
+ return_inverse_mask=True,
609
+ system_prompt_len=0
610
+ ):
611
+ N, L, H = sequence.shape
612
+ causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
613
+ index = 1 + system_prompt_len + 1 + 576
614
+ # PART OF SYSTEM PROMPT SHOULD BE CAUSAL ALSO
615
+ # causal_mask[:, :, :, :index] = 1
616
+ causal_mask[:, :, :, 1+system_prompt_len+1:index] = 1 # 把image token对应的列attention全部置为1
617
+ if return_inverse_mask:
618
+ inverted_mask = 1.0 - causal_mask.type(torch.int64)
619
+ inverted_mask = inverted_mask.masked_fill(
620
+ inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min
621
+ )
622
+ return inverted_mask
623
+ else:
624
+ return causal_mask
625
+
626
+
627
+ if __name__ == '__main__':
628
+ pass
mcp_servers/product_user_database.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mcp_server/product_user_database.py
2
+ import os
3
+ import pickle
4
+ import numpy as np
5
+ from typing import Dict, Any, List
6
+ from dotenv import load_dotenv
7
+ from tqdm import tqdm
8
+ from itertools import combinations
9
+
10
+ from scipy import sparse
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ from mcp.server.fastmcp import FastMCP
13
+ from mcp.server.sse import SseServerTransport
14
+ from starlette.applications import Starlette
15
+ from starlette.routing import Route, Mount
16
+ import uvicorn
17
+ import pandas as pd
18
+ import torch
19
+ from transformers import CLIPProcessor, CLIPModel
20
+ from openai import AsyncOpenAI
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+ FASHION_DATA_ROOT = os.getenv("FASHION_DATA_ROOT", "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec")
25
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
26
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
27
+ openai = AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
28
+
29
+ ###################################
30
+ #########Loading Model#############
31
+ ###################################
32
+ # Load CLIP model and processor
33
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
34
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", local_files_only=True)
35
+ clip_model.eval()
36
+
37
+
38
+ # Load item metadata
39
+ items_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/items_lite.parquet").set_index("item_id")
40
+ outfits_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/outfits_lite.parquet").set_index("outfit_id")
41
+ users_df = pd.read_parquet(f"{FASHION_DATA_ROOT}/meta/users_lite.parquet").set_index("user_id")
42
+ image_paths = items_df["path"].to_dict()
43
+
44
+
45
+ class InteractionDataManager:
46
+ def __init__(self, users_df, outfits_df, items_df):
47
+ """
48
+ 初始化类,加载数据并设置基本参数
49
+
50
+ 参数:
51
+ - users_file: 用户数据文件路径 (parquet)
52
+ - outfits_file: Outfit 数据文件路径 (parquet)
53
+ - items_file: 单品数据文件路径 (parquet)
54
+ """
55
+ self.users_df = users_df
56
+ self.outfits_df = outfits_df
57
+ self.items_df = items_df
58
+
59
+ # 创建映射
60
+ self.item_id_to_index = {item_id: index for index, item_id in enumerate(self.items_df.index)}
61
+ self.index_to_item_id = {index: item_id for index, item_id in enumerate(self.items_df.index)}
62
+ self.user_id_to_index = {user_id: index for index, user_id in enumerate(self.users_df.index)}
63
+ self.index_to_user_id = {index: user_id for index, user_id in enumerate(self.users_df.index)}
64
+ self.outfit_ids_dict = self.outfits_df['item_ids'].to_dict() # get outfit's item ids from outfit id
65
+ self.item_category_dict = self.items_df['category'].to_dict() # get item's category from item id
66
+ self.item_subcategory_dict = self.items_df['subcategory'].to_dict() # get item's subcategory from item id
67
+ self.n_items = len(self.items_df)
68
+ self.n_users = len(self.users_df)
69
+
70
+ self.user_outfit_pairs = []
71
+ outfit_set = set(self.outfits_df.index)
72
+ for uid, user in self.users_df.iterrows():
73
+ oids = user.outfit_ids.split(",")
74
+ self.user_outfit_pairs.extend([(uid, oid) for oid in oids if oid in outfit_set])
75
+
76
+ # 预处理类别到物品ID的映射(使用groupby)
77
+ self.subcategory_to_items = self.items_df.groupby('subcategory').apply(lambda x: set(x.index)).to_dict()
78
+
79
+ # 预处理类别到物品索引的映射(优化查找效率)
80
+ self.subcategory_to_indices = {}
81
+ for subcategory, item_ids in self.subcategory_to_items.items():
82
+ self.subcategory_to_indices[subcategory] = set([self.item_id_to_index[item_id]
83
+ for item_id in item_ids
84
+ if item_id in self.item_id_to_index])
85
+
86
+ item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/item_matrix.npz'
87
+ try:
88
+ self.load_matrix('item', item_interaction_matrix_path)
89
+ except FileNotFoundError:
90
+ self.build_item_interaction_matrix()
91
+ self.save_matrix('item', item_interaction_matrix_path)
92
+
93
+ user_item_interaction_matrix_path = f'{FASHION_DATA_ROOT}/data/personalized_recommendation/temp_matrix/user_item_matrix.npz'
94
+ try:
95
+ self.load_matrix('user_item', user_item_interaction_matrix_path)
96
+ except FileNotFoundError:
97
+ self.build_user_item_interaction_matrix()
98
+ self.save_matrix('user_item', user_item_interaction_matrix_path)
99
+
100
+ # 加载item clip features
101
+ with open(f"{FASHION_DATA_ROOT}/meta/clip_features.pkl", "rb") as f:
102
+ print("Loading Fashion Features...")
103
+ self.clip_features = pickle.load(f)
104
+ print("Loading Fashion Features Successfully")
105
+
106
+ # Prepare embeddings and item IDs
107
+ self.item_ids = list(self.clip_features.keys())
108
+ self.image_embeddings = np.array([self.clip_features[item_id]["image_embeds"] for item_id in item_ids])
109
+
110
+ def save_matrix(self, matrix_type, filepath):
111
+ """
112
+ 保存矩阵到文件
113
+
114
+ 参数:
115
+ - matrix_type: 'item' 或 'user_item',指定保存的矩阵类型
116
+ - filepath: 保存路径 (例如 'temp/item_matrix.npz')
117
+ """
118
+ if matrix_type == 'item':
119
+ matrix = self.item_interaction_matrix
120
+ elif matrix_type == 'user_item':
121
+ matrix = self.user_item_interaction_matrix
122
+ else:
123
+ raise ValueError("matrix_type must be 'item' or 'user_item'")
124
+
125
+ if matrix is None:
126
+ raise ValueError(f"{matrix_type} matrix has not been built yet.")
127
+
128
+ sparse.save_npz(filepath, matrix)
129
+ print(f"Saved {matrix_type} matrix to {filepath}")
130
+
131
+ def load_matrix(self, matrix_type, filepath):
132
+ """
133
+ 从文件加载矩阵
134
+
135
+ 参数:
136
+ - matrix_type: 'item' 或 'user_item',指定加载的矩阵类型
137
+ - filepath: 加载路径 (例如 'temp/item_matrix.npz')
138
+ """
139
+ if not os.path.exists(filepath):
140
+ raise FileNotFoundError(f"File {filepath} does not exist.")
141
+
142
+ matrix = sparse.load_npz(filepath)
143
+ if matrix_type == 'item':
144
+ self.item_interaction_matrix = matrix
145
+ elif matrix_type == 'user_item':
146
+ self.user_item_interaction_matrix = matrix
147
+ else:
148
+ raise ValueError("matrix_type must be 'item' or 'user_item'")
149
+
150
+ print(f"Loaded {matrix_type} matrix from {filepath}")
151
+ return matrix
152
+
153
+ def build_item_interaction_matrix(self):
154
+ """构建 Item-Item 交互矩阵"""
155
+ # 初始化单品交互矩阵
156
+ self.item_interaction_matrix = sparse.lil_matrix((self.n_items, self.n_items), dtype=int)
157
+
158
+ for index, outfit in tqdm(self.outfits_df.iterrows(), total=len(self.outfits_df)):
159
+ item_ids = outfit['item_ids'].split(',')
160
+ # 记录 item 对的共现
161
+ for item_id1, item_id2 in combinations(item_ids, r=2):
162
+ if item_id1 in self.item_id_to_index and item_id2 in self.item_id_to_index:
163
+ idx1 = self.item_id_to_index[item_id1]
164
+ idx2 = self.item_id_to_index[item_id2]
165
+ self.item_interaction_matrix[idx1, idx2] += 1
166
+ self.item_interaction_matrix[idx2, idx1] += 1 # 无序对称
167
+
168
+ # 转换为 CSR 格式
169
+ self.item_interaction_matrix = self.item_interaction_matrix.tocsr()
170
+ return self.item_interaction_matrix
171
+
172
+ def build_user_item_interaction_matrix(self):
173
+ """构建 User-Item 交互矩阵"""
174
+ # 初始化用户-单品交互矩阵
175
+ self.user_item_interaction_matrix = sparse.lil_matrix((self.n_users, self.n_items), dtype=int)
176
+
177
+ for uid, user in tqdm(self.users_df.iterrows(), total=len(self.users_df)):
178
+ oids = user["outfit_ids"].split(",")
179
+ outfits = self.outfits_df.loc[self.outfits_df.index.isin(oids)]
180
+ for oid, outfit in outfits.iterrows():
181
+ item_ids = outfit['item_ids'].split(',')
182
+ # 记录 user-item 对的出现
183
+ for iid in item_ids:
184
+ if iid in self.item_id_to_index:
185
+ uidx = self.user_id_to_index[uid]
186
+ iidx = self.item_id_to_index[iid]
187
+ self.user_item_interaction_matrix[uidx, iidx] += 1
188
+
189
+ # 转换为 CSR 格式
190
+ self.user_item_interaction_matrix = self.user_item_interaction_matrix.tocsr()
191
+ return self.user_item_interaction_matrix
192
+
193
+ def _process_interactions_for_category(
194
+ self,
195
+ matrix,
196
+ given_id,
197
+ category_indices,
198
+ id_to_index
199
+ ):
200
+ """
201
+ 处理单个实体与目标类别的交互
202
+
203
+ 参数:
204
+ - matrix: 交互矩阵
205
+ - given_id: 给定的实体ID(用户或物品)
206
+ - category_indices: 目标类别的物品索引集合
207
+
208
+ 返回:
209
+ - 交互列表,每个元素为一个包含item_id、interaction_count和score的字典
210
+ """
211
+ interactions = []
212
+
213
+ given_index = id_to_index[given_id]
214
+ row = matrix[given_index]
215
+
216
+ # 提取该行的非零元素
217
+ row_start = row.indptr[0]
218
+ row_end = row.indptr[1]
219
+ col_indices = row.indices[row_start:row_end]
220
+ data_values = row.data[row_start:row_end]
221
+
222
+ # 筛选出属于目标类别的物品
223
+ for col_idx, value in zip(col_indices, data_values):
224
+ # 检查是否为目标类别的物品
225
+ if col_idx in category_indices:
226
+ # 获取物品ID
227
+ output_id = self.index_to_item_id[col_idx]
228
+ interactions.append({
229
+ 'item_id': output_id,
230
+ 'interaction_count': int(value),
231
+ 'score': 0.0
232
+ })
233
+
234
+ return interactions
235
+
236
+ def get_item_category_interactions(
237
+ self,
238
+ target_category: str,
239
+ given_ids: List[str],
240
+ query_type='item', # item or user
241
+ top_k=None,
242
+ ):
243
+ """
244
+ 获取指定实体(用户或单品)与目标类别的所有交互情况
245
+
246
+ 参数:
247
+ - target_category: 待查询的subcategory
248
+ - given_ids: List of 目标类别
249
+ - query_type: 查询的类别, item或user
250
+ - top_k: 返回交互次数最多的前k个物品, 如果是None直接全部返回
251
+
252
+ 返回:
253
+ - 列表,包含与目标类别的交互统计信息,按交互次数排序
254
+ """
255
+ if query_type == 'item':
256
+ matrix = self.item_interaction_matrix
257
+ id_to_index = self.item_id_to_index
258
+ elif query_type == 'user':
259
+ matrix = self.user_item_interaction_matrix
260
+ id_to_index = self.user_id_to_index
261
+ else:
262
+ print(f'query_type must be either item or user but got {query_type}')
263
+ return []
264
+
265
+ # 收集所有交互记录
266
+ all_interactions = []
267
+ category = target_category
268
+ category_indices = self.subcategory_to_indices.get(category, set()) # 获取该类别的所有物品索引
269
+
270
+ # 获取该实体的所有交互
271
+ for given_id in given_ids:
272
+ interactions = self._process_interactions_for_category(
273
+ matrix, given_id, category_indices, id_to_index
274
+ )
275
+ # 将交互添加到结果列表
276
+ all_interactions.extend(interactions)
277
+
278
+ # 合并相同物品的交互次数
279
+ item_interactions = {}
280
+ for interaction in all_interactions:
281
+ item_id = interaction['item_id']
282
+ count = interaction['interaction_count']
283
+
284
+ if item_id in item_interactions:
285
+ item_interactions[item_id] += count
286
+ else:
287
+ item_interactions[item_id] = count
288
+
289
+ # 转换为结果格式
290
+ merged_interactions = [
291
+ {'item_id': item_id, 'interaction_count': count, 'score': 0.0}
292
+ for item_id, count in item_interactions.items()
293
+ ]
294
+
295
+ # 排序
296
+ if merged_interactions:
297
+ merged_interactions.sort(key=lambda x: x['interaction_count'], reverse=True)
298
+
299
+ # 截取top-k
300
+ if top_k and merged_interactions:
301
+ merged_interactions = merged_interactions[:top_k]
302
+
303
+ # 存储结果
304
+ return merged_interactions
305
+
306
+ def rank_by_similarity(self, item_interactions, user_interactions, beta=2.0):
307
+ """
308
+ 计算用户交互项与商品交互项的相似度并排序
309
+ """
310
+
311
+ def get_combined_features(feature_dict):
312
+ return (feature_dict['image_embeds'] + feature_dict['text_embeds']) / 2
313
+
314
+ item_feature_list = []
315
+ for item in item_interactions:
316
+ item_id = item['item_id']
317
+ if item_id not in self.clip_features:
318
+ raise ValueError(f"Didn't find clip feature of item with id: {item_id}")
319
+
320
+ item_features = get_combined_features(self.clip_features[item_id])
321
+ item_feature_list.append(item_features)
322
+
323
+ weights = np.array([x['interaction_count'] for x in item_interactions], dtype=np.float32)
324
+ weights = weights / np.sum(weights)
325
+ item_feature = np.sum(np.stack(item_feature_list, axis=0) * weights[:, np.newaxis], axis=0).reshape(1, -1)
326
+
327
+ max_count = max((user_item.get('interaction_count', 1) for user_item in user_interactions), default=1)
328
+ for user_item in user_interactions:
329
+ user_item_id = user_item['item_id']
330
+ if user_item_id not in self.clip_features:
331
+ raise ValueError(f"Didn't find clip feature of item with id: {user_item_id}")
332
+
333
+ user_item_features = get_combined_features(self.clip_features[user_item_id]).reshape(1, -1)
334
+ similarity = cosine_similarity(user_item_features, item_feature).item()
335
+ interaction_count = user_item['interaction_count']
336
+ count_factor = (interaction_count / max_count) * beta + 1
337
+ user_item['score'] = float(similarity) * count_factor
338
+
339
+ user_interactions.sort(key=lambda x: x.get('score', 0), reverse=True)
340
+ return user_interactions
341
+
342
+
343
+ data_manager = InteractionDataManager(users_df, outfits_df, items_df)
344
+ mcp = FastMCP('image-retrieval-server')
345
+
346
+
347
+ @mcp.tool()
348
+ async def summary_user_history(user_id: str, target_category: str, list_of_items: List[str]) -> str:
349
+ """Summary user's buying history of specific fashion category given user_id, target_category, list_of_items
350
+ After we collect all buying history of this user, we will summarize descriptions of these historical items through LLM.
351
+ So we will return user's preference about target_category in sentences.
352
+
353
+ Args:
354
+ user_id (str): User id. Will be provided through prompt
355
+ target_category (str): We care about user's buying history of this specific category.
356
+ list_of_items: List of item ids for history filtering. Will be provided through prompt
357
+ """
358
+ # We need to find the most appropriate item to become the target item
359
+ # It should have enough relationship with user and other items
360
+ # Specifically, item_interaction larger than 3, history larger than 10
361
+ item_interaction_result = data_manager.get_item_category_interactions(
362
+ target_category, list_of_items, query_type='item'
363
+ )
364
+ user_interaction_result = data_manager.get_item_category_interactions(
365
+ target_category, [user_id], query_type='user'
366
+ )
367
+
368
+ def get_description(item_id: str) -> str:
369
+ return data_manager.items_df.loc[item_id].gen_description
370
+
371
+ descriptions_for_summary = []
372
+ if len(item_interaction_result) == 0:
373
+ descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result]
374
+ else:
375
+ if len(user_interaction_result) >= 0:
376
+ user_interaction_result = data_manager.rank_by_similarity(
377
+ item_interaction_result,
378
+ user_interaction_result
379
+ )
380
+ descriptions_for_summary = [get_description(x['item_id']) for x in user_interaction_result[:5]]
381
+
382
+ if descriptions_for_summary:
383
+ user_message = f"Summary user's preference of {target_category} based on following descriptions of fashion items that user brought previously:"
384
+ for x in descriptions_for_summary:
385
+ user_message += f"\n{x}"
386
+ # Get summary using OpenAI API call
387
+ response = await openai.chat.completions.create(
388
+ model="gpt-4o-mini",
389
+ messages=[
390
+ {"role": "system", "content": f"You are a user preference summary assistant. Your response is limited in one sentence, staring at 'I prefer ...'"},
391
+ {"role": "user", "content": user_message}
392
+ ],
393
+ max_tokens=1000,
394
+ )
395
+ return response.choices[0].message.content
396
+ else:
397
+ return ""
398
+
399
+
400
+ user_id = "115"
401
+ # 根据类别和given outfit找到这个用户的历史交互
402
+ partial_outfit = ["25479e5dacebbfaed18a7dc4830bd5cd19114486", "becc7b46236e9abb6f6760e7a1569b06bbc236c1",
403
+ "180c32b5c8c164f3c632f3e73d6002ccfa6fea57"]
404
+ target_category = "Skirts"
405
+ summary_user_history(user_id, target_category, partial_outfit)
406
+
407
+
408
+ async def compute_text_embedding(text: str) -> np.ndarray:
409
+ inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True)
410
+ with torch.no_grad():
411
+ text_embedding = clip_model.get_text_features(**inputs).numpy()
412
+ return text_embedding / np.linalg.norm(text_embedding, axis=1, keepdims=True)
413
+
414
+
415
+ async def find_most_similar_image(text_embedding: np.ndarray) -> Dict[str, Any]:
416
+ similarities = np.dot(data_manager.image_embeddings, text_embedding.T).flatten()
417
+ most_similar_idx = np.argmax(similarities)
418
+ most_similar_item_id = data_manager.item_ids[most_similar_idx]
419
+ return {
420
+ "image_path": image_paths[most_similar_item_id],
421
+ "similarity": float(similarities[most_similar_idx])
422
+ }
423
+
424
+
425
+ @mcp.tool()
426
+ async def retrieve_image(text: str) -> Dict[str, Any]:
427
+ """Search for the most similar fashion image based on a text description.
428
+
429
+ Args:
430
+ text (str): Text description of the fashion item to search.
431
+ """
432
+ print(f"Searching for {text}")
433
+ text_embedding = await compute_text_embedding(text)
434
+ return await find_most_similar_image(text_embedding)
435
+
436
+
437
+ mcp_server = mcp._mcp_server # 获取内部 Server 对象
438
+ sse_transport = SseServerTransport("/messages/")
439
+
440
+
441
+ async def handle_sse(request):
442
+ print("Handling SSE connection")
443
+ async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
444
+ read_stream, write_stream = streams
445
+ await mcp_server.run(
446
+ read_stream,
447
+ write_stream,
448
+ mcp_server.create_initialization_options(),
449
+ )
450
+
451
+ # 定义路由
452
+ routes = [
453
+ Route("/sse", endpoint=handle_sse),
454
+ Mount("/messages/", app=sse_transport.handle_post_message),
455
+ ]
456
+
457
+ # 创建 Starlette 应用
458
+ starlette_app = Starlette(routes=routes)
459
+
460
+
461
+ if __name__ == "__main__":
462
+ print("Starting Image Retrieval server with HTTP and SSE...")
463
+ uvicorn.run(starlette_app, host="0.0.0.0", port=8001) # 使用 8001 端口,避免与 FashionVLM 冲突
mcp_servers/virtual_try_on.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import uuid
4
+
5
+ from dotenv import load_dotenv
6
+ import PIL
7
+ from google import genai
8
+ from google.genai import types
9
+ from mcp.server.fastmcp import FastMCP
10
+
11
+
12
+ load_dotenv()
13
+
14
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
15
+ GEN_IMG_DIR = os.getenv("GEN_IMG_DIR")
16
+ # os.environ["HTTP_PROXY"] = "http://127.0.0.1:10809"
17
+ # os.environ["HTTPS_PROXY"] = "http://127.0.0.1:10809"
18
+
19
+ client = genai.Client(api_key=GEMINI_API_KEY)
20
+ mcp = FastMCP("virtual_try_on")
21
+
22
+
23
+ async def save_image(response, path):
24
+ for part in response.candidates[0].content.parts:
25
+ if part.text is not None:
26
+ continue
27
+ elif part.inline_data is not None:
28
+ mime = part.inline_data.mime_type
29
+ data = part.inline_data.data
30
+ pathlib.Path(path).write_bytes(data)
31
+
32
+
33
+ @mcp.tool()
34
+ async def try_on(image_path: str) -> str:
35
+ """Generate a virtual try-on image based on image path and return the saved file path.
36
+
37
+ Args:
38
+ image_path str: Path to the input image file for try-on image generation
39
+
40
+ Returns:
41
+ str: File path of the generated image
42
+ """
43
+ try:
44
+ print(image_path)
45
+ response = client.models.generate_content(
46
+ model="models/gemini-2.0-flash-exp",
47
+ contents=[
48
+ "You are a virtual try on tool. Put all the clothes uploaded on a real person and create a picture. Only clothings should be put on, excluding shoes or bags or accessories.",
49
+ PIL.Image.open(os.path.abspath(image_path))
50
+ ],
51
+ config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
52
+ )
53
+ gen_img_filename = f'{GEN_IMG_DIR}/{uuid.uuid4().hex}.png'
54
+ await save_image(response, gen_img_filename)
55
+ return os.path.abspath(gen_img_filename)
56
+ except Exception as e:
57
+ print(e)
58
+ return image_path
59
+
60
+
61
+ def main():
62
+ print("Started MCP server 'virtual_try_on'...")
63
+ mcp.run(transport='stdio')
64
+
65
+
66
+ if __name__ == "__main__":
67
+ # image_path = "/mnt/d/PostDoc/fifth paper/code/FashionVLM/datasets/FashionRec/data/basic_recommendation/train/temp/0000000_target.jpg"
68
+ # print(image_path)
69
+ #
70
+ # def save_image(response, path):
71
+ # for part in response.candidates[0].content.parts:
72
+ # if part.inline_data is not None:
73
+ # data = part.inline_data.data
74
+ # pathlib.Path(path).write_bytes(data)
75
+ #
76
+ #
77
+ # client = genai.Client(api_key="AIzaSyCd3sP-FksEgLB2GCFom8UDvasWJ-glSL4")
78
+ #
79
+ # response = client.models.generate_content(
80
+ # model="models/gemini-2.0-flash-exp",
81
+ # contents=[
82
+ # "你是虚拟穿衣工具,把这一套衣服都穿到模特身上,输出一张图片,全身图",
83
+ # PIL.Image.open(image_path)
84
+ # ],
85
+ # config=types.GenerateContentConfig(response_modalities=['Text', 'Image'])
86
+ # )
87
+ #
88
+ # for part in response.candidates[0].content.parts:
89
+ # if part.text is not None:
90
+ # print(part.text)
91
+ #
92
+ # save_image(response, 'edited_image3.png')
93
+
94
+ main()
requirements.txt ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ albumentations==0.3.2
5
+ annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
+ anykeystore==0.2
8
+ asn1crypto==1.5.1
9
+ asttokens==2.4.1
10
+ async-timeout==4.0.3
11
+ attrs==21.2.0
12
+ bidict==0.23.1
13
+ blessed==1.20.0
14
+ boto3==1.34.113
15
+ botocore==1.34.113
16
+ braceexpand==0.1.7
17
+ cachetools==5.3.3
18
+ certifi==2024.2.2
19
+ cffi==1.16.0
20
+ chardet==5.2.0
21
+ charset-normalizer==3.3.2
22
+ click==8.1.7
23
+ clip==0.2.0
24
+ clip-openai==1.0.post20230121
25
+ cmake==3.29.3
26
+ cramjam==2.8.3
27
+ crcmod==1.7
28
+ cryptacular==1.6.2
29
+ cryptography==39.0.2
30
+ cycler==0.12.1
31
+ datasets==2.2.1
32
+ diffusers==0.30.1
33
+ decorator==5.1.1
34
+ decord==0.6.0
35
+ deepspeed==0.14.2
36
+ defusedxml==0.7.1
37
+ Deprecated==1.2.14
38
+ descartes==1.1.0
39
+ dill==0.3.8
40
+ distlib==0.3.8
41
+ distro-info==1.0
42
+ dnspython==2.6.1
43
+ docker-pycreds==0.4.0
44
+ docstring_parser==0.16
45
+ ecdsa==0.19.0
46
+ einops==0.6.0
47
+ exceptiongroup==1.2.1
48
+ executing==2.0.1
49
+ fairscale==0.4.13
50
+ fastparquet==2024.5.0
51
+ ffmpegcv==0.3.13
52
+ filelock==3.14.0
53
+ fire==0.6.0
54
+ fonttools==4.51.0
55
+ frozenlist==1.4.1
56
+ fsspec==2023.6.0
57
+ ftfy==6.2.0
58
+ gitdb==4.0.11
59
+ GitPython==3.1.43
60
+ gpustat==1.1.1
61
+ greenlet==3.0.3
62
+ grpcio==1.64.0
63
+ h11==0.14.0
64
+ hjson==3.1.0
65
+ huggingface-hub==0.23.2
66
+ hupper==1.12.1
67
+ idna==3.7
68
+ imageio==2.34.1
69
+ imgaug==0.2.6
70
+ iniconfig==2.0.0
71
+ ipaddress==1.0.23
72
+ ipdb==0.13.13
73
+ ipython==8.18.1
74
+ jaxtyping==0.2.28
75
+ jedi==0.19.1
76
+ Jinja2==3.1.4
77
+ jmespath==1.0.1
78
+ joblib==1.4.2
79
+ jsonargparse==4.14.1
80
+ jsonlines==4.0.0
81
+ kiwisolver==1.4.5
82
+ kornia==0.7.2
83
+ kornia_rs==0.1.3
84
+ lazy_loader==0.4
85
+ lightning==2.2.3
86
+ lightning-utilities==0.11.2
87
+ lit==18.1.6
88
+ MarkupSafe==2.1.5
89
+ matplotlib==3.5.3
90
+ matplotlib-inline==0.1.7
91
+ miscreant==0.3.0
92
+ mpmath==1.3.0
93
+ msgpack==1.0.8
94
+ multidict==6.0.5
95
+ multiprocess==0.70.16
96
+ natsort==8.4.0
97
+ networkx==3.2.1
98
+ ninja==1.11.1.1
99
+ numpy==1.24.4
100
+ nuscenes-devkit==1.1.11
101
+ oauthlib==3.2.2
102
+ omegaconf==2.3.0
103
+ open-clip-torch==2.24.0
104
+ openai-clip
105
+ opencv-python==4.9.0.80
106
+ opencv-python-headless==3.4.18.65
107
+ packaging==22.0
108
+ pandas==1.5.3
109
+ parquet==1.3.1
110
+ parso==0.8.4
111
+ PasteDeploy==3.1.0
112
+ pathlib2==2.3.7.post1
113
+ pathtools==0.1.2
114
+ pbkdf2==1.3
115
+ pexpect==4.9.0
116
+ pillow==10.3.0
117
+ plaster==1.1.2
118
+ plaster-pastedeploy==1.0.1
119
+ platformdirs==4.2.2
120
+ plotly==5.22.0
121
+ pluggy==1.5.0
122
+ ply==3.11
123
+ promise==2.3
124
+ prompt-toolkit==3.0.43
125
+ protobuf==3.20.3
126
+ psutil==5.9.8
127
+ ptyprocess==0.7.0
128
+ pure-eval==0.2.2
129
+ py==1.11.0
130
+ py-cpuinfo==9.0.0
131
+ py-spy==0.3.14
132
+ pyarrow==11.0.0
133
+ pyarrow-hotfix==0.6
134
+ pyasn1==0.6.0
135
+ pycocotools==2.0.7
136
+ pycparser==2.22
137
+ pycryptodomex==3.20.0
138
+ pycurl==7.43.0.6
139
+ pydantic==1.10.15
140
+ pydantic_core==2.18.3
141
+ Pygments==2.18.0
142
+ PyJWT==2.8.0
143
+ pynvml==11.5.0
144
+ pyope==0.2.2
145
+ pyOpenSSL==23.2.0
146
+ pyparsing==3.1.2
147
+ pyquaternion==0.9.9
148
+ pyramid==2.0.2
149
+ pyramid-mailer==0.15.1
150
+ pytest==6.2.5
151
+ python-consul==1.1.0
152
+ python-dateutil==2.9.0.post0
153
+ python-engineio==4.9.1
154
+ python-etcd==0.4.5
155
+ python-jose==3.3.0
156
+ python-socketio==5.11.2
157
+ python3-openid==3.2.0
158
+ pytorch-extension==0.2
159
+ pytorch-lightning==2.2.3
160
+ pytz==2024.1
161
+ PyYAML==6.0.1
162
+ regex==2024.5.15
163
+ repoze.sendmail==4.4.1
164
+ requests==2.31.0
165
+ requests-oauthlib==2.0.0
166
+ rsa==4.9
167
+ s3transfer==0.10.1
168
+ safetensors==0.4.3
169
+ schedule==1.2.2
170
+ scikit-image==0.22.0
171
+ scikit-learn==1.5.0
172
+ scipy==1.13.1
173
+ sentencepiece==0.2.0
174
+ sentry-sdk==2.3.1
175
+ setproctitle==1.3.3
176
+ Shapely==1.8.5.post1
177
+ shortuuid==1.0.13
178
+ simple-websocket==1.0.0
179
+ six==1.16.0
180
+ smmap==5.0.1
181
+ SQLAlchemy==2.0.30
182
+ stack-data==0.6.3
183
+ sympy==1.12
184
+ taming-transformers-rom1504==0.0.6
185
+ tenacity==8.3.0
186
+ tensorboardX==2.6.2.2
187
+ termcolor==2.4.0
188
+ threadpoolctl==3.5.0
189
+ thriftpy2==0.5.0
190
+ tifffile==2024.5.22
191
+ timm==1.0.3
192
+ tokenizers==0.19.1
193
+ toml==0.10.2
194
+ tomli==2.0.1
195
+ torch==2.2.1
196
+ torch-fidelity==0.3.0
197
+ torchmetrics==1.4.0.post0
198
+ torchvision==0.17.1
199
+ tox==3.28.0
200
+ tqdm==4.66.4
201
+ traitlets==5.14.3
202
+ transaction==4.0
203
+ transformers==4.41.1
204
+ translationstring==1.4
205
+ triton==2.2.0
206
+ typeguard==2.13.3
207
+ typing_extensions==4.12.0
208
+ tzdata==2024.1
209
+ urllib3==1.26.18
210
+ velruse==1.1.1
211
+ venusian==3.1.0
212
+ virtualenv==20.26.2
213
+ wandb==0.17.2
214
+ watchdog==4.0.1
215
+ wcwidth==0.2.13
216
+ webdataset==0.2.86
217
+ WebOb==1.8.7
218
+ websocket-client==1.8.0
219
+ wrapt==1.16.0
220
+ wsproto==1.2.0
221
+ WTForms==3.1.2
222
+ wtforms-recaptcha==0.3.2
223
+ xformers==0.0.25
224
+ xxhash==3.4.1
225
+ yarl==1.9.4
226
+ zope.deprecation==5.0
227
+ zope.interface==6.4.post2
228
+ zope.sqlalchemy==3.1
system_message.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM_MESSAGE = """
2
+ You are a fashion assistant. Use weather tool to get weather alert.
3
+
4
+ Use the retrieve_image tool when user ask your find product from the database based on user descriptions.
5
+ Use the image_generate tool to generate fashion item image from descriptions. The description must be in English!
6
+ Use the fashion_recommend_without_image tool to generate recommendations when there are no image paths included.
7
+ Use the fashion_recommend tool for those query with uploaded image. Provide recommendation according to the query and uploaded images.
8
+ Use the try_on tool for user's request of try-on uploaded images. Those clothing images paths will be provided in the prompt.
9
+ """
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from typing import List
3
+
4
+
5
+ def create_image_grid(image_paths: List[str], output_path: str, grid_size: int = 2) -> None:
6
+ """
7
+ 将多个图片合并为 4 宫格图片,少于 4 个则留空,确保透明背景被填充为白色,并实现等比缩放。
8
+
9
+ 参数:
10
+ - image_paths: 输入图片路径列表
11
+ - output_path: 输出 4 宫格图片路径
12
+ - grid_size: 网格大小(默认 2x2)
13
+ """
14
+ images = []
15
+ target_size = (256, 256) # 每个格子的目标尺寸
16
+
17
+ for path in image_paths[:4]:
18
+ try:
19
+ # 加载图片,保留透明性
20
+ img = Image.open(path).convert('RGBA')
21
+
22
+ # 如果图片有透明通道,将透明区域填充为白色
23
+ if img.mode == 'RGBA':
24
+ background = Image.new('RGBA', img.size, (255, 255, 255, 255)) # 白色背景
25
+ img = Image.alpha_composite(background, img)
26
+
27
+ # 转换为 RGB 模式
28
+ img = img.convert('RGB')
29
+ original_width, original_height = img.size
30
+ aspect_ratio = original_width / original_height
31
+ if original_width >= original_height:
32
+ # 宽度是长边,调整宽度到 256
33
+ new_width = 256
34
+ new_height = int(256 / aspect_ratio)
35
+ else:
36
+ # 高度是长边,调整高度到 256
37
+ new_height = 256
38
+ new_width = int(256 * aspect_ratio)
39
+
40
+ # 等比缩放
41
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
42
+ # img.thumbnail(target_size, Image.Resampling.LANCZOS) # 使用高质量缩放算法
43
+
44
+ # 创建 256x256 的空白画布(白色背景)
45
+ canvas = Image.new('RGB', target_size, (255, 255, 255))
46
+
47
+ # 计算居中位置
48
+ offset_x = (target_size[0] - img.size[0]) // 2
49
+ offset_y = (target_size[1] - img.size[1]) // 2
50
+
51
+ # 将缩放后的图片居中贴到画布上
52
+ canvas.paste(img, (offset_x, offset_y))
53
+
54
+ images.append(canvas)
55
+ except Exception as e:
56
+ print(f"Error loading image {path}: {e}")
57
+ images.append(None)
58
+
59
+ # 如果图片不足 4 张,补空
60
+ while len(images) < 4:
61
+ images.append(None)
62
+
63
+ # 创建空白画布(512x512,白色背景)
64
+ grid_image = Image.new('RGB', (512, 512), (255, 255, 255)) # 白色背景
65
+
66
+ # 按 2x2 排列贴图
67
+ for idx, img in enumerate(images):
68
+ if img is not None:
69
+ x = (idx % 2) * 256
70
+ y = (idx // 2) * 256
71
+ grid_image.paste(img, (x, y))
72
+
73
+ # 保存图片为 JPG 格式
74
+ grid_image.save(output_path, quality=95) # 设置质量为 95,避免过度压缩