Spaces:
Runtime error
Runtime error
| # ============================= | |
| # π codet5_summarizer.py (Updated) | |
| # ============================= | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| import os | |
| MODEL_OPTIONS = { | |
| "CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum", | |
| "CodeT5 Base": "Salesforce/codet5-base", | |
| "CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python", | |
| "Gemini (describeai)": "describeai/gemini", | |
| "Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2", | |
| } | |
| class CodeT5Summarizer: | |
| def __init__(self, model_name=None): | |
| model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"] | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| hf_token = os.getenv('HF_TOKEN') | |
| if hf_token is None: | |
| raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
| # Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq | |
| try: | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device) | |
| except: | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device) | |
| self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False | |
| def preprocess_code(self, code): | |
| code = re.sub(r'\n\s*\n', '\n', code) | |
| lines = code.split('\n') | |
| clean = [] | |
| docstring = False | |
| for line in lines: | |
| if '"""' in line or "'''" in line: | |
| docstring = not docstring | |
| if docstring or not line.strip().startswith('#'): | |
| clean.append(line) | |
| return re.sub(r' +', ' ', '\n'.join(clean)) | |
| def extract_functions(self, code): | |
| function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:' | |
| function_matches = re.finditer(function_pattern, code, re.DOTALL) | |
| functions = [] | |
| for match in function_matches: | |
| start_pos = match.start() | |
| function_name = match.group(1) | |
| lines = code[start_pos:].split('\n') | |
| body_start = 1 | |
| while body_start < len(lines) and not lines[body_start].strip(): | |
| body_start += 1 | |
| if body_start < len(lines): | |
| body_indent = len(lines[body_start]) - len(lines[body_start].lstrip()) | |
| function_body = [lines[0]] | |
| i = 1 | |
| while i < len(lines): | |
| line = lines[i] | |
| if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'): | |
| break | |
| function_body.append(line) | |
| i += 1 | |
| function_code = '\n'.join(function_body) | |
| functions.append((function_name, function_code)) | |
| # Class method detection | |
| class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)' | |
| class_matches = re.finditer(class_pattern, code, re.DOTALL) | |
| for match in class_matches: | |
| class_name = match.group(1) | |
| start_pos = match.start() | |
| class_code = code[start_pos:] | |
| method_matches = re.finditer(function_pattern, class_code, re.DOTALL) | |
| for method_match in method_matches: | |
| if method_match.start() > 200: # Only near the top of the class | |
| break | |
| method_name = method_match.group(1) | |
| method_start = method_match.start() | |
| method_lines = class_code[method_start:].split('\n') | |
| body_start = 1 | |
| while body_start < len(method_lines) and not method_lines[body_start].strip(): | |
| body_start += 1 | |
| if body_start < len(method_lines): | |
| body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip()) | |
| method_body = [method_lines[0]] | |
| i = 1 | |
| while i < len(method_lines): | |
| line = method_lines[i] | |
| if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'): | |
| break | |
| method_body.append(line) | |
| i += 1 | |
| method_code = '\n'.join(method_body) | |
| functions.append((f"{class_name}.{method_name}", method_code)) | |
| return functions | |
| def extract_classes(self, code): | |
| class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)' | |
| class_matches = re.finditer(class_pattern, code, re.DOTALL) | |
| classes = [] | |
| for match in class_matches: | |
| class_name = match.group(1) | |
| start_pos = match.start() | |
| class_lines = code[start_pos:].split('\n') | |
| body_start = 1 | |
| while body_start < len(class_lines) and not class_lines[body_start].strip(): | |
| body_start += 1 | |
| if body_start < len(class_lines): | |
| body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip()) | |
| class_body = [class_lines[0]] | |
| i = 1 | |
| while i < len(class_lines): | |
| line = class_lines[i] | |
| if line.strip() and (len(line) - len(line.lstrip())) < body_indent: | |
| break | |
| class_body.append(line) | |
| i += 1 | |
| class_code = '\n'.join(class_body) | |
| classes.append((class_name, class_code)) | |
| return classes | |
| def summarize(self, code, max_length=512): | |
| inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device) | |
| with torch.no_grad(): | |
| if self.is_encoder_decoder: | |
| output = self.model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], # Optional but good to include | |
| max_new_tokens=max_length, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| return self.tokenizer.decode(output[0], skip_special_tokens=True) | |
| else: | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| output = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, # β Add this line | |
| max_new_tokens=max_length, | |
| do_sample=False, | |
| num_beams=4, | |
| early_stopping=True, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| return self.tokenizer.decode(output[0], skip_special_tokens=True) | |
| def summarize_code(self, code, summarize_functions=True, summarize_classes=True): | |
| preprocessed_code = self.preprocess_code(code) | |
| results = { | |
| "file_summary": None, | |
| "function_summaries": {}, | |
| "class_summaries": {} | |
| } | |
| try: | |
| results["file_summary"] = self.summarize(preprocessed_code) | |
| except Exception as e: | |
| results["file_summary"] = f"Error generating file summary: {str(e)}" | |
| if summarize_functions: | |
| for function_name, function_code in self.extract_functions(preprocessed_code): | |
| try: | |
| summary = self.summarize(function_code) | |
| results["function_summaries"][function_name] = summary | |
| except Exception as e: | |
| results["function_summaries"][function_name] = f"Error: {str(e)}" | |
| if summarize_classes: | |
| for class_name, class_code in self.extract_classes(preprocessed_code): | |
| try: | |
| summary = self.summarize(class_code) | |
| results["class_summaries"][class_name] = summary | |
| except Exception as e: | |
| results["class_summaries"][class_name] = f"Error: {str(e)}" | |
| return results | |