Seth0330 commited on
Commit
762eb29
·
verified ·
1 Parent(s): 3ade64c

Create utils_ctc.py

Browse files
Files changed (1) hide show
  1. pdrt/utils_ctc.py +43 -0
pdrt/utils_ctc.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+
4
+ # CTC collate
5
+ def custom_collate(data):
6
+
7
+ target_lengths = [len(d['label']) for d in data]
8
+ labels = [d['label'] for d in data]
9
+ inputs = [d['img'].tolist() for d in data]
10
+ idx = [d['idx'] for d in data]
11
+ raw_label = [d['raw_label'] for d in data]
12
+
13
+ target_lengths = torch.tensor(target_lengths)
14
+ labels = pad_sequence(labels, batch_first=True)
15
+ inputs = torch.tensor(inputs)
16
+ idx = torch.tensor(idx)
17
+
18
+ return { #(6)
19
+ 'idx': idx,
20
+ 'img': inputs,
21
+ 'label': labels,
22
+ 'target_lengths': target_lengths,
23
+ 'raw_label': raw_label,
24
+ }
25
+
26
+ def create_char_dicts(list_strings):
27
+ text_to_seq = {}
28
+ seq_to_text = {}
29
+ value = 1 # 0 is blank token
30
+
31
+ for text in list_strings:
32
+ for character in text:
33
+ if character not in text_to_seq:
34
+ text_to_seq[character] = value
35
+ seq_to_text[value] = character
36
+ value += 1
37
+ return text_to_seq, seq_to_text
38
+
39
+ def sample_text_to_seq(list_strings, mydict):
40
+ return [mydict.get(character, "") for character in list_strings]
41
+
42
+ def sample_seq_to_text(list_strings, mydict):
43
+ return ''.join([mydict.get(character, "") for character in list_strings])