To use the model, downlow the weights and put this in the same directory as the weight:
import torch
import torch.nn as nn
import time
VOCAB_SIZE = 256
EMBED_DIM = 64
NUM_HEADS = 4
NUM_LAYERS = 2
SEQ_LEN = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MeowGPT(nn.Module):
def __init__(self):
super().__init__()
self.token_emb = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
self.pos_emb = nn.Embedding(SEQ_LEN, EMBED_DIM)
layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS, dim_feedforward=EMBED_DIM*4, batch_first=True)
self.transformer = nn.TransformerEncoder(layer, num_layers=NUM_LAYERS)
self.head = nn.Linear(EMBED_DIM, VOCAB_SIZE)
def forward(self, x):
B, T = x.shape
positions = torch.arange(T, device=x.device)
x = self.token_emb(x) + self.pos_emb(positions)
mask = torch.triu(torch.ones(T, T, device=DEVICE) * float('-inf'), diagonal=1)
x = self.transformer(x, mask=mask)
return self.head(x)
def load_bot():
model = MeowGPT().to(DEVICE)
try:
model.load_state_dict(torch.load("meow_model.bin", map_location=DEVICE))
print("Loaded meow_model.bin successfully!")
except FileNotFoundError:
print("Error: meow_model.bin not found. Run train.py first!")
exit()
model.eval()
return model
def chat():
model = load_bot()
print("\n--- MeowGPT is listening (type 'exit' to stop) ---")
while True:
text = input("You: ")
if text.lower() in ['exit', 'quit']: break
# Prepare input
input_ids = [ord(c) for c in text]
x = torch.tensor([input_ids], device=DEVICE)
generated_text = ""
for _ in range(10): # Max prediction length
if x.size(1) > SEQ_LEN: x = x[:, -SEQ_LEN:] # Truncate if too long
with torch.no_grad():
logits = model(x)
# Greedy decoding: pick the highest probability token
next_token_id = logits[0, -1].argmax().item()
if next_token_id == 0: # EOS token
break
generated_char = chr(next_token_id)
generated_text += generated_char
# Auto-regressive: append output to input
x = torch.cat([x, torch.tensor([[next_token_id]], device=DEVICE)], dim=1)
print(f"Bot: {generated_text}")
if __name__ == "__main__":
chat()