AfshinMA's picture
Update app.py
855c27c verified
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
import gradio as gr
import os
def clean_text(text: str) -> str:
"""Remove unwanted characters from the text."""
replacements = {
'!': '',
'?': '.',
'(': '',
')': '',
':': '',
'/': '',
'\\': '',
'\n': ' '
}
for old_char, new_char in replacements.items():
text = text.replace(old_char, new_char)
return text
def chunk_text(article: str, max_chunk_size: int = 500) -> list[str]:
"""Chunk the text into smaller parts based on the specified chunk size."""
chunks = []
current_words = 0
# Split text into sentences
for sentence in article.split('. '):
current_words += len(sentence.split())
if current_words > max_chunk_size * (len(chunks) + 1):
chunks.append('')
if chunks: # If there is at least one chunk
chunks[-1] += sentence + '. '
else:
chunks.append(sentence + '. ')
return [chunk.strip() for chunk in chunks]
def summarize_text(text: str, max_chunk_size: int = 400, max_length: int = 130) -> str:
"""Summarize the input text using a pre-trained model."""
if not text.strip(): # Handle empty input
return "Please provide some text to summarize."
cleaned_text = clean_text(text)
chunks = chunk_text(cleaned_text, max_chunk_size)
# Load the BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
summary = ""
for chunk in chunks:
inputs = tokenizer.encode("summarize: " + chunk, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = model.generate(inputs, max_length=max_length, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
summary += tokenizer.decode(summary_ids[0], skip_special_tokens=True) + "\n"
return summary.strip()
def load_texts(file_paths: list[str]) -> list[str]:
"""Load text content from a list of file paths."""
texts = []
for path in file_paths:
try:
with open(path, 'r', encoding='utf-8') as file:
texts.append(file.read())
except FileNotFoundError:
print(f"File not found: {path}")
texts.append("") # Append an empty string if file is not found
return texts
def main():
# Load example texts
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
example_paths = [
os.path.join(ROOT_DIR, r'./texts/sample1.txt'),
os.path.join(ROOT_DIR, r'./texts/sample2.txt')
]
example_texts = load_texts(example_paths)
gr.Interface(
title="Text Summarizer",
fn=summarize_text,
inputs=[
gr.TextArea(label='Input Text', lines=3, max_lines=7, placeholder="Enter text here...", max_length=5000),
gr.Slider(50, 500, step=10, value=400, label="Max Chunk Size", info="Choose between 50 and 500"),
gr.Slider(30, 150, step=10, value=130, label="Max Length of Summary", info="Choose between 30 and 150")
],
outputs=gr.Textbox(label="Summary"),
examples=example_texts,
theme="default",
css=".footer{display:none !important}"
).launch(share=True, debug=True)
if __name__ == '__main__':
main()