Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import BertTokenizerFast, BertForSequenceClassification,GPT2LMHeadModel,BartForConditionalGeneration | |
| import torch | |
| import math | |
| class CHSentenceSmoothScorer(): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.tokenizer = BertTokenizerFast.from_pretrained( | |
| "fnlp/bart-base-chinese") | |
| self.model = BartForConditionalGeneration.from_pretrained( | |
| "fnlp/bart-base-chinese") | |
| def __call__(self, sentences): | |
| input_ids = self.tokenizer.batch_encode_plus( | |
| sentences, return_tensors='pt', | |
| padding=True, | |
| max_length=50, | |
| truncation='longest_first' | |
| )['input_ids'] | |
| logits = self.model(input_ids).logits | |
| softmax = torch.softmax(logits, dim=-1) | |
| out = [] | |
| for i, sentence in enumerate(sentences): | |
| sent_token_ids = input_ids[i].tolist() | |
| sent_token_ids = list( | |
| filter(lambda x: x not in [self.tokenizer.pad_token_id], sent_token_ids)) | |
| ppl = 0.0 | |
| for j, token_id in enumerate(sent_token_ids): | |
| ppl += math.log(softmax[i][j][token_id].item()) | |
| ppl = -1*(ppl/len(sent_token_ids)) | |
| prob_socre = math.exp(ppl*-1) | |
| out.append(prob_socre) | |
| return out | |
| model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model') | |
| tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model') | |
| smooth_scorer = CHSentenceSmoothScorer() | |
| def judge(sentence): | |
| input_ids = tokenizer(sentence,return_tensors='pt')['input_ids'] | |
| out = model(input_ids) | |
| logits = out.logits | |
| prob = torch.softmax(logits,dim=-1) | |
| pred = torch.argmax(prob,dim=-1).item() | |
| pred_text = 'Incorrect' if pred == 0 else 'Correct' | |
| correct_prob = prob[0][1].item() | |
| pred_text = pred_text + f", score: {round(correct_prob*100,2)}" | |
| smooth_score = round(smooth_scorer([sentence])[0]*100,2) | |
| return pred_text,smooth_score | |
| iface = gr.Interface( | |
| fn=judge, | |
| inputs=gr.Textbox( | |
| label="請輸入一段中文句子來檢測正確性", | |
| lines=1, | |
| ), | |
| outputs=[ | |
| gr.Textbox( | |
| label="正確性檢查", | |
| lines=1 | |
| ), | |
| gr.Textbox( | |
| label="流暢性檢查", | |
| lines=1 | |
| ) | |
| ], | |
| examples = [ | |
| '請注意用字的鄭確性', | |
| '請注意用字的正確性' | |
| ] | |
| ) | |
| iface.launch() |