Spaces:
Runtime error
Runtime error
| import collections | |
| import os | |
| import pickle | |
| from argparse import Namespace | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch import cosine_similarity | |
| from transformers import AutoTokenizer, AutoModel | |
| def download_models(): | |
| # Import our models. The package will take care of downloading the models automatically | |
| model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, | |
| init_embeddings_model=None) | |
| model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args) | |
| return model | |
| class Text: | |
| def __init__(self, text_dir, model, num_steps, text_image_pkl_path=None, dict_text_pkl_path=None, pkl_path=None, dict_path=None, image_path=None, maps_path=None): | |
| self.dict_text_pkl_path = dict_text_pkl_path | |
| self.text_image_pkl_path = text_image_pkl_path | |
| self.text_dir = text_dir | |
| self.model = model | |
| self.num_steps = num_steps | |
| self.pkl_path = pkl_path | |
| self.dict_path = dict_path | |
| self.image_path = image_path | |
| self.maps_path = maps_path | |
| def get_embedding(self, texts): | |
| tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert") | |
| model = download_models() | |
| # 截断 | |
| # str or strList | |
| texts = texts if isinstance(texts, list) else [texts] | |
| for i in range(len(texts)): | |
| if len(texts[i]) > self.num_steps: | |
| texts[i] = texts[i][:self.num_steps] | |
| # Tokenize the texts | |
| inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt") | |
| # Extract the embeddings | |
| # Get the embeddings | |
| with torch.no_grad(): | |
| embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output | |
| return embeddings[0] if len(texts) == 1 else embeddings | |
| def read_text(self, save_embeddings=False, save_maps=False): | |
| """抽取、预存""" | |
| text_embeddings = collections.defaultdict() | |
| text_keys = [] | |
| dirs = os.listdir(self.text_dir) | |
| data = [] | |
| texts = [] | |
| id = 0 | |
| for dir in dirs: | |
| with open(self.text_dir + '/' + dir, 'r') as fr: | |
| for line in fr.readlines(): | |
| category = collections.defaultdict(str) | |
| ch = ':' if ':' in line else ':' | |
| if '旁白' in line: | |
| text = line.strip().split(ch)[1].strip() | |
| else: | |
| text = ''.join(list(line.strip().split(ch)[1])[1:-1]) # 提取「」内的文本 | |
| if text in text_keys: # 避免重复的text,导致embeds 和 maps形状不一致 | |
| continue | |
| text_keys.append(text) | |
| if save_maps: | |
| category["titles"] = dir.split('.')[0] | |
| category["id"] = str(id) | |
| category["text"] = text | |
| id = id + 1 | |
| data.append(dict(category)) | |
| texts.append(text) | |
| embeddings = self.get_embedding(texts) | |
| if save_embeddings: | |
| for text, embed in zip(texts, embeddings): | |
| text_embeddings[text] = self.get_embedding(text) | |
| if save_embeddings: | |
| self.store(self.pkl_path, text_embeddings) | |
| if save_maps: | |
| self.store(self.maps_path, data) | |
| return text_embeddings, data | |
| def load(self, load_pkl=False, load_maps=False, load_dict_text=False, load_text_image=False): | |
| if self.pkl_path and load_pkl: | |
| with open(self.pkl_path, 'rb') as f: | |
| return pickle.load(f) | |
| elif self.maps_path and load_maps: | |
| with open(self.maps_path, 'rb') as f: | |
| return pickle.load(f) | |
| elif self.dict_text_pkl_path and load_dict_text: | |
| with open(self.dict_text_pkl_path, 'rb') as f: | |
| return pickle.load(f) | |
| elif self.text_image_pkl_path and load_text_image: | |
| with open(self.text_image_pkl_path, 'rb') as f: | |
| return pickle.load(f) | |
| else: | |
| print("No pkl_path") | |
| def get_cosine_similarity(self, texts, get_image=False, get_texts=False): | |
| """ | |
| 计算文本列表的相似度避免重复计算query_similarity | |
| texts[0] = query | |
| """ | |
| if get_image: | |
| pkl = self.load(load_dict_text=True) | |
| elif get_texts: | |
| pkl = self.load(load_pkl=True) | |
| else: | |
| pkl = {} | |
| embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536) | |
| for text, embed in zip(texts, embeddings): | |
| pkl[text] = embed | |
| query_embedding = self.get_embedding(texts[0]).reshape(1, -1) | |
| texts_embeddings = np.array([value.numpy().reshape(-1, 1536) for value in pkl.values()]).squeeze(1) | |
| return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings)) | |
| def store(self, path, data): | |
| with open(path, 'wb+') as f: | |
| pickle.dump(data, f) | |
| def text_to_image(self, text, save_dict_text=False): | |
| """ | |
| 给定文本出图片 | |
| 计算query 和 texts 的相似度,取最高的作为new_query 查询image | |
| 到text_image_dict 读取图片名 | |
| 然后到images里面加载该图片然后返回 | |
| """ | |
| if save_dict_text: | |
| text_image = {} | |
| with open(self.dict_path, 'r') as f: | |
| data = f.readlines() | |
| for sub_text, image in zip(data[::2], data[1::2]): | |
| text_image[sub_text.strip()] = image.strip() | |
| self.store(self.text_image_pkl_path, text_image) | |
| keys_embeddings = {} | |
| embeddings = self.get_embedding(list(text_image.keys())) | |
| for key, embed in zip(text_image.keys(), embeddings): | |
| keys_embeddings[key] = embed | |
| self.store(self.dict_text_pkl_path, keys_embeddings) | |
| if self.dict_path and self.image_path: | |
| # 加载 text-imageName | |
| text_image = self.load(load_text_image=True) | |
| keys = list(text_image.keys()) | |
| keys.insert(0, text) | |
| query_similarity = self.get_cosine_similarity(keys, get_image=True) | |
| key_index = query_similarity.argmax(dim=0) | |
| text = list(text_image.keys())[key_index] | |
| image = text_image[text] + '.jpg' | |
| if image in os.listdir(self.image_path): | |
| res = Image.open(self.image_path + '/' + image) | |
| # res.show() | |
| return res | |
| else: | |
| print("Image doesn't exist") | |
| else: | |
| print("No path") | |
| def text_to_text(self, text): | |
| pkl = self.load(load_pkl=True) | |
| texts = list(pkl.keys()) | |
| texts.insert(0, text) | |
| texts_similarity = self.get_cosine_similarity(texts, get_texts=True) | |
| key_index = texts_similarity.argmax(dim=0).item() | |
| value = list(pkl.keys())[key_index] | |
| return value | |
| # if __name__ == '__main__': | |
| # pkl_path = './pkl/texts.pkl' | |
| # maps_path = './pkl/maps.pkl' | |
| # text_image_pkl_path='./pkl/text_image.pkl' | |
| # dict_path = "../characters/haruhi/text_image_dict.txt" | |
| # dict_text_pkl_path = './pkl/dict_text.pkl' | |
| # image_path = "../characters/haruhi/images" | |
| # text_dir = "../characters/haruhi/texts" | |
| # model = download_models() | |
| # text = Text(text_dir, text_image_pkl_path=text_image_pkl_path, maps_path=maps_path, | |
| # dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path, | |
| # dict_path=dict_path, image_path=image_path) | |
| # text.read_text(save_maps=True, save_embeddings=True) | |
| # data = text.load(load_pkl=True) | |
| # sub_text = "你好!" | |
| # image = text.text_to_image(sub_text) | |
| # print(image) | |
| # sub_texts = ["hello", "你好"] | |
| # print(text.get_cosine_similarity(sub_texts)) | |
| # value = text.text_to_text(sub_text) | |
| # print(value) | |