Spaces:
Paused
Paused
| import os | |
| import gc | |
| import gradio as gr | |
| import torch | |
| import random | |
| import logging | |
| import openai | |
| from openai import OpenAI | |
| from vecalign.plan2align import translate_text, external_find_best_translation | |
| from trl import AutoModelForCausalLMWithValueHead | |
| from huggingface_hub import login, HfApi, snapshot_download | |
| import spacy | |
| import subprocess | |
| import pkg_resources | |
| import sys | |
| login(token=os.environ.get("LA_NAME")) | |
| os.environ["LASER"] = "laser" | |
| import transformers | |
| print(transformers.__version__) | |
| def check_and_install(package, required_version): | |
| try: | |
| dist = pkg_resources.get_distribution(package) | |
| installed_version = dist.version | |
| if installed_version != required_version: | |
| print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"]) | |
| else: | |
| print(f"[{package}] required version {required_version} finished") | |
| except pkg_resources.DistributionNotFound: | |
| print(f"[{package}] not found, install: {required_version}...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"]) | |
| packages = { | |
| "pip": "24.0", | |
| "fairseq": "0.12.2", | |
| "torch": "2.6.0", | |
| "transformers": "4.51.3" | |
| } | |
| for package, version in packages.items(): | |
| check_and_install(package, version) | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm", | |
| "ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"] | |
| for model in models: | |
| try: | |
| spacy.load(model) | |
| except OSError: | |
| from spacy.cli import download | |
| download(model) | |
| try: | |
| spacy.load("zh_core_web_sm") | |
| except OSError: | |
| from spacy.cli import download | |
| download("zh_core_web_sm") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"]) | |
| # ---------- translation function ---------- | |
| # Initialize device | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load models once | |
| print("Loading models...") | |
| model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| def generate_translation(system_prompt, prompt): | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| return translation | |
| def check_token_length(text, max_tokens=1024): | |
| return len(text) <= max_tokens | |
| import uuid | |
| def get_user_session(state=None): | |
| if state is None: | |
| state = {} | |
| if not isinstance(state, dict): | |
| state = {} | |
| if not state.get("session_id"): | |
| state["session_id"] = uuid.uuid4().hex | |
| return state["session_id"] | |
| # ---------- Translation Function ---------- | |
| def mpc_initial_translate(source_sentence, src_language, tgt_language): | |
| system_prompts = [ | |
| "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
| "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
| "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." | |
| ] | |
| translations = [] | |
| for prompt_style in system_prompts: | |
| prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result." | |
| prompt += f"\n### {src_language}:\n {source_sentence}" | |
| prompt += f"\n### {tgt_language}:\n" | |
| translation = generate_translation(prompt_style, prompt) | |
| translations.append(translation) | |
| print("mpc_initial_translate") | |
| print(translations) | |
| return translations | |
| def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language): | |
| system_prompts = [ | |
| "You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.", | |
| "You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.", | |
| "You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging." | |
| ] | |
| translations = [] | |
| for prompt_style in system_prompts: | |
| prompt = (f"Source ({src_language}): {source_sentence}\n" | |
| f"Current Translation ({tgt_language}): {current_translation}\n" | |
| f"Please provide an improved translation into {tgt_language} and only output the result:") | |
| translation = generate_translation(prompt_style, prompt) | |
| translations.append(translation) | |
| print("mpc_improved_translate") | |
| print(translations) | |
| return translations | |
| def basic_translate(source_sentence, src_language, tgt_language): | |
| system_prompts = ["You are a helpful translator and only output the result."] | |
| translations = [] | |
| for prompt_style in system_prompts: | |
| prompt = f"### Translate this from {src_language} to {tgt_language}." | |
| prompt += f"\n### {src_language}:\n {source_sentence}" | |
| prompt += f"\n### {tgt_language}:\n" | |
| translation = generate_translation(prompt_style, prompt) | |
| translations.append(translation) | |
| return translations | |
| def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type): | |
| result = translate_text( | |
| text = text, | |
| model = model, | |
| tokenizer = tokenizer, | |
| device = device, | |
| src_language=src_language, | |
| task_language=task_language, | |
| max_iterations_value=max_iterations_value, | |
| threshold_value=threshold_value, | |
| good_ref_contexts_num_value=good_ref_contexts_num_value, | |
| reward_model_type=reward_model_type, | |
| session_id=session_id | |
| ) | |
| _, score = evaluate_candidates(text, [result], task_language, session_id) | |
| return result, score | |
| def evaluate_candidates(source, candidates, language, session_id): | |
| evals = [(source, candidates)] | |
| best_translations = external_find_best_translation(evals, language, session_id) | |
| best_candidate, best_score = best_translations[0] | |
| return best_candidate, best_score | |
| def original_translation(text, src_language, target_language, session_id): | |
| cand_list = basic_translate(text, src_language, target_language) | |
| best, score = evaluate_candidates(text, cand_list, target_language, session_id) | |
| if cand_list: | |
| return best, score | |
| return "", 0 | |
| def best_of_n_translation(text, src_language, target_language, n, session_id): | |
| if not check_token_length(text, 2048): | |
| return "Warning: Input text exceeds 2048 tokens.", None, "" | |
| candidates = [] | |
| for i in range(n): | |
| cand_list = basic_translate(text, src_language, target_language) | |
| if cand_list: | |
| candidates.append(cand_list[0]) | |
| best, score = evaluate_candidates(text, candidates, target_language, session_id) | |
| print("best_of_n evaluate_candidates results:") | |
| print(best, score) | |
| return best, score | |
| def mpc_translation(text, src_language, target_language, iterations, session_id): | |
| if not check_token_length(text, 2048): | |
| return "Warning: Input text exceeds 2048 tokens.", None, "" | |
| current_trans = "" | |
| best_score = None | |
| for i in range(iterations): | |
| if i == 0: | |
| cand_list = mpc_initial_translate(text, src_language, target_language) | |
| else: | |
| cand_list = mpc_improved_translate(text, current_trans, src_language, target_language) | |
| best, score = evaluate_candidates(text, cand_list, target_language, session_id) | |
| print("mpc evaluate_candidates results:") | |
| print(best, score) | |
| current_trans = best | |
| best_score = score | |
| return current_trans, best_score | |
| # ---------- Gradio function ---------- | |
| def process_text(text, src_language, target_language, max_iterations_value, threshold_value, | |
| good_ref_contexts_num_value, translation_methods=None, state=None): | |
| translation_methods = translation_methods or ["Original", "Plan2Align"] | |
| session_id = get_user_session(state) | |
| """ | |
| 傳入中文文本與目標語言,依序產生四種翻譯結果: | |
| 1. 原始翻譯 | |
| 2. Plan2Align 翻譯 | |
| 3. Best-of-N 翻譯 | |
| 4. MPC 翻譯 | |
| """ | |
| orig_output = "" | |
| plan2align_output = "" | |
| best_of_n_output = "" | |
| mpc_output = "" | |
| if "Original" in translation_methods: | |
| orig, best_score = original_translation(text, src_language, target_language, session_id) | |
| orig_output = f"{orig}\n\nScore: {best_score:.2f}" | |
| if "Plan2Align" in translation_methods: | |
| plan2align_trans, best_score = plan2align_translate_text( | |
| text, session_id, model, tokenizer, device, src_language, target_language, | |
| max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx" | |
| ) | |
| plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}" | |
| if "Best-of-N" in translation_methods: | |
| best_candidate, best_score = best_of_n_translation(text, src_language, target_language, | |
| max_iterations_value, session_id) | |
| best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}" | |
| if "MPC" in translation_methods: | |
| mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language, | |
| max_iterations_value, session_id) | |
| mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}" | |
| return orig_output, plan2align_output, best_of_n_output, mpc_output | |
| # ---------- Gradio ---------- | |
| target_languages = ["Chinese", "English", "Russian", "German", "Japanese", "Korean"] | |
| src_languages = ["Chinese", "English", "Russian", "German", "Japanese", "Korean"] | |
| with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo: | |
| state = gr.State({}) | |
| gr.Markdown("# Translation Demo: Multiple Translation Methods") | |
| gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| source_text = gr.Textbox( | |
| label="Source Text", | |
| placeholder="請輸入文本...", | |
| lines=5 | |
| ) | |
| src_language_input = gr.Dropdown( | |
| choices=src_languages, | |
| value="Chinese", | |
| label="Source Language" | |
| ) | |
| task_language_input = gr.Dropdown( | |
| choices=target_languages, | |
| value="English", | |
| label="Task (Target) Language" | |
| ) | |
| max_iterations_input = gr.Number(label="Max Iterations", value=6) | |
| threshold_input = gr.Number(label="Threshold", value=0.7) | |
| good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5) | |
| translation_methods_input = gr.CheckboxGroup( | |
| choices=["Original", "Plan2Align", "Best-of-N", "MPC"], | |
| value=["Original", "Plan2Align"], | |
| label="Translation Methods" | |
| ) | |
| translate_button = gr.Button("Translate") | |
| with gr.Column(scale=2): | |
| original_output = gr.Textbox( | |
| label="Original Translation", | |
| lines=5, | |
| interactive=False | |
| ) | |
| plan2align_output = gr.Textbox( | |
| label="Plan2Align Translation", | |
| lines=5, | |
| interactive=False | |
| ) | |
| best_of_n_output = gr.Textbox( | |
| label="Best-of-N Translation", | |
| lines=5, | |
| interactive=False | |
| ) | |
| mpc_output = gr.Textbox( | |
| label="MPC Translation", | |
| lines=5, | |
| interactive=False | |
| ) | |
| translate_button.click( | |
| fn=process_text, | |
| inputs=[ | |
| source_text, | |
| src_language_input, | |
| task_language_input, | |
| max_iterations_input, | |
| threshold_input, | |
| good_ref_contexts_num_input, | |
| translation_methods_input, | |
| state | |
| ], | |
| outputs=[original_output, plan2align_output, best_of_n_output, mpc_output] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"]], | |
| ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Chinese", "Russian", 2, 0.7, 1, ["Original", "Plan2Align"]], | |
| ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Chinese", "German", 2, 0.7, 1, ["Original", "Plan2Align"]], | |
| ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"]], | |
| ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"]] | |
| ], | |
| inputs=[ | |
| source_text, | |
| src_language_input, | |
| task_language_input, | |
| max_iterations_input, | |
| threshold_input, | |
| good_ref_contexts_num_input, | |
| translation_methods_input | |
| ], | |
| outputs=[original_output, plan2align_output, best_of_n_output, mpc_output], | |
| fn=process_text | |
| ) | |
| gr.Markdown("## How It Works") | |
| gr.Markdown(""" | |
| 1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。 | |
| 2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。 | |
| 3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。 | |
| 4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。 | |
| 若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。 | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |