Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import sys | |
| import time | |
| import json | |
| from itertools import cycle | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from urllib.parse import unquote | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList | |
| from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields | |
| from examples import examples as input_examples | |
| from nuextract_logging import log_event | |
| MAX_INPUT_SIZE = 10_000 | |
| MAX_NEW_TOKENS = 4_000 | |
| MAX_WINDOW_SIZE = 4_000 | |
| markdown_description = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| </head> | |
| <body> | |
| <img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;"> | |
| <br> | |
| <ul> | |
| <li>NuMind is a startup developing custom information extraction solutions.</li> | |
| <li>NuExtract is a zero-shot model. See the blog posts for more info (<a href="https://numind.ai/blog/nuextract-a-foundation-model-for-structured-extraction">NuExtract</a>, <a href="https://numind.ai/blog/nuextract-1-5---multilingual-infinite-context-still-small-and-better-than-gpt-4o">NuExtract-v1.5</a>).</li> | |
| <li>We have started to deploy NuMind Enterprise to customize, serve, and monitor NuExtract privately. If that interests you, let's chat 😊.</li> | |
| <li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li> | |
| </ul> | |
| <h1>NuExtract-v1.5</h1> | |
| <p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction. | |
| It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian). | |
| To use the model, provide an input text and a JSON template describing the information you need to extract.</p> | |
| <ul> | |
| <li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li> | |
| </ul> | |
| <i>⚠️ In this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i> | |
| <br> | |
| <i>⚠️ The model is trained to assume a valid JSON template. Attempts to use invalid JSON could lead to unpredictable results.</i> | |
| </body> | |
| </html> | |
| """ | |
| def highlight_words(input_text, json_output): | |
| colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"]) | |
| color_map = {} | |
| highlighted_text = input_text | |
| leaves = extract_leaves(json_output) | |
| for path, value in leaves: | |
| path_key = tuple(path) | |
| if path_key not in color_map: | |
| color_map[path_key] = next(colors) | |
| color = color_map[path_key] | |
| escaped_value = re.escape(value).replace(r'\ ', r'\s+') # escape value and replace spaces with \s+ | |
| pattern = rf"(?<=[ \n\t]){escaped_value}(?=[ \n\t\.\,\?\:\;])" | |
| replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>" | |
| highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE) | |
| return highlighted_text | |
| def predict_chunk(text, template, current, model, tokenizer): | |
| current = clean_json_text(current) | |
| input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{" | |
| input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda") | |
| output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True) | |
| return clean_json_text(output.split("<|output|>")[1]) | |
| def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128): | |
| # Split text into chunks of n tokens | |
| tokens = tokenizer.tokenize(text) | |
| chunks = split_document(text, window_size, overlap, tokenizer) | |
| # Iterate over text chunks | |
| prev = template | |
| full_pred = "" | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i}...") | |
| pred = predict_chunk(chunk, template, prev, model, tokenizer) | |
| # Handle broken output | |
| pred = handle_broken_output(pred, prev) | |
| # create highlighted text | |
| try: | |
| highlighted_pred = highlight_words(text, json.loads(pred)) | |
| except: | |
| highlighted_pred = text | |
| # attempt json parsing | |
| template_dict = None | |
| pred_dict = None | |
| try: | |
| template_dict = json.loads(template) | |
| except: | |
| pass | |
| try: | |
| pred_dict = json.loads(pred) | |
| except: | |
| pass | |
| # Sync empty fields | |
| if template_dict and pred_dict: | |
| synced_pred = sync_empty_fields(pred_dict, template_dict) | |
| synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False) | |
| elif pred_dict: | |
| synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False) | |
| else: | |
| synced_pred = pred | |
| # Return progress, current prediction, and updated HTML | |
| yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred | |
| # Iterate | |
| prev = pred | |
| ###### | |
| # Model is loaded here but will be moved to CUDA only when needed with ZeroGPU | |
| model_name = "numind/NuExtract-v1.5" | |
| auth_token = os.environ.get("HF_TOKEN") or False | |
| # Load tokenizer in advance but not the model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token) | |
| # We define a function to load the model when needed | |
| def load_model(): | |
| model = AutoModelForCausalLM.from_pretrained(model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", use_auth_token=auth_token) | |
| model.eval() | |
| return model | |
| def gradio_interface_function(template, text, is_example): | |
| try: | |
| if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE: | |
| yield "", "Input text too long for space. Download model to use unrestricted.", "" | |
| return # End the function since there was an error | |
| # Load the model when needed | |
| model = load_model() | |
| # Initialize the sliding window prediction process | |
| prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE) | |
| # Iterate over the generator to return values at each step | |
| for progress, full_pred, html_content in prediction_generator: | |
| # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content) | |
| yield progress, full_pred, html_content | |
| # Conditionally log event if not an example and logging is configured | |
| if not is_example: | |
| try: | |
| log_event(text, template, full_pred) | |
| except Exception as e: | |
| print(f"Warning: Could not log event: {e}", file=sys.stderr) | |
| except Exception as e: | |
| error_message = f"Error processing request: {str(e)}" | |
| print(error_message, file=sys.stderr) | |
| yield "", error_message, "" | |
| # Set up the Gradio interface | |
| iface = gr.Interface( | |
| description=markdown_description, | |
| fn=gradio_interface_function, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"), | |
| gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"), | |
| gr.Checkbox(label="Is Example?", visible=False), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Progress"), | |
| gr.Textbox(label="Model Output"), | |
| gr.HTML(label="Model Output with Highlighted Words"), | |
| ], | |
| examples=input_examples, | |
| # live=True # Enable real-time updates | |
| ) | |
| iface.launch(debug=True) |