| from transformers import PreTrainedModel, AutoModel, AutoConfig |
| from .configuration_vatrpp import VATrPPConfig |
| import os |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from .data.dataset import FolderDataset |
| from .models.model import VATr |
| from .models.util.vision import detect_text_bounds |
| from torchvision.transforms.functional import to_pil_image |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| def get_long_tail_chars(): |
| with open(f"files/longtail.txt", 'r') as f: |
| chars = [c.rstrip() for c in f] |
|
|
| chars.remove('') |
|
|
| return chars |
|
|
|
|
| class VATrPP(PreTrainedModel): |
| config_class = VATrPPConfig |
|
|
| def __init__(self, config: VATrPPConfig) -> None: |
| super().__init__(config) |
|
|
| config.english_words_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename=config.english_words_path) |
| config.mytext_path = hf_hub_download(repo_id="blowing-up-groundhogs/vatrpp", filename='mytext.txt') |
|
|
| self.model = VATr(config) |
| self.model.eval() |
|
|
| def set_style_folder(self, style_folder, num_examples=15): |
| word_lengths = None |
| if os.path.exists(os.path.join(style_folder, "word_lengths.txt")): |
| word_lengths = {} |
| with open(os.path.join(style_folder, "word_lengths.txt"), 'r') as f: |
| for line in f: |
| word, length = line.rstrip().split(",") |
| word_lengths[word] = int(length) |
|
|
| self.style_dataset = FolderDataset(style_folder, num_examples=num_examples, word_lengths=word_lengths) |
|
|
| @torch.no_grad() |
| def generate(self, gen_text, style_imgs, align_words: bool = False, at_once: bool = False): |
| style_images = style_imgs.unsqueeze(0).to(self.model.args.device) |
|
|
| fake = self.create_fake_sentence(style_images, gen_text, align_words, at_once) |
| return to_pil_image(fake) |
|
|
| @torch.no_grad() |
| def create_fake_sentence(self, style_images, text, align_words=False, at_once=False): |
| text = "".join([c for c in text if c in self.model.args.alphabet]) |
|
|
| text = text.split() if not at_once else [text] |
| gap = np.ones((32, 16)) |
|
|
| text_encode, len_text, encode_pos = self.model.netconverter.encode(text) |
| text_encode = text_encode.to(self.model.args.device).unsqueeze(0) |
|
|
| fake = self.model._generate_fakes(style_images, text_encode, len_text) |
| if not at_once: |
| if align_words: |
| fake = self.stitch_words(fake, show_lines=False) |
| else: |
| fake = np.concatenate(sum([[img, gap] for img in fake], []), axis=1)[:, :-16] |
| else: |
| fake = fake[0] |
| fake = (fake * 255).astype(np.uint8) |
|
|
| return fake |
|
|
| @torch.no_grad() |
| def generate_batch(self, style_imgs, text): |
| """ |
| Given a batch of style images and text, generate images using the model |
| """ |
| device = self.model.args.device |
| text_encode, _, _ = self.model.netconverter.encode(text) |
| fakes, _ = self.model.netG(style_imgs.to(device), text_encode.to(device)) |
| return fakes |
|
|
| @staticmethod |
| def stitch_words(words: list, show_lines: bool = False, scale_words: bool = False): |
| gap_width = 16 |
|
|
| bottom_lines = [] |
| top_lines = [] |
| for i in range(len(words)): |
| b, t = detect_text_bounds(words[i]) |
| bottom_lines.append(b) |
| top_lines.append(t) |
| if show_lines: |
| words[i] = cv2.line(words[i], (0, b), (words[i].shape[1], b), (0, 0, 1.0)) |
| words[i] = cv2.line(words[i], (0, t), (words[i].shape[1], t), (1.0, 0, 0)) |
|
|
| bottom_lines = np.array(bottom_lines, dtype=float) |
|
|
| if scale_words: |
| top_lines = np.array(top_lines, dtype=float) |
| gaps = bottom_lines - top_lines |
| target_gap = np.mean(gaps) |
| scales = target_gap / gaps |
|
|
| bottom_lines *= scales |
| top_lines *= scales |
| words = [cv2.resize(word, None, fx=scale, fy=scale) for word, scale in zip(words, scales)] |
|
|
| highest = np.max(bottom_lines) |
| offsets = highest - bottom_lines |
| height = np.max(offsets + [word.shape[0] for word in words]) |
|
|
| result = np.ones((int(height), gap_width * len(words) + sum([w.shape[1] for w in words]))) |
|
|
| x_pos = 0 |
| for bottom_line, word in zip(bottom_lines, words): |
| offset = int(highest - bottom_line) |
|
|
| result[offset:offset + word.shape[0], x_pos:x_pos+word.shape[1]] = word |
|
|
| x_pos += word.shape[1] + gap_width |
|
|
| return result |
|
|
|
|
| AutoConfig.register("vatrpp", VATrPPConfig) |
| AutoModel.register(VATrPPConfig, VATrPP) |