cleaned up
Browse files- README.md +53 -39
- __pycache__/ssllm_hf.cpython-310.pyc +0 -0
- generate.py +3 -14
- ssllm_hf.py +2 -192
README.md
CHANGED
|
@@ -80,52 +80,66 @@ SSLLM is a 218M parameter decoder-only transformer language model created for te
|
|
| 80 |
from ssllm_hf import SSLLMForCausalLM, SSLLMConfig
|
| 81 |
import tiktoken
|
| 82 |
import torch
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
config = SSLLMConfig.from_pretrained('ssllm_hf')
|
| 86 |
-
model = SSLLMForCausalLM
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
prompt = "The future of artificial intelligence will"
|
| 93 |
-
input_ids = torch.tensor([tokenizer.encode(prompt)])
|
| 94 |
-
|
| 95 |
-
with torch.no_grad():
|
| 96 |
-
outputs = model.generate(
|
| 97 |
-
input_ids,
|
| 98 |
-
max_new_tokens=128,
|
| 99 |
-
do_sample=True,
|
| 100 |
-
temperature=0.7,
|
| 101 |
-
top_p=0.9,
|
| 102 |
-
repetition_penalty=1.2,
|
| 103 |
-
no_repeat_ngram_size=4,
|
| 104 |
-
pad_token_id=100257,
|
| 105 |
-
eos_token_id=100257,
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
generated_text = tokenizer.decode(outputs[0].tolist())
|
| 109 |
-
print(generated_text)
|
| 110 |
-
```
|
| 111 |
-
|
| 112 |
-
## Performance
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
|
| 126 |
**Prompt:** "In a small village nestled between mountains,"
|
| 127 |
|
| 128 |
-
**Output:** "In a small village nestled between mountains, lived two
|
|
|
|
|
|
|
| 129 |
|
| 130 |
## Limitations
|
| 131 |
|
|
@@ -135,7 +149,7 @@ The model produces coherent, contextually relevant text with:
|
|
| 135 |
- **Tokenizer:** Requires tiktoken library (not standard HuggingFace tokenizer)
|
| 136 |
- **Special Tokens:** Limited special token vocabulary
|
| 137 |
|
| 138 |
-
##
|
| 139 |
|
| 140 |
- Model outputs should be reviewed for potential biases
|
| 141 |
- Not suitable for generating harmful or inappropriate content
|
|
@@ -155,7 +169,7 @@ The model produces coherent, contextually relevant text with:
|
|
| 155 |
|
| 156 |
- **Framework:** PyTorch
|
| 157 |
- **HuggingFace Transformers:** Compatible with generation utilities
|
| 158 |
-
- **vLLM:** Requires GPT-2 format conversion
|
| 159 |
- **ONNX:** Not currently supported
|
| 160 |
- **TensorFlow:** Not supported
|
| 161 |
|
|
|
|
| 80 |
from ssllm_hf import SSLLMForCausalLM, SSLLMConfig
|
| 81 |
import tiktoken
|
| 82 |
import torch
|
| 83 |
+
from safetensors.torch import load_file
|
| 84 |
+
from huggingface_hub import hf_hub_download
|
| 85 |
|
| 86 |
+
# Initialize model with config
|
| 87 |
+
config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf')
|
| 88 |
+
model = SSLLMForCausalLM(config)
|
| 89 |
|
| 90 |
+
# Download and load model weights
|
| 91 |
+
model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors')
|
| 92 |
+
state_dict = load_file(model_path)
|
| 93 |
+
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# Setup device and eval mode
|
| 96 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 97 |
+
model = model.to(device).eval()
|
| 98 |
|
| 99 |
+
# Initialize tokenizer
|
| 100 |
+
tokenizer = tiktoken.get_encoding('cl100k_base')
|
| 101 |
|
| 102 |
+
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40):
|
| 103 |
+
# Encode the prompt
|
| 104 |
+
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
|
| 105 |
+
attention_mask = torch.ones_like(input_ids)
|
| 106 |
+
|
| 107 |
+
# Generate with the model
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
outputs = model.generate(
|
| 110 |
+
input_ids,
|
| 111 |
+
attention_mask=attention_mask,
|
| 112 |
+
max_new_tokens=max_new_tokens,
|
| 113 |
+
do_sample=True,
|
| 114 |
+
temperature=temperature,
|
| 115 |
+
top_p=top_p,
|
| 116 |
+
top_k=top_k,
|
| 117 |
+
pad_token_id=100257,
|
| 118 |
+
eos_token_id=100257,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Decode only the new tokens
|
| 122 |
+
new_tokens = outputs[0][input_ids.shape[1]:].tolist()
|
| 123 |
+
generated = tokenizer.decode(new_tokens)
|
| 124 |
+
|
| 125 |
+
print(f"\n{prompt}{generated}")
|
| 126 |
+
print(f"\nTokens generated: {len(new_tokens)}")
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
prompt = "In a small village nestled between mountains,"
|
| 130 |
+
|
| 131 |
+
# Test different generation settings
|
| 132 |
+
print(f"PROMPT: {prompt}")
|
| 133 |
+
generate_text(prompt)
|
| 134 |
+
```
|
| 135 |
|
| 136 |
+
## Example Outputs
|
| 137 |
|
| 138 |
**Prompt:** "In a small village nestled between mountains,"
|
| 139 |
|
| 140 |
+
**Output:** "In a small village nestled between mountains, lived two curious friends named Sam and Alex. They were always curious and loved learning new things. One day, while exploring the woods near the riverbank, they stumbled upon a mysterious object. It was a tiny, glowing object with a glowing light.
|
| 141 |
+
|
| 142 |
+
Sam explained that it had a special kind of light that could change how the light behaves. He told them that the light was made up of different colors and patterns, making it an even better way to see clearly. This made Sam and Alex curious."
|
| 143 |
|
| 144 |
## Limitations
|
| 145 |
|
|
|
|
| 149 |
- **Tokenizer:** Requires tiktoken library (not standard HuggingFace tokenizer)
|
| 150 |
- **Special Tokens:** Limited special token vocabulary
|
| 151 |
|
| 152 |
+
## Considerations
|
| 153 |
|
| 154 |
- Model outputs should be reviewed for potential biases
|
| 155 |
- Not suitable for generating harmful or inappropriate content
|
|
|
|
| 169 |
|
| 170 |
- **Framework:** PyTorch
|
| 171 |
- **HuggingFace Transformers:** Compatible with generation utilities
|
| 172 |
+
- **vLLM:** No (Requires GPT-2 format conversion)
|
| 173 |
- **ONNX:** Not currently supported
|
| 174 |
- **TensorFlow:** Not supported
|
| 175 |
|
__pycache__/ssllm_hf.cpython-310.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
generate.py
CHANGED
|
@@ -21,10 +21,6 @@ model = model.to(device).eval()
|
|
| 21 |
tokenizer = tiktoken.get_encoding('cl100k_base')
|
| 22 |
|
| 23 |
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40):
|
| 24 |
-
print(f"\n{'='*60}")
|
| 25 |
-
print(f"Temperature: {temperature}, Top-p: {top_p}, Top-k: {top_k}")
|
| 26 |
-
print(f"{'='*60}")
|
| 27 |
-
|
| 28 |
# Encode the prompt
|
| 29 |
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
|
| 30 |
attention_mask = torch.ones_like(input_ids)
|
|
@@ -45,21 +41,14 @@ def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=
|
|
| 45 |
|
| 46 |
# Decode only the new tokens
|
| 47 |
new_tokens = outputs[0][input_ids.shape[1]:].tolist()
|
| 48 |
-
|
| 49 |
-
# Try normal decoding first
|
| 50 |
-
generated = tokenizer.decode(new_tokens)
|
| 51 |
-
except:
|
| 52 |
-
# Fallback to per-token decoding if there are any invalid tokens
|
| 53 |
-
generated = ''.join(tokenizer.decode([t]) for t in new_tokens if t < tokenizer.n_vocab)
|
| 54 |
|
| 55 |
print(f"\n{prompt}{generated}")
|
| 56 |
print(f"\nTokens generated: {len(new_tokens)}")
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
-
prompt = "
|
| 60 |
|
| 61 |
# Test different generation settings
|
| 62 |
print(f"PROMPT: {prompt}")
|
| 63 |
-
generate_text(prompt
|
| 64 |
-
generate_text(prompt, temperature=0.7, top_p=0.9, top_k=40) # Balanced
|
| 65 |
-
generate_text(prompt, temperature=0.3, top_p=0.8, top_k=20) # Focused
|
|
|
|
| 21 |
tokenizer = tiktoken.get_encoding('cl100k_base')
|
| 22 |
|
| 23 |
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# Encode the prompt
|
| 25 |
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
|
| 26 |
attention_mask = torch.ones_like(input_ids)
|
|
|
|
| 41 |
|
| 42 |
# Decode only the new tokens
|
| 43 |
new_tokens = outputs[0][input_ids.shape[1]:].tolist()
|
| 44 |
+
generated = tokenizer.decode(new_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
print(f"\n{prompt}{generated}")
|
| 47 |
print(f"\nTokens generated: {len(new_tokens)}")
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|
| 50 |
+
prompt = "In a small village nestled between mountains,"
|
| 51 |
|
| 52 |
# Test different generation settings
|
| 53 |
print(f"PROMPT: {prompt}")
|
| 54 |
+
generate_text(prompt)
|
|
|
|
|
|
ssllm_hf.py
CHANGED
|
@@ -1,16 +1,11 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
that exactly matches SSLLM architecture but is compatible with HuggingFace.
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import os
|
| 8 |
-
import json
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
|
| 12 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 13 |
-
import tiktoken
|
| 14 |
|
| 15 |
class SSLLMConfig(PretrainedConfig):
|
| 16 |
"""Configuration class for SSLLM model compatible with HuggingFace."""
|
|
@@ -203,189 +198,4 @@ class SSLLMForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 203 |
|
| 204 |
def set_input_embeddings(self, new_embeddings):
|
| 205 |
"""Set input embeddings."""
|
| 206 |
-
self.token_embed = new_embeddings
|
| 207 |
-
|
| 208 |
-
def load_ssllm_checkpoint(checkpoint_path):
|
| 209 |
-
"""Load the SSLLM checkpoint and extract model state and config."""
|
| 210 |
-
print(f"Loading SSLLM checkpoint from: {checkpoint_path}")
|
| 211 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 212 |
-
|
| 213 |
-
if 'model_state_dict' in checkpoint:
|
| 214 |
-
state_dict = checkpoint['model_state_dict']
|
| 215 |
-
config = checkpoint.get('config', {})
|
| 216 |
-
else:
|
| 217 |
-
state_dict = checkpoint
|
| 218 |
-
config = {}
|
| 219 |
-
|
| 220 |
-
print(f"Loaded checkpoint with {len(state_dict)} parameters")
|
| 221 |
-
|
| 222 |
-
# Print parameter shapes for verification
|
| 223 |
-
total_params = sum(p.numel() for p in state_dict.values())
|
| 224 |
-
print(f"Total parameters in checkpoint: {total_params:,}")
|
| 225 |
-
|
| 226 |
-
return state_dict, config
|
| 227 |
-
|
| 228 |
-
def convert_ssllm_to_hf(ssllm_state_dict, ssllm_config):
|
| 229 |
-
"""Convert SSLLM state dict to HuggingFace format with exact parameter preservation."""
|
| 230 |
-
|
| 231 |
-
# Extract configuration
|
| 232 |
-
vocab_size = ssllm_config.get('vocab_size', 100277)
|
| 233 |
-
d_model = ssllm_config.get('d_model', 768)
|
| 234 |
-
num_heads = ssllm_config.get('num_heads', 12)
|
| 235 |
-
num_layers = ssllm_config.get('num_layers', 10)
|
| 236 |
-
d_ff = ssllm_config.get('d_ff', 2560)
|
| 237 |
-
max_seq_len = ssllm_config.get('max_seq_len', 1024)
|
| 238 |
-
|
| 239 |
-
print(f"Model config: vocab_size={vocab_size}, d_model={d_model}, num_heads={num_heads}")
|
| 240 |
-
print(f" num_layers={num_layers}, d_ff={d_ff}, max_seq_len={max_seq_len}")
|
| 241 |
-
|
| 242 |
-
# Create SSLLM configuration
|
| 243 |
-
config = SSLLMConfig(
|
| 244 |
-
vocab_size=vocab_size,
|
| 245 |
-
d_model=d_model,
|
| 246 |
-
num_heads=num_heads,
|
| 247 |
-
num_layers=num_layers,
|
| 248 |
-
d_ff=d_ff,
|
| 249 |
-
max_seq_len=max_seq_len,
|
| 250 |
-
dropout_rate=0.1,
|
| 251 |
-
attention_dropout=0.1,
|
| 252 |
-
stochastic_depth_rate=0.1,
|
| 253 |
-
bos_token_id=100256,
|
| 254 |
-
eos_token_id=100257,
|
| 255 |
-
pad_token_id=100257,
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
# Create SSLLM model
|
| 259 |
-
model = SSLLMForCausalLM(config)
|
| 260 |
-
|
| 261 |
-
print(f"Created SSLLM model with {sum(p.numel() for p in model.parameters())} parameters")
|
| 262 |
-
|
| 263 |
-
# Load state dict directly (should be exact match)
|
| 264 |
-
missing_keys, unexpected_keys = model.load_state_dict(ssllm_state_dict, strict=False)
|
| 265 |
-
|
| 266 |
-
if missing_keys:
|
| 267 |
-
print(f"Missing keys: {missing_keys}")
|
| 268 |
-
if unexpected_keys:
|
| 269 |
-
print(f"Unexpected keys: {unexpected_keys}")
|
| 270 |
-
|
| 271 |
-
# Verify parameter count
|
| 272 |
-
model_params = sum(p.numel() for p in model.parameters())
|
| 273 |
-
checkpoint_params = sum(p.numel() for p in ssllm_state_dict.values())
|
| 274 |
-
|
| 275 |
-
print(f"Parameter verification:")
|
| 276 |
-
print(f" Model parameters: {model_params:,}")
|
| 277 |
-
print(f" Checkpoint parameters: {checkpoint_params:,}")
|
| 278 |
-
print(f" Match: {'✅' if model_params == checkpoint_params else '❌'}")
|
| 279 |
-
|
| 280 |
-
return model, config
|
| 281 |
-
|
| 282 |
-
def save_hf_model(model, config, output_dir):
|
| 283 |
-
"""Save the converted model in HuggingFace format."""
|
| 284 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 285 |
-
|
| 286 |
-
# Save model and config
|
| 287 |
-
model.save_pretrained(output_dir)
|
| 288 |
-
|
| 289 |
-
# Create tokenizer config for cl100k_base
|
| 290 |
-
tokenizer_config = {
|
| 291 |
-
"tokenizer_class": "tiktoken",
|
| 292 |
-
"model_name": "cl100k_base",
|
| 293 |
-
"vocab_size": 100277,
|
| 294 |
-
"bos_token": "",
|
| 295 |
-
"eos_token": "",
|
| 296 |
-
"pad_token": "",
|
| 297 |
-
"unk_token": "",
|
| 298 |
-
"mask_token": "",
|
| 299 |
-
"additional_special_tokens": []
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
with open(os.path.join(output_dir, 'tokenizer_config.json'), 'w') as f:
|
| 303 |
-
json.dump(tokenizer_config, f, indent=2)
|
| 304 |
-
|
| 305 |
-
# Create generation config
|
| 306 |
-
generation_config = {
|
| 307 |
-
"bos_token_id": 100256,
|
| 308 |
-
"eos_token_id": 100257,
|
| 309 |
-
"pad_token_id": 100257,
|
| 310 |
-
"max_length": 1024,
|
| 311 |
-
"do_sample": True,
|
| 312 |
-
"temperature": 0.7,
|
| 313 |
-
"top_p": 0.9,
|
| 314 |
-
"repetition_penalty": 1.1,
|
| 315 |
-
"no_repeat_ngram_size": 3
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
with open(os.path.join(output_dir, 'generation_config.json'), 'w') as f:
|
| 319 |
-
json.dump(generation_config, f, indent=2)
|
| 320 |
-
|
| 321 |
-
# Create tokenizer info
|
| 322 |
-
with open(os.path.join(output_dir, 'tokenizer_info.txt'), 'w') as f:
|
| 323 |
-
f.write("Tokenizer: cl100k_base (tiktoken)\n")
|
| 324 |
-
f.write("Vocabulary size: 100277\n")
|
| 325 |
-
f.write("BOS token ID: 100256\n")
|
| 326 |
-
f.write("EOS token ID: 100257\n")
|
| 327 |
-
f.write("PAD token ID: 100257\n")
|
| 328 |
-
|
| 329 |
-
print(f"Model saved to: {output_dir}")
|
| 330 |
-
print("Files created:")
|
| 331 |
-
print(" - pytorch_model.bin (model weights)")
|
| 332 |
-
print(" - config.json (model configuration)")
|
| 333 |
-
print(" - tokenizer_config.json (tokenizer configuration)")
|
| 334 |
-
print(" - generation_config.json (generation parameters)")
|
| 335 |
-
print(" - tokenizer_info.txt (tokenizer metadata)")
|
| 336 |
-
|
| 337 |
-
def main():
|
| 338 |
-
"""Main conversion function."""
|
| 339 |
-
import argparse
|
| 340 |
-
|
| 341 |
-
parser = argparse.ArgumentParser(description='Convert SSLLM checkpoint to HuggingFace format')
|
| 342 |
-
parser.add_argument('--input', type=str, default='ssllm.pth',
|
| 343 |
-
help='Path to SSLLM checkpoint file (default: ssllm.pth)')
|
| 344 |
-
parser.add_argument('--output', type=str, default='ssllm_hf',
|
| 345 |
-
help='Output directory for HuggingFace model (default: ssllm_hf)')
|
| 346 |
-
|
| 347 |
-
args = parser.parse_args()
|
| 348 |
-
|
| 349 |
-
if not os.path.exists(args.input):
|
| 350 |
-
print(f"Error: Input checkpoint file '{args.input}' not found")
|
| 351 |
-
return
|
| 352 |
-
|
| 353 |
-
print("=" * 60)
|
| 354 |
-
print("SSLLM TO HUGGINGFACE CONVERSION")
|
| 355 |
-
print("=" * 60)
|
| 356 |
-
|
| 357 |
-
# Load SSLLM checkpoint
|
| 358 |
-
ssllm_state_dict, ssllm_config = load_ssllm_checkpoint(args.input)
|
| 359 |
-
|
| 360 |
-
# Convert to HuggingFace format with exact parameter preservation
|
| 361 |
-
model, config = convert_ssllm_to_hf(ssllm_state_dict, ssllm_config)
|
| 362 |
-
|
| 363 |
-
# Save in HuggingFace format
|
| 364 |
-
save_hf_model(model, config, args.output)
|
| 365 |
-
|
| 366 |
-
print("=" * 60)
|
| 367 |
-
print("CONVERSION COMPLETED SUCCESSFULLY!")
|
| 368 |
-
print("=" * 60)
|
| 369 |
-
print(f"Your model is now available at: {args.output}")
|
| 370 |
-
print("\nTo use with HuggingFace transformers:")
|
| 371 |
-
print("```python")
|
| 372 |
-
print("from transformers import AutoModel, AutoConfig")
|
| 373 |
-
print("import tiktoken")
|
| 374 |
-
print("")
|
| 375 |
-
print(f"# Load model")
|
| 376 |
-
print(f"model = AutoModel.from_pretrained('{args.output}', trust_remote_code=True)")
|
| 377 |
-
print("")
|
| 378 |
-
print("# Load tokenizer (tiktoken)")
|
| 379 |
-
print("tokenizer = tiktoken.get_encoding('cl100k_base')")
|
| 380 |
-
print("")
|
| 381 |
-
print("# Generate text")
|
| 382 |
-
print("input_text = 'Once upon a time'")
|
| 383 |
-
print("input_ids = torch.tensor([tokenizer.encode(input_text)])")
|
| 384 |
-
print("with torch.no_grad():")
|
| 385 |
-
print(" outputs = model.generate(input_ids, max_length=100, do_sample=True, temperature=0.7)")
|
| 386 |
-
print("generated_text = tokenizer.decode(outputs[0].tolist())")
|
| 387 |
-
print("print(generated_text)")
|
| 388 |
-
print("```")
|
| 389 |
-
|
| 390 |
-
if __name__ == "__main__":
|
| 391 |
-
main()
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
A custom model for causal language modeling, compatible with HuggingFace.
|
|
|
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
|
| 8 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
|
| 9 |
|
| 10 |
class SSLLMConfig(PretrainedConfig):
|
| 11 |
"""Configuration class for SSLLM model compatible with HuggingFace."""
|
|
|
|
| 198 |
|
| 199 |
def set_input_embeddings(self, new_embeddings):
|
| 200 |
"""Set input embeddings."""
|
| 201 |
+
self.token_embed = new_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|