Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # Gradio app for Dhivehi typo correction | |
| import difflib | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| # Available models | |
| MODEL_OPTIONS_TYPO = { | |
| "A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr", | |
| "XS Model": "alakxender/dhivehi-quick-spell-check-t5" | |
| } | |
| # Function to load model and tokenizer | |
| def load_model(model_choice): | |
| print("Loading model and tokenizer...") | |
| try: | |
| selected_model = MODEL_OPTIONS_TYPO[model_choice] | |
| tokenizer = AutoTokenizer.from_pretrained(selected_model) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForSeq2SeqLM.from_pretrained(selected_model) | |
| # Move model to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"Model loaded successfully on {device}") | |
| return model, tokenizer, device | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, None, None | |
| # Function to correct typos (reverted to single output) | |
| def correct_typo(text, model, tokenizer, device): | |
| if not text.strip(): | |
| #return "Please enter some text." | |
| raise gr.Error("Please enter some textπ₯!", duration=5) | |
| if len(text.strip()) > 1024: | |
| #return "Shorter the better." | |
| raise gr.Error("Shorter the betterπ₯!", duration=5) | |
| try: | |
| # Prepare input with prefix | |
| input_text = "fix: " + text | |
| # Tokenize input | |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True) | |
| inputs = inputs.to(device) | |
| # Generate output | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask", None), | |
| max_length=128, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| # Decode the output | |
| corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return corrected_text | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Initialize model and tokenizer | |
| model, tokenizer, device = load_model("A3 Model") | |
| if model is None: | |
| print("Failed to load model. Please check your model and tokenizer paths.") | |
| # Function to highlight differences between original and corrected text | |
| def highlight_differences(original, corrected): | |
| d = difflib.Differ() | |
| orig_words = original.split() | |
| corr_words = corrected.split() | |
| diff = list(d.compare(orig_words, corr_words)) | |
| html_parts = [] | |
| i = 0 | |
| while i < len(diff): | |
| if diff[i].startswith(' '): # Unchanged | |
| html_parts.append(f'<span>{diff[i][2:]}</span>') | |
| elif diff[i].startswith('- '): # Removed | |
| if i + 1 < len(diff) and diff[i + 1].startswith('+ '): | |
| # Changed word - show correction | |
| old_word = diff[i][2:] | |
| new_word = diff[i + 1][2:] | |
| html_parts.append(f'<span style="background-color: #fff3cd">{old_word}</span>β<span style="background-color: #d4edda">{new_word}</span>') | |
| i += 1 | |
| else: | |
| # Removed word | |
| html_parts.append(f'<span style="background-color: #f8d7da">{diff[i][2:]}</span>') | |
| elif diff[i].startswith('+ '): # Added | |
| html_parts.append(f'<span style="background-color: #d4edda">{diff[i][2:]}</span>') | |
| i += 1 | |
| return f'<div class="dhivehi-diff">{" ".join(html_parts)}</div>' | |
| # Function to process the input for Gradio | |
| def process_input(text,model_choice): | |
| if model is None: | |
| load_model(model_choice) | |
| corrected = correct_typo(text, model, tokenizer, device) | |
| highlighted = highlight_differences(text, corrected) | |
| return corrected, highlighted | |
| # Define CSS for Dhivehi font styling | |
| css = """ | |
| .textbox1 textarea { | |
| font-size: 18px !important; | |
| font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
| line-height: 1.8 !important; | |
| direction: rtl !important; | |
| } | |
| .dhivehi-text { | |
| font-size: 18px !important; | |
| font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
| line-height: 1.8 !important; | |
| direction: rtl !important; | |
| text-align: right !important; | |
| padding: 10px !important; | |
| background: transparent !important; /* Make background transparent */ | |
| border-radius: 4px !important; | |
| color: #ffffff !important; /* White text for dark background */ | |
| } | |
| /* Style for the highlighted differences */ | |
| .dhivehi-diff { | |
| font-size: 18px !important; | |
| font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; | |
| line-height: 1.8 !important; | |
| direction: rtl !important; | |
| text-align: right !important; | |
| padding: 15px !important; | |
| background: transparent !important; /* Make background transparent */ | |
| border: 1px solid rgba(255, 255, 255, 0.1) !important; /* Subtle border */ | |
| border-radius: 4px !important; | |
| margin-top: 10px !important; | |
| color: #ffffff !important; /* White text for dark background */ | |
| } | |
| /* Ensure the highlighted spans have good contrast */ | |
| .dhivehi-diff span { | |
| padding: 2px 5px !important; | |
| border-radius: 3px !important; | |
| margin: 0 2px !important; | |
| } | |
| /* Original text (yellow background) */ | |
| .dhivehi-diff span[style*="background-color: #fff3cd"] { | |
| background-color: rgba(255, 243, 205, 0.2) !important; | |
| color: #ffd700 !important; /* Golden yellow for visibility */ | |
| border: 1px solid rgba(255, 243, 205, 0.3) !important; | |
| } | |
| /* Corrected text (green background) */ | |
| .dhivehi-diff span[style*="background-color: #d4edda"] { | |
| background-color: rgba(212, 237, 218, 0.2) !important; | |
| color: #98ff98 !important; /* Light green for visibility */ | |
| border: 1px solid rgba(212, 237, 218, 0.3) !important; | |
| } | |
| /* Removed text (red background) */ | |
| .dhivehi-diff span[style*="background-color: #f8d7da"] { | |
| background-color: rgba(248, 215, 218, 0.2) !important; | |
| color: #ff6b6b !important; /* Light red for visibility */ | |
| border: 1px solid rgba(248, 215, 218, 0.3) !important; | |
| } | |
| /* Arrow color */ | |
| .dhivehi-diff span:contains('β') { | |
| color: #ffffff !important; | |
| } | |
| """ |