Spaces:
Running
Running
| import random, os | |
| from PIL import Image | |
| import copy | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import warnings | |
| from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| try: | |
| from vouchervision.utils_LLM import SystemLoadMonitor | |
| except: | |
| from utils_LLM import SystemLoadMonitor | |
| warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") | |
| class FlorenceOCR: | |
| # def __init__(self, logger, model_id='microsoft/Florence-2-base'): | |
| def __init__(self, logger, model_id='microsoft/Florence-2-large'): | |
| self.MAX_TOKENS = 1024 | |
| self.logger = logger | |
| self.model_id = model_id | |
| self.monitor = SystemLoadMonitor(logger) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda() | |
| self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| # self.model_id_clean = "mistralai/Mistral-7B-v0.3" | |
| self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" | |
| self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean) | |
| # Configuring the BitsAndBytesConfig for quantization | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| quant_method="bnb", | |
| ) | |
| self.model_clean = AutoModelForCausalLM.from_pretrained( | |
| self.model_id_clean, | |
| quantization_config=quant_config, | |
| low_cpu_mem_usage=True,) | |
| def ocr_florence(self, image, task_prompt='<OCR>', text_input=None): | |
| self.monitor.start_monitoring_usage() | |
| # Open image if a path is provided | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| if text_input is None: | |
| prompt = task_prompt | |
| else: | |
| prompt = task_prompt + text_input | |
| inputs = self.processor(text=prompt, images=image, return_tensors="pt") | |
| # Move input_ids and pixel_values to the same device as the model | |
| inputs = {key: value.to(self.model.device) for key, value in inputs.items()} | |
| generated_ids = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=self.MAX_TOKENS, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer_dict = self.processor.post_process_generation( | |
| generated_text, | |
| task=task_prompt, | |
| image_size=(image.width, image.height) | |
| ) | |
| parsed_answer_text = parsed_answer_dict[task_prompt] | |
| # Prepare input for the second model | |
| inputs_clean = self.tokenizer_clean( | |
| f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_text}", | |
| return_tensors="pt" | |
| ) | |
| inputs_clean = {key: value.to(self.model_clean.device) for key, value in inputs_clean.items()} | |
| outputs_clean = self.model_clean.generate(**inputs_clean, max_new_tokens=self.MAX_TOKENS) | |
| text_with_spaces = self.tokenizer_clean.decode(outputs_clean[0], skip_special_tokens=True) | |
| # Extract only the LLM response from the decoded text | |
| response_start = text_with_spaces.find(parsed_answer_text) | |
| if response_start != -1: | |
| text_with_spaces = text_with_spaces[response_start + len(parsed_answer_text):].strip() | |
| print(text_with_spaces) | |
| self.monitor.stop_inference_timer() # Starts tool timer too | |
| usage_report = self.monitor.stop_monitoring_report_usage() | |
| return text_with_spaces, parsed_answer_text, parsed_answer_dict, usage_report | |
| def main(): | |
| # img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg' | |
| img_path = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg' | |
| image = Image.open(img_path) | |
| # ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-base') | |
| ocr = FlorenceOCR(logger = None, model_id='microsoft/Florence-2-large') | |
| results_text, results_all, results_dirty, usage_report = ocr.ocr_florence(image, task_prompt='<OCR>', text_input=None) | |
| print(results_text) | |
| if __name__ == '__main__': | |
| main() | |