Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| from language.responder import Responder | |
| # ===== LOAD TOKENIZER ===== | |
| class SimpleTokenizer: | |
| def __init__(self, vocab_path): | |
| with open(vocab_path, "r") as f: | |
| self.vocab = json.load(f) | |
| self.inv_vocab = {v: k for k, v in self.vocab.items()} | |
| def encode(self, text): | |
| return [self.vocab.get(t, 1) for t in text.split()] | |
| def decode(self, ids): | |
| return " ".join([self.inv_vocab.get(i, "<unk>") for i in ids]) | |
| # ===== LOAD MODEL ===== | |
| def load_responder(device): | |
| vocab_path = "artifacts/vocab.json" | |
| model_path = "artifacts/responder.pt" | |
| brain_path = "artifacts/replay_buffer.pt" | |
| tokenizer = SimpleTokenizer(vocab_path) | |
| # 🔥 GET REAL TARGET_DIM FROM TRAINING DATA | |
| brain_data = torch.load(brain_path, map_location=device) | |
| target_dim = brain_data[0][1].shape[-1] | |
| print(f"[INFO] TARGET_DIM (from training): {target_dim}") | |
| model = Responder( | |
| vocab_size=len(tokenizer.vocab), | |
| embed_dim=192, | |
| hidden_dim=512, | |
| target_dim=target_dim | |
| ).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| return model, tokenizer | |
| # ===== GENERATE ===== | |
| def generate_response(model, tokenizer, text, device): | |
| with torch.no_grad(): | |
| # 🔥 MUST MATCH hidden_dim (512) | |
| brain_input = torch.randn(1, 512).to(device) | |
| output = model.generate_from_brain( | |
| brain_input=brain_input, | |
| tokenizer=tokenizer, | |
| max_len=64, | |
| intent="general", | |
| emotion="neutral", | |
| memory_text=text | |
| ) | |
| return output | |
| # ===== MAIN LOOP ===== | |
| def main(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[INFO] Using device: {device}") | |
| model, tokenizer = load_responder(device) | |
| print("\n=== RESPONDER TEST MODE ===") | |
| print("Type 'exit' to quit\n") | |
| while True: | |
| user_input = input("You: ") | |
| if user_input.lower() in ["exit", "quit"]: | |
| break | |
| response = generate_response(model, tokenizer, user_input, device) | |
| print(f"AI: {response}\n") | |
| if __name__ == "__main__": | |
| main() |