|
|
import torch |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
|
|
|
def custom_collate(data): |
|
|
|
|
|
target_lengths = [len(d['label']) for d in data] |
|
|
labels = [d['label'] for d in data] |
|
|
inputs = [d['img'].tolist() for d in data] |
|
|
idx = [d['idx'] for d in data] |
|
|
raw_label = [d['raw_label'] for d in data] |
|
|
|
|
|
target_lengths = torch.tensor(target_lengths) |
|
|
labels = pad_sequence(labels, batch_first=True) |
|
|
inputs = torch.tensor(inputs) |
|
|
idx = torch.tensor(idx) |
|
|
|
|
|
return { |
|
|
'idx': idx, |
|
|
'img': inputs, |
|
|
'label': labels, |
|
|
'target_lengths': target_lengths, |
|
|
'raw_label': raw_label, |
|
|
} |
|
|
|
|
|
def create_char_dicts(list_strings): |
|
|
text_to_seq = {} |
|
|
seq_to_text = {} |
|
|
value = 1 |
|
|
|
|
|
for text in list_strings: |
|
|
for character in text: |
|
|
if character not in text_to_seq: |
|
|
text_to_seq[character] = value |
|
|
seq_to_text[value] = character |
|
|
value += 1 |
|
|
return text_to_seq, seq_to_text |
|
|
|
|
|
def sample_text_to_seq(list_strings, mydict): |
|
|
return [mydict.get(character, "") for character in list_strings] |
|
|
|
|
|
def sample_seq_to_text(list_strings, mydict): |
|
|
return ''.join([mydict.get(character, "") for character in list_strings]) |