Spaces:
Runtime error
Runtime error
| import concurrent | |
| import io | |
| import logging | |
| import re | |
| import cairosvg | |
| import kagglehub | |
| import torch | |
| from lxml import etree | |
| from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template | |
| svg_constraints = kagglehub.package_import('metric/svg-constraints') | |
| class NaiveModel: | |
| def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"): | |
| self.device = device | |
| self.max_seq_length = max_seq_length | |
| self.load_in_4bit = True | |
| # Load the Unsloth Phi-4 model | |
| self.model, self.tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=self.max_seq_length, | |
| load_in_4bit=self.load_in_4bit | |
| ) | |
| # Set up chat template | |
| self.tokenizer = get_chat_template( | |
| self.tokenizer, | |
| chat_template="phi-4", | |
| ) | |
| # Prepare model for inference | |
| FastLanguageModel.for_inference(self.model) | |
| self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints. | |
| <constraints> | |
| * **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs` | |
| * **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity` | |
| </constraints> | |
| Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis. | |
| <description>"A red circle with a blue square inside"</description> | |
| ```svg | |
| <svg viewBox="0 0 256 256" width="256" height="256"> | |
| <circle cx="50" cy="50" r="40" fill="red"/> | |
| <rect x="30" y="30" width="40" height="40" fill="blue"/> | |
| </svg> | |
| ``` | |
| <description>"{}"</description> | |
| """ | |
| self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>""" | |
| self.constraints = svg_constraints.SVGConstraints() | |
| self.timeout_seconds = 90 | |
| def predict(self, description: str, max_new_tokens=512) -> str: | |
| def generate_svg(): | |
| try: | |
| # Format the prompt | |
| prompt = self.prompt_template.format(description) | |
| # Create messages in the format expected by the chat template | |
| messages = [ | |
| {"role": "user", "content": prompt}, | |
| ] | |
| # Tokenize the messages | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| # Generate the output | |
| outputs = self.model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| temperature=1.0, | |
| min_p=0.1, | |
| do_sample=True, | |
| ) | |
| # Decode the output | |
| output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the generated text (skip the prompt) | |
| generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else "" | |
| logging.debug('Output decoded from model: %s', output_decoded) | |
| matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE) | |
| if matches: | |
| svg = matches[-1] | |
| else: | |
| return self.default_svg | |
| logging.debug('Unprocessed SVG: %s', svg) | |
| svg = self.enforce_constraints(svg) | |
| logging.debug('Processed SVG: %s', svg) | |
| # Ensure the generated code can be converted by cairosvg | |
| cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
| return svg | |
| except Exception as e: | |
| logging.error('Exception during SVG generation: %s', e) | |
| return self.default_svg | |
| # Execute SVG generation in a new thread to enforce time constraints | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(generate_svg) | |
| try: | |
| return future.result(timeout=self.timeout_seconds) | |
| except concurrent.futures.TimeoutError: | |
| logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds) | |
| return self.default_svg | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred: {e}") | |
| return self.default_svg | |
| def enforce_constraints(self, svg_string: str) -> str: | |
| """Enforces constraints on an SVG string, removing disallowed elements | |
| and attributes. | |
| Parameters | |
| ---------- | |
| svg_string : str | |
| The SVG string to process. | |
| Returns | |
| ------- | |
| str | |
| The processed SVG string, or the default SVG if constraints | |
| cannot be satisfied. | |
| """ | |
| logging.info('Sanitizing SVG...') | |
| try: | |
| parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) | |
| root = etree.fromstring(svg_string, parser=parser) | |
| except etree.ParseError as e: | |
| logging.error('SVG Parse Error: %s. Returning default SVG.', e) | |
| logging.error('SVG string: %s', svg_string) | |
| return self.default_svg | |
| elements_to_remove = [] | |
| for element in root.iter(): | |
| tag_name = etree.QName(element.tag).localname | |
| # Remove disallowed elements | |
| if tag_name not in self.constraints.allowed_elements: | |
| elements_to_remove.append(element) | |
| continue # Skip attribute checks for removed elements | |
| # Remove disallowed attributes | |
| attrs_to_remove = [] | |
| for attr in element.attrib: | |
| attr_name = etree.QName(attr).localname | |
| if ( | |
| attr_name | |
| not in self.constraints.allowed_elements[tag_name] | |
| and attr_name | |
| not in self.constraints.allowed_elements['common'] | |
| ): | |
| attrs_to_remove.append(attr) | |
| for attr in attrs_to_remove: | |
| logging.debug( | |
| 'Attribute "%s" for element "%s" not allowed. Removing.', | |
| attr, | |
| tag_name, | |
| ) | |
| del element.attrib[attr] | |
| # Check and remove invalid href attributes | |
| for attr, value in element.attrib.items(): | |
| if etree.QName(attr).localname == 'href' and not value.startswith('#'): | |
| logging.debug( | |
| 'Removing invalid href attribute in element "%s".', tag_name | |
| ) | |
| del element.attrib[attr] | |
| # Validate path elements to help ensure SVG conversion | |
| if tag_name == 'path': | |
| d_attribute = element.get('d') | |
| if not d_attribute: | |
| logging.warning('Path element is missing "d" attribute. Removing path.') | |
| elements_to_remove.append(element) | |
| continue # Skip further checks for this removed element | |
| # Use regex to validate 'd' attribute format | |
| path_regex = re.compile( | |
| r'^' # Start of string | |
| r'(?:' # Non-capturing group for each command + numbers block | |
| r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) | |
| r'\s*' # Optional whitespace after command | |
| r'(?:' # Non-capturing group for optional numbers | |
| r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number | |
| r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) | |
| r')?' # Numbers are optional (e.g. for Z command) | |
| r'\s*' # Optional whitespace after numbers/command block | |
| r')+' # One or more command blocks | |
| r'\s*' # Optional trailing whitespace | |
| r'$' # End of string | |
| ) | |
| if not path_regex.match(d_attribute): | |
| logging.warning( | |
| 'Path element has malformed "d" attribute format. Removing path.' | |
| ) | |
| elements_to_remove.append(element) | |
| continue | |
| logging.debug('Path element "d" attribute validated (regex check).') | |
| # Remove elements marked for removal | |
| for element in elements_to_remove: | |
| if element.getparent() is not None: | |
| element.getparent().remove(element) | |
| logging.debug('Removed element: %s', element.tag) | |
| try: | |
| cleaned_svg_string = etree.tostring(root, encoding='unicode') | |
| return cleaned_svg_string | |
| except ValueError as e: | |
| logging.error( | |
| 'SVG could not be sanitized to meet constraints: %s', e | |
| ) | |
| return self.default_svg | |
| if __name__ == "__main__": | |
| model = NaiveModel() | |
| svg = model.predict("a purple forest at dusk") | |
| # Convert SVG to PNG | |
| try: | |
| # Create a PNG in memory | |
| png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
| # Save the PNG to a file | |
| with open("output.png", "wb") as f: | |
| f.write(png_data) | |
| print("SVG saved as output.png") | |
| except Exception as e: | |
| print(f"Error converting SVG to PNG: {e}") | |