Spaces:
Running
Running
Abhipsha Das
commited on
add files
Browse files- data/databases/README.md +32 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/eval/__init__.py +0 -0
- src/eval/metrics.py +87 -0
- src/processing/__init__.py +0 -0
- src/processing/__pycache__/__init__.cpython-311.pyc +0 -0
- src/processing/__pycache__/extractions.cpython-311.pyc +0 -0
- src/processing/__pycache__/generate.cpython-311.pyc +0 -0
- src/processing/extractions.py +65 -0
- src/processing/generate.py +226 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- src/utils/__pycache__/utils.cpython-311.pyc +0 -0
- src/utils/utils.py +155 -0
data/databases/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- This folder contains all the SQL databases for the different processed data along with their raw data.
|
| 2 |
+
|
| 3 |
+
- The databases are named after the arXiv category and the format of the generated data.
|
| 4 |
+
|
| 5 |
+
Each file in this folder is a database containing 2 tables:
|
| 6 |
+
- **papers**
|
| 7 |
+
|
| 8 |
+
The papers data from the `raw` folder that was fed to the model.
|
| 9 |
+
|
| 10 |
+
SCHEMA:
|
| 11 |
+
- paper_id TEXT PRIMARY KEY,
|
| 12 |
+
- abstract TEXT,
|
| 13 |
+
- authors TEXT,
|
| 14 |
+
- primary_category TEXT,
|
| 15 |
+
- url TEXT,
|
| 16 |
+
- updated_on TEXT,
|
| 17 |
+
- sentence_count INTEGER
|
| 18 |
+
|
| 19 |
+
- **predictions**
|
| 20 |
+
|
| 21 |
+
The corresponding model generations stored in the `results` folder.
|
| 22 |
+
|
| 23 |
+
SCHEMA:
|
| 24 |
+
- id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 25 |
+
- paper_id TEXT,
|
| 26 |
+
- sentence_index INTEGER,
|
| 27 |
+
- tag_type TEXT,
|
| 28 |
+
- concept TEXT,
|
| 29 |
+
- FOREIGN KEY (paper_id) REFERENCES papers(paper_id)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
To query any database, open SQLite in your terminal and specify the database name.
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
src/eval/__init__.py
ADDED
|
File without changes
|
src/eval/metrics.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
|
| 5 |
+
"""
|
| 6 |
+
Returns true positives, false positives, and false negatives for one example
|
| 7 |
+
If union is True, then disregards the type of the tag and only considers the union of all tags
|
| 8 |
+
"""
|
| 9 |
+
n_tp = 0
|
| 10 |
+
n_fp = 0
|
| 11 |
+
n_fn = 0
|
| 12 |
+
if union:
|
| 13 |
+
gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
|
| 14 |
+
pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
|
| 15 |
+
n_tp = len(gold_phrases & pred_phrases)
|
| 16 |
+
n_fp = len(pred_phrases - gold_phrases)
|
| 17 |
+
n_fn = len(gold_phrases - pred_phrases)
|
| 18 |
+
return n_tp, n_fp, n_fn
|
| 19 |
+
|
| 20 |
+
for tag in set(gold.keys()).union(pred.keys()):
|
| 21 |
+
gold_phrases = set(gold.get(tag, []))
|
| 22 |
+
pred_phrases = set(pred.get(tag, []))
|
| 23 |
+
|
| 24 |
+
n_tp += len(gold_phrases & pred_phrases)
|
| 25 |
+
n_fp += len(pred_phrases - gold_phrases)
|
| 26 |
+
n_fn += len(gold_phrases - pred_phrases)
|
| 27 |
+
|
| 28 |
+
return n_tp, n_fp, n_fn
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_metrics(running_time, pred_times, runtype, eval_metrics=None):
|
| 32 |
+
metrics = {}
|
| 33 |
+
metrics["avg_pred_response_time_per_sentence"] = (
|
| 34 |
+
round(sum(pred_times) / len(pred_times), 4) if pred_times else 0
|
| 35 |
+
)
|
| 36 |
+
metrics["total_time"] = round(running_time, 4)
|
| 37 |
+
|
| 38 |
+
if runtype == "eval" and eval_metrics is not None:
|
| 39 |
+
n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics
|
| 40 |
+
|
| 41 |
+
precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
|
| 42 |
+
recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
|
| 43 |
+
f1 = round(
|
| 44 |
+
(
|
| 45 |
+
2 * (precision * recall) / (precision + recall)
|
| 46 |
+
if (precision + recall) > 0
|
| 47 |
+
else 0
|
| 48 |
+
),
|
| 49 |
+
4,
|
| 50 |
+
)
|
| 51 |
+
union_precision = round(
|
| 52 |
+
(
|
| 53 |
+
n_tp_union / (n_tp_union + n_fp_union)
|
| 54 |
+
if (n_tp_union + n_fp_union) > 0
|
| 55 |
+
else 0
|
| 56 |
+
),
|
| 57 |
+
4,
|
| 58 |
+
)
|
| 59 |
+
union_recall = round(
|
| 60 |
+
(
|
| 61 |
+
n_tp_union / (n_tp_union + n_fn_union)
|
| 62 |
+
if (n_tp_union + n_fn_union) > 0
|
| 63 |
+
else 0
|
| 64 |
+
),
|
| 65 |
+
4,
|
| 66 |
+
)
|
| 67 |
+
union_f1 = round(
|
| 68 |
+
(
|
| 69 |
+
2 * (union_precision * union_recall) / (union_precision + union_recall)
|
| 70 |
+
if (union_precision + union_recall) > 0
|
| 71 |
+
else 0
|
| 72 |
+
),
|
| 73 |
+
4,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
metrics.update(
|
| 77 |
+
{
|
| 78 |
+
"precision": precision,
|
| 79 |
+
"recall": recall,
|
| 80 |
+
"f1": f1,
|
| 81 |
+
"union_precision": union_precision,
|
| 82 |
+
"union_recall": union_recall,
|
| 83 |
+
"union_f1": union_f1,
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return metrics
|
src/processing/__init__.py
ADDED
|
File without changes
|
src/processing/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
src/processing/__pycache__/extractions.cpython-311.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
src/processing/__pycache__/generate.cpython-311.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
src/processing/extractions.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from bs4 import BeautifulSoup
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# TODO: review the functions here
|
| 10 |
+
def extract_all_tagged_phrases(text: str) -> Dict[str, List[str]]:
|
| 11 |
+
soup = BeautifulSoup(text, "html.parser")
|
| 12 |
+
tagged_phrases = defaultdict(list)
|
| 13 |
+
|
| 14 |
+
for tag in soup.find_all(True):
|
| 15 |
+
if tag.name:
|
| 16 |
+
# Clean and process the text
|
| 17 |
+
full_text = " ".join(tag.stripped_strings)
|
| 18 |
+
full_text = re.sub(r"\s+", " ", full_text.strip())
|
| 19 |
+
full_text = re.sub(r'(?<!\\)\\(?!["\\])', r"\\\\", full_text)
|
| 20 |
+
full_text = full_text.replace('"', '\\"')
|
| 21 |
+
|
| 22 |
+
if full_text: # Only add non-empty strings
|
| 23 |
+
tagged_phrases[tag.name].append(full_text)
|
| 24 |
+
|
| 25 |
+
# Remove duplicates while preserving order
|
| 26 |
+
return {
|
| 27 |
+
tag: list(dict.fromkeys(phrases)) for tag, phrases in tagged_phrases.items()
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def extract_prediction(schema: dict, prediction: str, kind: str = "json") -> dict:
|
| 32 |
+
pred = {}
|
| 33 |
+
if kind == "json":
|
| 34 |
+
json_match = re.search(r"\{[\s\S]+\}", prediction)
|
| 35 |
+
if json_match:
|
| 36 |
+
json_str = json_match.group(0)
|
| 37 |
+
json_str = re.sub(r"(\w+)-\$?\\?(\w+)\$?", r"\1-\2", json_str)
|
| 38 |
+
json_str = json_str.replace('\\"', '"')
|
| 39 |
+
json_str = re.sub(r'}\s*"', '}, "', json_str)
|
| 40 |
+
json_str = re.sub(r']\s*"', '], "', json_str)
|
| 41 |
+
try:
|
| 42 |
+
pred = json.loads(json_str)
|
| 43 |
+
except json.JSONDecodeError as e:
|
| 44 |
+
logging.warning(f"Failed to parse JSON: {json_str}")
|
| 45 |
+
logging.warning(f"Error: {str(e)}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
json_str = re.sub(r",\s*([}\]])", r"\1", json_str)
|
| 49 |
+
json_str = re.sub(r"(?<![\w'])'|'(?![\w'])", '"', json_str)
|
| 50 |
+
pred = json.loads(json_str)
|
| 51 |
+
except json.JSONDecodeError:
|
| 52 |
+
logging.error(
|
| 53 |
+
f"Failed to parse JSON even after attempted fixes: {json_str}"
|
| 54 |
+
)
|
| 55 |
+
elif kind == "readable":
|
| 56 |
+
match = re.findall(
|
| 57 |
+
rf'^({"|".join(list(schema.keys()))}): (.+)$',
|
| 58 |
+
prediction,
|
| 59 |
+
flags=re.MULTILINE,
|
| 60 |
+
)
|
| 61 |
+
pred = {tag: values.split(", ") for tag, values in match}
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Invalid kind: {kind}")
|
| 64 |
+
|
| 65 |
+
return pred
|
src/processing/generate.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
# import spacy
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from config import (
|
| 9 |
+
DEFAULT_FEW_SHOT_NUM,
|
| 10 |
+
DEFAULT_FEW_SHOT_SELECTION,
|
| 11 |
+
DEFAULT_TEMPERATURE,
|
| 12 |
+
DEFAULT_TOP_P,
|
| 13 |
+
DEFAULT_KIND,
|
| 14 |
+
)
|
| 15 |
+
from typing import List, Dict, Tuple, Union
|
| 16 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
| 17 |
+
|
| 18 |
+
from .extractions import extract_all_tagged_phrases
|
| 19 |
+
|
| 20 |
+
# nlp = spacy.load("en_core_web_sm")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# TODO: run with constituency tests
|
| 24 |
+
# TODO: review instruction and system level prompt (currently they are repetitive)
|
| 25 |
+
def get_sentences(text: str) -> List[str]:
|
| 26 |
+
# TODO: spacy splitting results in unequal lengths
|
| 27 |
+
# doc = nlp(text)
|
| 28 |
+
# sentences = [sent.text.strip() for sent in doc.sents]
|
| 29 |
+
# sentences = [s for s in sentences if s]
|
| 30 |
+
# return sentences
|
| 31 |
+
|
| 32 |
+
return text.split(". ")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def format_instance(sentence: str, extraction: Union[str, None]) -> str:
|
| 36 |
+
return "".join(
|
| 37 |
+
[
|
| 38 |
+
f"Sentence: {sentence}\n",
|
| 39 |
+
(
|
| 40 |
+
f"Extractions:\n{extraction}\n"
|
| 41 |
+
if extraction is not None
|
| 42 |
+
else f"Extractions:\n"
|
| 43 |
+
),
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def generate_instructions(schema: dict, kind: str = DEFAULT_KIND) -> str:
|
| 49 |
+
instruction_parts = [
|
| 50 |
+
"The following schema is provided to tag the title and abstract of a given scientific paper as shown in the examples:\n"
|
| 51 |
+
]
|
| 52 |
+
if kind == "json":
|
| 53 |
+
instruction_parts.append(f"{json.dumps(schema, indent=2)}\n\n")
|
| 54 |
+
elif kind == "readable":
|
| 55 |
+
readable_schema = ""
|
| 56 |
+
for tag, description in schema.items():
|
| 57 |
+
readable_schema += f"{tag}: {description}\n"
|
| 58 |
+
instruction_parts.append(f"{readable_schema}\n")
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Invalid kind: {kind}")
|
| 61 |
+
|
| 62 |
+
return "".join(instruction_parts)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def generate_demonstrations(
|
| 66 |
+
examples: List[dict],
|
| 67 |
+
kind: str = DEFAULT_KIND,
|
| 68 |
+
num_examples: int = DEFAULT_FEW_SHOT_NUM,
|
| 69 |
+
selection: str = DEFAULT_FEW_SHOT_SELECTION,
|
| 70 |
+
) -> str:
|
| 71 |
+
demonstration_parts = []
|
| 72 |
+
for example in examples:
|
| 73 |
+
sentences = get_sentences(example["abstract"])
|
| 74 |
+
tagged_sentences = get_sentences(example["tagged_abstract"])
|
| 75 |
+
paired_sentences = list(zip(sentences, tagged_sentences, strict=True))
|
| 76 |
+
|
| 77 |
+
if selection == "random":
|
| 78 |
+
selected_pairs = random.sample(
|
| 79 |
+
paired_sentences, min(num_examples, len(paired_sentences))
|
| 80 |
+
)
|
| 81 |
+
elif selection == "first":
|
| 82 |
+
selected_pairs = paired_sentences[:num_examples]
|
| 83 |
+
elif selection == "last":
|
| 84 |
+
selected_pairs = paired_sentences[-num_examples:]
|
| 85 |
+
elif selection == "middle":
|
| 86 |
+
start = max(0, (len(paired_sentences) - num_examples) // 2)
|
| 87 |
+
selected_pairs = paired_sentences[start : start + num_examples]
|
| 88 |
+
elif selection == "distributed":
|
| 89 |
+
step = max(1, len(paired_sentences) // num_examples)
|
| 90 |
+
selected_pairs = paired_sentences[::step][:num_examples]
|
| 91 |
+
elif selection == "longest":
|
| 92 |
+
selected_pairs = sorted(
|
| 93 |
+
paired_sentences, key=lambda x: len(x[0]), reverse=True
|
| 94 |
+
)[:num_examples]
|
| 95 |
+
elif selection == "shortest":
|
| 96 |
+
selected_pairs = sorted(paired_sentences, key=lambda x: len(x[0]))[
|
| 97 |
+
:num_examples
|
| 98 |
+
]
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Invalid selection method: {selection}")
|
| 101 |
+
|
| 102 |
+
for sentence, tagged_sentence in selected_pairs:
|
| 103 |
+
tag_to_phrase = extract_all_tagged_phrases(tagged_sentence)
|
| 104 |
+
if kind == "json":
|
| 105 |
+
extractions = f"{json.dumps(tag_to_phrase, indent=2)}\n"
|
| 106 |
+
elif kind == "readable":
|
| 107 |
+
extractions = "".join(
|
| 108 |
+
f"{tag}: {', '.join(phrase)}\n"
|
| 109 |
+
for tag, phrase in tag_to_phrase.items()
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Invalid kind: {kind}")
|
| 113 |
+
|
| 114 |
+
demonstration_parts.append(format_instance(sentence, extractions))
|
| 115 |
+
|
| 116 |
+
return "".join(demonstration_parts)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def generate_prefix(instructions: str, demonstrations: str) -> str:
|
| 120 |
+
return f"{instructions}" f"{demonstrations}"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def generate_prediction(
|
| 124 |
+
model,
|
| 125 |
+
tokenizer,
|
| 126 |
+
prefix: str,
|
| 127 |
+
input: str,
|
| 128 |
+
kind: str,
|
| 129 |
+
system_prompt: str = f"You are an assistant who tags papers according to given schema and "
|
| 130 |
+
"only returns the tagged phrases in the format as provided in the examples "
|
| 131 |
+
"without repeating anything else.",
|
| 132 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 133 |
+
top_p: float = DEFAULT_TOP_P,
|
| 134 |
+
) -> str:
|
| 135 |
+
prompt = prefix + input
|
| 136 |
+
messages = [
|
| 137 |
+
{
|
| 138 |
+
"role": "system",
|
| 139 |
+
"content": system_prompt,
|
| 140 |
+
},
|
| 141 |
+
{"role": "user", "content": prompt},
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
input_ids = tokenizer.apply_chat_template(
|
| 145 |
+
messages,
|
| 146 |
+
# add_generation_prompt=True,
|
| 147 |
+
return_tensors="pt",
|
| 148 |
+
).to(model.device)
|
| 149 |
+
|
| 150 |
+
terminators = [
|
| 151 |
+
tokenizer.eos_token_id,
|
| 152 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
outputs = model.generate(
|
| 156 |
+
input_ids,
|
| 157 |
+
max_new_tokens=1200,
|
| 158 |
+
eos_token_id=terminators,
|
| 159 |
+
# num_beams=8,
|
| 160 |
+
do_sample=True,
|
| 161 |
+
temperature=temperature,
|
| 162 |
+
top_p=top_p,
|
| 163 |
+
)
|
| 164 |
+
response = outputs[0][input_ids.shape[-1] :]
|
| 165 |
+
prediction_response = tokenizer.decode(response, skip_special_tokens=True)
|
| 166 |
+
|
| 167 |
+
return prediction_response
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def batch_generate_prediction(
|
| 171 |
+
model,
|
| 172 |
+
tokenizer,
|
| 173 |
+
prefix: str,
|
| 174 |
+
input_ids: torch.Tensor,
|
| 175 |
+
kind: str,
|
| 176 |
+
system_prompt: str = "You are an assistant who tags papers according to given schema and "
|
| 177 |
+
"only returns the tagged phrases in the format as provided in the examples "
|
| 178 |
+
"without repeating anything else.",
|
| 179 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 180 |
+
top_p: float = DEFAULT_TOP_P,
|
| 181 |
+
max_new_tokens: int = 1200,
|
| 182 |
+
batch_size: int = 1,
|
| 183 |
+
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 184 |
+
) -> List[str]:
|
| 185 |
+
all_predictions = []
|
| 186 |
+
|
| 187 |
+
# Prepare system message
|
| 188 |
+
system_message = {"role": "system", "content": system_prompt}
|
| 189 |
+
|
| 190 |
+
for i in range(0, input_ids.size(0), batch_size):
|
| 191 |
+
batch_input_ids = input_ids[i : i + batch_size]
|
| 192 |
+
|
| 193 |
+
batch_messages = [
|
| 194 |
+
[
|
| 195 |
+
system_message,
|
| 196 |
+
{
|
| 197 |
+
"role": "user",
|
| 198 |
+
"content": prefix + tokenizer.decode(ids, skip_special_tokens=True),
|
| 199 |
+
},
|
| 200 |
+
]
|
| 201 |
+
for ids in batch_input_ids
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
batch_input_ids = tokenizer.apply_chat_template(
|
| 205 |
+
batch_messages, return_tensors="pt", padding=True, truncation=True
|
| 206 |
+
).to(device)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
outputs = model.generate(
|
| 210 |
+
batch_input_ids,
|
| 211 |
+
max_new_tokens=max_new_tokens,
|
| 212 |
+
do_sample=True,
|
| 213 |
+
temperature=temperature,
|
| 214 |
+
top_p=top_p,
|
| 215 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 216 |
+
attention_mask=batch_input_ids.ne(tokenizer.pad_token_id),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
for output in outputs:
|
| 220 |
+
response = output[batch_input_ids.size(1) :]
|
| 221 |
+
prediction_response = tokenizer.decode(response, skip_special_tokens=True)
|
| 222 |
+
all_predictions.append(prediction_response)
|
| 223 |
+
|
| 224 |
+
torch.cuda.empty_cache()
|
| 225 |
+
|
| 226 |
+
return all_predictions
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
src/utils/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (8.47 kB). View file
|
|
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from config import DEFAULT_RES_DIR as RES_DIR
|
| 7 |
+
|
| 8 |
+
from accelerate import (
|
| 9 |
+
infer_auto_device_map,
|
| 10 |
+
init_empty_weights,
|
| 11 |
+
Accelerator,
|
| 12 |
+
load_checkpoint_and_dispatch,
|
| 13 |
+
)
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_results(
|
| 18 |
+
out_dir_path,
|
| 19 |
+
all_inputs,
|
| 20 |
+
gold_tags,
|
| 21 |
+
predicted_responses,
|
| 22 |
+
predicted_tags,
|
| 23 |
+
metrics,
|
| 24 |
+
runtype,
|
| 25 |
+
append=False,
|
| 26 |
+
):
|
| 27 |
+
mode = "a" if append else "w"
|
| 28 |
+
|
| 29 |
+
with open(
|
| 30 |
+
os.path.join(RES_DIR, out_dir_path, "prompts.txt"), mode, encoding="utf-8"
|
| 31 |
+
) as f:
|
| 32 |
+
for input, gold_tag, pred_response, pred_tag in zip(
|
| 33 |
+
all_inputs, gold_tags, predicted_responses, predicted_tags
|
| 34 |
+
):
|
| 35 |
+
f.write(f"{input}\n")
|
| 36 |
+
f.write(f"True Tag: {gold_tag}\n")
|
| 37 |
+
f.write(f"Predicted Response: {pred_response}\n")
|
| 38 |
+
f.write(f"Predicted Tag: {pred_tag}\n")
|
| 39 |
+
f.write("#" * 50 + "\n")
|
| 40 |
+
|
| 41 |
+
with open(
|
| 42 |
+
os.path.join(RES_DIR, out_dir_path, "predicted_responses.txt"),
|
| 43 |
+
mode,
|
| 44 |
+
encoding="utf-8",
|
| 45 |
+
) as f:
|
| 46 |
+
for response in predicted_responses:
|
| 47 |
+
f.write(f"{response}\n")
|
| 48 |
+
f.write("#" * 50 + "\n")
|
| 49 |
+
|
| 50 |
+
if append:
|
| 51 |
+
with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "r+") as f:
|
| 52 |
+
data = json.load(f)
|
| 53 |
+
data["predicted_tags"].extend(predicted_tags)
|
| 54 |
+
f.seek(0)
|
| 55 |
+
json.dump(data, f, indent=4)
|
| 56 |
+
f.truncate()
|
| 57 |
+
else:
|
| 58 |
+
with open(os.path.join(RES_DIR, out_dir_path, "predictions.json"), "w") as f:
|
| 59 |
+
json.dump({"predicted_tags": predicted_tags}, f, indent=4)
|
| 60 |
+
|
| 61 |
+
if runtype == "eval":
|
| 62 |
+
if append:
|
| 63 |
+
with open(
|
| 64 |
+
os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "r+"
|
| 65 |
+
) as f:
|
| 66 |
+
data = json.load(f)
|
| 67 |
+
data["gold_tags"].extend(gold_tag)
|
| 68 |
+
f.seek(0)
|
| 69 |
+
json.dump(data, f, indent=4)
|
| 70 |
+
f.truncate()
|
| 71 |
+
else:
|
| 72 |
+
with open(
|
| 73 |
+
os.path.join(RES_DIR, out_dir_path, "ground_truth.json"), "w"
|
| 74 |
+
) as f:
|
| 75 |
+
json.dump({"gold_tags": gold_tags}, f, indent=4)
|
| 76 |
+
|
| 77 |
+
with open(os.path.join(RES_DIR, out_dir_path, "metrics.json"), "w") as f:
|
| 78 |
+
json.dump({"metrics": metrics, "prompt_file": "prompts.txt"}, f, indent=4)
|
| 79 |
+
|
| 80 |
+
logging.info(f"Results saved in: {os.path.join(RES_DIR, out_dir_path)}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def save_best_config(metrics, config):
|
| 84 |
+
best_config_path = os.path.join(RES_DIR, "best_config.json")
|
| 85 |
+
if os.path.exists(best_config_path):
|
| 86 |
+
with open(best_config_path, "r") as f:
|
| 87 |
+
best_config = json.load(f)
|
| 88 |
+
if metrics["precision"] > best_config["metrics"]["precision"]:
|
| 89 |
+
best_config = {"metrics": metrics, "config": config}
|
| 90 |
+
else:
|
| 91 |
+
best_config = {"metrics": metrics, "config": config}
|
| 92 |
+
|
| 93 |
+
with open(best_config_path, "w") as f:
|
| 94 |
+
json.dump(best_config, f, indent=4)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_sweep_config(config_path="sweep_config.json"):
|
| 98 |
+
with open(config_path, "r") as f:
|
| 99 |
+
return json.load(f)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# def load_model_and_tokenizer(model_id: str):
|
| 103 |
+
# accelerator = Accelerator()
|
| 104 |
+
|
| 105 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
| 106 |
+
# # device_map = infer_auto_device_map(model, max_memory=max_memory)
|
| 107 |
+
|
| 108 |
+
# if tokenizer.pad_token_id is None:
|
| 109 |
+
# tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 110 |
+
|
| 111 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 112 |
+
# model_id,
|
| 113 |
+
# torch_dtype=torch.bfloat16,
|
| 114 |
+
# device_map="auto",
|
| 115 |
+
# token=os.getenv("HF_TOKEN"),
|
| 116 |
+
# )
|
| 117 |
+
|
| 118 |
+
# model, tokenizer = accelerator.prepare(model, tokenizer)
|
| 119 |
+
|
| 120 |
+
# return model, tokenizer
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def clear_cuda_cache():
|
| 124 |
+
if torch.cuda.is_available():
|
| 125 |
+
torch.cuda.empty_cache()
|
| 126 |
+
torch.cuda.memory.reset_max_memory_allocated()
|
| 127 |
+
torch.cuda.memory.reset_max_memory_cached()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_model_and_tokenizer(model_id):
|
| 131 |
+
# Set up memory-saving options
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 134 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 135 |
+
|
| 136 |
+
# Initialize tokenizer
|
| 137 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 138 |
+
model_id, padding_side="left", use_auth_token=os.getenv("HF_TOKEN")
|
| 139 |
+
)
|
| 140 |
+
if tokenizer.pad_token_id is None:
|
| 141 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 142 |
+
|
| 143 |
+
# Load configuration
|
| 144 |
+
config = AutoConfig.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
|
| 145 |
+
|
| 146 |
+
# Load model
|
| 147 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 148 |
+
model_id,
|
| 149 |
+
config=config,
|
| 150 |
+
torch_dtype=torch.float16,
|
| 151 |
+
use_auth_token=os.getenv("HF_TOKEN"),
|
| 152 |
+
device_map="auto",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return model, tokenizer
|