Upload 3 files
Browse files- pubchem_experiment/data_preprocess.py +197 -0
- pubchem_experiment/make_predictions.py +172 -0
- pubchem_experiment/metrics.py +163 -0
pubchem_experiment/data_preprocess.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import tqdm
|
| 4 |
+
import swifter
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
|
| 7 |
+
# Disable RDKit informational and warning messages
|
| 8 |
+
from rdkit import RDLogger
|
| 9 |
+
RDLogger.DisableLog('rdApp.*')
|
| 10 |
+
|
| 11 |
+
PUBCHEM_DIR = # pubchem_path + 'pubchem24/'
|
| 12 |
+
FSMOL_UID_PATH = # fsmol_path + '/fsmol/fsmol_train_accession_keys.json'
|
| 13 |
+
PROT_CLASS_PATH = # chembl_path + 'chembl33/uniprot_pclass_mapping.csv'
|
| 14 |
+
MHNFS_PATH = # mhnfs_path + '/mhnfs'
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.append(MHNFS_PATH)
|
| 18 |
+
from src.data_preprocessing.utils import Standardizer
|
| 19 |
+
|
| 20 |
+
class PubChemFilter:
|
| 21 |
+
|
| 22 |
+
def __init__(self, pubchem_dir, fsmol_uid_path, prot_class_path, mhnfs_path, debug = False):
|
| 23 |
+
self.pubchem_dir = pubchem_dir
|
| 24 |
+
self.fsmol_uid_path = fsmol_uid_path
|
| 25 |
+
self.prot_class_path = prot_class_path
|
| 26 |
+
self.mhnfs_path = mhnfs_path
|
| 27 |
+
self.debug = debug
|
| 28 |
+
|
| 29 |
+
def load_and_filter_assays(self):
|
| 30 |
+
"""
|
| 31 |
+
Load PubChem Assay data from file and filter them:
|
| 32 |
+
1. Drop all assays without protein accession keys
|
| 33 |
+
2. Drop all assays linked to multiple accession keys
|
| 34 |
+
3. Drop all assays with accession keys in FSmol training data
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
df_assays (pd.Dataframe)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
print('Load assays...')
|
| 41 |
+
df_assays = pd.read_table(f'{self.pubchem_dir}/bioassays.tsv.gz', usecols=['AID', 'UniProts IDs'] ).rename(columns={'UniProts IDs' : 'UID'})
|
| 42 |
+
|
| 43 |
+
# Load FSmol training data accession keys
|
| 44 |
+
with open(self.fsmol_uid_path, 'r') as f:
|
| 45 |
+
fs_train_targets = json.load(f).values()
|
| 46 |
+
fs_train_targets = list(set([key for sublist in fs_train_targets for key in sublist]))
|
| 47 |
+
|
| 48 |
+
print('Filter assays...')
|
| 49 |
+
df_assays = df_assays.dropna(subset=['UID'])
|
| 50 |
+
df_assays = df_assays[~df_assays['UID'].str.contains('\|')]
|
| 51 |
+
df_assays = df_assays[~df_assays['UID'].str.contains('|'.join(fs_train_targets))]
|
| 52 |
+
self.df_assays = df_assays
|
| 53 |
+
|
| 54 |
+
def load_and_filter_bioactivities(self, chunk_size=10_000_000):
|
| 55 |
+
"""
|
| 56 |
+
Load bioactivity data in chucks and filter out datapoints with
|
| 57 |
+
1. assay not in aids
|
| 58 |
+
2. outcome not 'Active'/'Inactive'
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
print('Load bioactivities...')
|
| 62 |
+
aids = self.df_assays.AID.tolist()
|
| 63 |
+
filtered_chunks = []
|
| 64 |
+
chunk_size = 10_000_000
|
| 65 |
+
for chunk in pd.read_csv(f'{self.pubchem_dir}/bioactivities.tsv.gz', sep='\t', chunksize=chunk_size, usecols=['AID', 'CID', 'Activity Outcome']):
|
| 66 |
+
filtered_chunk = chunk[chunk['AID'].isin(aids)]
|
| 67 |
+
filtered_chunk = filtered_chunk[filtered_chunk['Activity Outcome'].isin(['Inactive','Active'])]
|
| 68 |
+
filtered_chunks.append(filtered_chunk)
|
| 69 |
+
if self.debug:
|
| 70 |
+
break # For debugging
|
| 71 |
+
df_bio = pd.concat(filtered_chunks)
|
| 72 |
+
df_bio = df_bio[df_bio.CID.notna()]
|
| 73 |
+
df_bio['Activity'] = df_bio['Activity Outcome'].swifter.apply(lambda x : 1 if x == 'Active' else 0)
|
| 74 |
+
self.df_bio = df_bio.drop('Activity Outcome', axis=1).astype(int)
|
| 75 |
+
|
| 76 |
+
def merge_assay_and_activity_data(self):
|
| 77 |
+
print('Merge...')
|
| 78 |
+
self.df = self.df_bio.merge(self.df_assays, on='AID', how='left')
|
| 79 |
+
convert_dict = {col: 'int32' if col != 'UID' else 'str' for col in self.df.columns }
|
| 80 |
+
self.df = self.df.astype(convert_dict)
|
| 81 |
+
del self.df_assays, self.df_bio
|
| 82 |
+
|
| 83 |
+
def drop_hts_assays(self):
|
| 84 |
+
print('Drop HTS assays...')
|
| 85 |
+
aid_counts = self.df.groupby('AID').size()
|
| 86 |
+
filtered_aids = aid_counts[aid_counts <= 100_000].index
|
| 87 |
+
self.df = self.df[self.df['AID'].isin(filtered_aids)]
|
| 88 |
+
|
| 89 |
+
def drop_targets_with_limited_data(self, na_min=50, ni_min=50):
|
| 90 |
+
print('Drop targets with not enough datapoints...')
|
| 91 |
+
unique_uids = self.df['UID'].sort_values().unique() # Sorted unique targets
|
| 92 |
+
activity_counts = self.df.groupby('UID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
|
| 93 |
+
mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
|
| 94 |
+
self.df = self.df[self.df['UID'].isin(unique_uids[mask])]
|
| 95 |
+
|
| 96 |
+
def drop_conflicting_bioactivity_measures(self, target_col='UID', compound_col='CID'):
|
| 97 |
+
"""
|
| 98 |
+
Check if each target-compound pair is associated to an unique activity value,
|
| 99 |
+
i.e. every measure either active or inactive. If not, drop it.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def process_group(group):
|
| 103 |
+
if group['Activity'].nunique() == 1:
|
| 104 |
+
return group.head(1)
|
| 105 |
+
else:
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
print('Drop conflicting datapoints...')
|
| 109 |
+
# Get unique UID-CID pairs and duplicated ones
|
| 110 |
+
df_uniques = self.df.drop_duplicates(subset=[target_col, compound_col], keep=False)
|
| 111 |
+
df_duplicates = self.df[~self.df.index.isin(df_uniques.index)]
|
| 112 |
+
|
| 113 |
+
# Check duplicated pairs
|
| 114 |
+
groups = df_duplicates.groupby([target_col, compound_col])
|
| 115 |
+
rows = []
|
| 116 |
+
for _, group in tqdm.tqdm(groups):
|
| 117 |
+
rows.append(process_group(group))
|
| 118 |
+
df_rows = pd.concat([row for row in rows if row is not None])
|
| 119 |
+
self.df = pd.concat([df_uniques, df_rows])
|
| 120 |
+
|
| 121 |
+
def add_smiles(self, chunk_size=10_000_000):
|
| 122 |
+
print('Retrieve SMILES...')
|
| 123 |
+
cids = self.df.CID.astype(int).unique()
|
| 124 |
+
filtered_chunks = []
|
| 125 |
+
for chunk in pd.read_table(f'{self.pubchem_dir}/smiles.tsv.gz', chunksize=chunk_size, names=['CID', 'SMILES']):
|
| 126 |
+
filtered_chunk = chunk[chunk['CID'].isin(cids)]
|
| 127 |
+
filtered_chunks.append(filtered_chunk)
|
| 128 |
+
if self.debug:
|
| 129 |
+
break
|
| 130 |
+
df_smiles = pd.concat(filtered_chunks)
|
| 131 |
+
|
| 132 |
+
def cleanup(smiles):
|
| 133 |
+
sm = Standardizer(metal_disconnect=True, canon_taut=True)
|
| 134 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 135 |
+
try:
|
| 136 |
+
standardized_mol, _ = sm.standardize_mol(mol)
|
| 137 |
+
return Chem.MolToSmiles(standardized_mol)
|
| 138 |
+
except:
|
| 139 |
+
print(smiles)
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
df_smiles['SMILES'] = df_smiles['SMILES'].swifter.apply(lambda smi: cleanup(smi))
|
| 143 |
+
df_smiles.dropna(inplace=True)
|
| 144 |
+
|
| 145 |
+
self.df = self.df.merge(df_smiles, on='CID', how='left').dropna(subset=['SMILES'])
|
| 146 |
+
|
| 147 |
+
def print_stats(self):
|
| 148 |
+
nassays = self.df['AID'].nunique()
|
| 149 |
+
ntargets = self.df["UID"].nunique()
|
| 150 |
+
ncompounds = self.df["CID"].nunique()
|
| 151 |
+
nactvities = self.df.shape[0]
|
| 152 |
+
print(f'{ntargets: >5,} targets | {nassays: >6,} assays | {ncompounds: >9,} compounds | {nactvities: >10,} activity data points')
|
| 153 |
+
|
| 154 |
+
def save(self, fname='data/pubchem24_preprocessed.csv.gz'):
|
| 155 |
+
print(f'Save to {fname}...')
|
| 156 |
+
self.df.to_csv(fname, index=False)
|
| 157 |
+
|
| 158 |
+
def load(self, fname):
|
| 159 |
+
print(f'Load from {fname}...')
|
| 160 |
+
self.df = pd.read_csv(fname)
|
| 161 |
+
|
| 162 |
+
def add_protein_classifications(self):
|
| 163 |
+
"""
|
| 164 |
+
Retrieve protein classification
|
| 165 |
+
"""
|
| 166 |
+
print('Retrieve protein classifications...')
|
| 167 |
+
protein_class = pd.read_csv(self.prot_class_path)
|
| 168 |
+
print(protein_class)
|
| 169 |
+
# protein_class['UID'] = protein_class['target_id'].swifter.apply(lambda x: x.split('_')[0])
|
| 170 |
+
self.df = self.df.merge(protein_class[['UID', 'Organism', 'L1', 'L2']], on='UID', how='left')
|
| 171 |
+
|
| 172 |
+
if __name__ == '__main__':
|
| 173 |
+
# Create an instance of PubChemFilter class
|
| 174 |
+
pubchem_filter = PubChemFilter(PUBCHEM_DIR, FSMOL_UID_PATH, PROT_CLASS_PATH, MHNFS_PATH, False)
|
| 175 |
+
|
| 176 |
+
# Call methods of the class as needed
|
| 177 |
+
pubchem_filter.load_and_filter_assays()
|
| 178 |
+
pubchem_filter.load_and_filter_bioactivities()
|
| 179 |
+
pubchem_filter.merge_assay_and_activity_data()
|
| 180 |
+
pubchem_filter.print_stats()
|
| 181 |
+
pubchem_filter.drop_hts_assays()
|
| 182 |
+
pubchem_filter.print_stats()
|
| 183 |
+
pubchem_filter.drop_targets_with_limited_data()
|
| 184 |
+
pubchem_filter.print_stats()
|
| 185 |
+
pubchem_filter.drop_conflicting_bioactivity_measures()
|
| 186 |
+
pubchem_filter.print_stats()
|
| 187 |
+
pubchem_filter.drop_targets_with_limited_data()
|
| 188 |
+
pubchem_filter.print_stats()
|
| 189 |
+
pubchem_filter.add_smiles()
|
| 190 |
+
pubchem_filter.print_stats()
|
| 191 |
+
pubchem_filter.drop_conflicting_bioactivity_measures(compound_col='SMILES')
|
| 192 |
+
pubchem_filter.print_stats()
|
| 193 |
+
pubchem_filter.drop_targets_with_limited_data()
|
| 194 |
+
pubchem_filter.print_stats()
|
| 195 |
+
pubchem_filter.add_protein_classifications()
|
| 196 |
+
pubchem_filter.save(fname='data/pubchem24/preprocessed.csv.gz')
|
| 197 |
+
|
pubchem_experiment/make_predictions.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chunk
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
# from rdkit.Chem import AllChem
|
| 8 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 9 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
import clamp
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_morgan_fingerprints(smiles_list, radius=4, n_bits=4048):
|
| 19 |
+
"""
|
| 20 |
+
Generate Morgan fingerprints for a list of SMILES.
|
| 21 |
+
"""
|
| 22 |
+
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=radius,fpSize=n_bits)
|
| 23 |
+
mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
|
| 24 |
+
fps = []
|
| 25 |
+
for smiles, mol in zip(smiles_list, mols):
|
| 26 |
+
if mol is None:
|
| 27 |
+
print(smiles)
|
| 28 |
+
fps.append(None)
|
| 29 |
+
else:
|
| 30 |
+
fps.append(mfpgen.GetFingerprintAsNumPy(mol))
|
| 31 |
+
# np.array([mfpgen.GetFingerprintAsNumPy(mol) for mol in mols])
|
| 32 |
+
return fps
|
| 33 |
+
|
| 34 |
+
def rf(df, train_smiles, test_smiles):
|
| 35 |
+
"""
|
| 36 |
+
Train and test RF baseline model.
|
| 37 |
+
|
| 38 |
+
Parameters:
|
| 39 |
+
df : pd.DataFrame with 'SMILES' and 'Activity_label' columns
|
| 40 |
+
train_smiles : list of training set smiles
|
| 41 |
+
test_smiles : list of test set smiles
|
| 42 |
+
Returns:
|
| 43 |
+
preds : list of predicted labels for the test set
|
| 44 |
+
"""
|
| 45 |
+
train_df = df[df['SMILES'].isin(train_smiles)]
|
| 46 |
+
test_df = df[df['SMILES'].isin(test_smiles)]
|
| 47 |
+
|
| 48 |
+
# Generate Morgan fingerprints for training and test sets
|
| 49 |
+
X_train = generate_morgan_fingerprints(train_df['SMILES'])
|
| 50 |
+
X_test = generate_morgan_fingerprints(test_df['SMILES'])
|
| 51 |
+
|
| 52 |
+
# Extract labels
|
| 53 |
+
y_train = train_df['Activity'].values
|
| 54 |
+
|
| 55 |
+
# Train a Random Forest Classifier
|
| 56 |
+
clf = RandomForestClassifier(n_estimators=200, random_state=82)
|
| 57 |
+
clf.fit(X_train, y_train)
|
| 58 |
+
|
| 59 |
+
# Make predictions on the test set
|
| 60 |
+
try:
|
| 61 |
+
preds = clf.predict_proba(X_test)[:,1]
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(e)
|
| 64 |
+
print(test_df)
|
| 65 |
+
print(X_test)
|
| 66 |
+
|
| 67 |
+
return preds
|
| 68 |
+
|
| 69 |
+
def fh(smiles_list):
|
| 70 |
+
df = pd.read_csv('data/fh_predictions.csv')
|
| 71 |
+
preds = df[df['SMILES'].isin(smiles_list)]['Prediction'].tolist()
|
| 72 |
+
return preds
|
| 73 |
+
|
| 74 |
+
def drop_assays_with_limited_data(df, na_min=50, ni_min=100):
|
| 75 |
+
print('Drop assays with not enough datapoints...')
|
| 76 |
+
unique_uids = df['AID'].sort_values().unique() # Sorted unique targets
|
| 77 |
+
activity_counts = df.groupby('AID')['Activity'].value_counts().unstack().fillna(0) # matrix: rows=sorted targets, columns=nactive, ninactives
|
| 78 |
+
mask = ((activity_counts[1] >= na_min) & (activity_counts[0] >= ni_min) ) # Both nactives and ninactives above nmin
|
| 79 |
+
df = df[df['AID'].isin(unique_uids[mask])]
|
| 80 |
+
return df
|
| 81 |
+
|
| 82 |
+
def run(
|
| 83 |
+
n_actives : int,
|
| 84 |
+
n_inactives : int,
|
| 85 |
+
model : str = 'MHNfs',
|
| 86 |
+
task : str = 'UID',
|
| 87 |
+
input_file : str = '', # todo add path
|
| 88 |
+
output_dir : str = '', # todo add path
|
| 89 |
+
n_repeats : int = 3,
|
| 90 |
+
seed : int = 42
|
| 91 |
+
):
|
| 92 |
+
|
| 93 |
+
# Load data
|
| 94 |
+
data = pd.read_csv(input_file)
|
| 95 |
+
|
| 96 |
+
if task == 'AID':
|
| 97 |
+
data = drop_assays_with_limited_data(data, 30, 30)
|
| 98 |
+
|
| 99 |
+
# Output dir
|
| 100 |
+
output_dir = os.path.join(output_dir, model, task, f'{n_actives}+{n_inactives}x{n_repeats}')
|
| 101 |
+
print(output_dir)
|
| 102 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
# Tasks
|
| 105 |
+
tasks = data[task].value_counts(ascending=True).index.tolist()
|
| 106 |
+
# print(tasks)
|
| 107 |
+
|
| 108 |
+
if model == 'MHNfs':
|
| 109 |
+
predictor = ActivityPredictor()
|
| 110 |
+
|
| 111 |
+
# Iterate over tasks
|
| 112 |
+
for t in tqdm(tasks):
|
| 113 |
+
|
| 114 |
+
# Output file
|
| 115 |
+
output_file = os.path.join(output_dir, f'{t}.csv')
|
| 116 |
+
if os.path.exists(output_file):
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# Data for task
|
| 120 |
+
df = data[data[task] == t]
|
| 121 |
+
|
| 122 |
+
# Iterate over replicates
|
| 123 |
+
results = []
|
| 124 |
+
for i in range(n_repeats):
|
| 125 |
+
# Select support sets and test molecules
|
| 126 |
+
actives = df.loc[df['Activity'] == 1, 'SMILES'].sample(n=n_actives, random_state=seed+i).tolist()
|
| 127 |
+
inactives = df.loc[df['Activity'] == 0, 'SMILES'].sample(n=n_inactives, random_state=seed+i).tolist()
|
| 128 |
+
test_smiles = df[~df.SMILES.isin(actives+inactives)].SMILES.tolist()
|
| 129 |
+
|
| 130 |
+
if model == 'RF':
|
| 131 |
+
preds = rf(df, actives+inactives, test_smiles)
|
| 132 |
+
else:
|
| 133 |
+
if len(test_smiles) > 10_000:
|
| 134 |
+
# MHNfs breaks for over 20_000 datapoints -> Use chunks to make predictions
|
| 135 |
+
chunk_size = 10_000
|
| 136 |
+
chunks = [test_smiles[i:i + chunk_size] for i in range(0, len(test_smiles), chunk_size)]
|
| 137 |
+
preds = []
|
| 138 |
+
for chunk in chunks:
|
| 139 |
+
preds.extend( predictor.predict(chunk, actives, inactives))
|
| 140 |
+
else:
|
| 141 |
+
preds = predictor.predict(test_smiles, actives, inactives)
|
| 142 |
+
|
| 143 |
+
d = {
|
| 144 |
+
'SMILES' : test_smiles,
|
| 145 |
+
'Label' : df[df.SMILES.isin(test_smiles)].Activity,
|
| 146 |
+
'Prediction' : preds,
|
| 147 |
+
'Fold' : [i] * len(test_smiles)
|
| 148 |
+
}
|
| 149 |
+
results.append(pd.DataFrame(d))
|
| 150 |
+
|
| 151 |
+
results = pd.concat(results)
|
| 152 |
+
results.to_csv(output_file, index=False)
|
| 153 |
+
|
| 154 |
+
if __name__ == '__main__':
|
| 155 |
+
|
| 156 |
+
mhnfs_path = # mhnfs_path + '/mhnfs'
|
| 157 |
+
benchmark_path = # project_path
|
| 158 |
+
|
| 159 |
+
import sys
|
| 160 |
+
sys.path.append(mhnfs_path)
|
| 161 |
+
from src.prediction_pipeline import ActivityPredictor
|
| 162 |
+
|
| 163 |
+
support_sets = [(1,7), (2,6), (4,4)]
|
| 164 |
+
models = ['RF', 'MHNfs']
|
| 165 |
+
tasks = ['AID', 'UID']
|
| 166 |
+
|
| 167 |
+
input_file = # preprocessed_data path + '/pubchem24_preprocessed_2.csv.gz'
|
| 168 |
+
|
| 169 |
+
for support_set in support_sets:
|
| 170 |
+
for model in models:
|
| 171 |
+
for task in tasks:
|
| 172 |
+
run(*support_set, task=task, model=model, input_file=input_file)
|
pubchem_experiment/metrics.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
from rdkit.ML.Scoring.Scoring import CalcBEDROC
|
| 10 |
+
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, average_precision_score, \
|
| 11 |
+
matthews_corrcoef, precision_score, recall_score, f1_score, confusion_matrix
|
| 12 |
+
|
| 13 |
+
def specificity_score(true_labels, predicted_labels):
|
| 14 |
+
tn, fp, _, _ = confusion_matrix(true_labels, predicted_labels).ravel()
|
| 15 |
+
specificity = tn / (tn + fp)
|
| 16 |
+
return specificity
|
| 17 |
+
|
| 18 |
+
MAIN_DIR = '' # todo add project dir
|
| 19 |
+
|
| 20 |
+
def balanced_mcc_score(sensitivity, specificity, prevalence):
|
| 21 |
+
"""Returns the Matthews' correlation coefficient at the given
|
| 22 |
+
sensitivity, specificity and prevalence.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
sensitivity : float
|
| 27 |
+
The sensitivity of the model
|
| 28 |
+
specificity : float
|
| 29 |
+
The specificity of the model
|
| 30 |
+
prevalence : float
|
| 31 |
+
The prevalence of the test set
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
------
|
| 35 |
+
float
|
| 36 |
+
Matthews' correlation coefficient as a float
|
| 37 |
+
"""
|
| 38 |
+
numerator = sensitivity + specificity - 1
|
| 39 |
+
denominatorFirstTerm = sensitivity + (1 - specificity)*(1 - prevalence) / prevalence
|
| 40 |
+
denominatorSecondTerm = specificity + (1 -sensitivity)*prevalence/(1 - prevalence)
|
| 41 |
+
denominator = math.sqrt(denominatorFirstTerm * denominatorSecondTerm)
|
| 42 |
+
|
| 43 |
+
if sensitivity == 1 and specificity == 0:
|
| 44 |
+
denominator = 1
|
| 45 |
+
if sensitivity == 0 and specificity == 1:
|
| 46 |
+
denominator = 1.
|
| 47 |
+
|
| 48 |
+
return(numerator / denominator)
|
| 49 |
+
|
| 50 |
+
def ef_top_per(predictions, prevalance, top_frac=0.01):
|
| 51 |
+
|
| 52 |
+
n = int(len(predictions) * top_frac)
|
| 53 |
+
predictions = sorted(predictions, reverse=True)[:n]
|
| 54 |
+
f = np.sum(np.round(predictions)) / n
|
| 55 |
+
return f / prevalance
|
| 56 |
+
|
| 57 |
+
def compute_metrics(df):
|
| 58 |
+
"""
|
| 59 |
+
Compute a set of classification metric for single set of predictions.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
df : dataframe with true labels in 'Label' column and probabilistic predictions in 'Prediction' column
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
df_metrics: dataframe with metrics in 'Metric' column and values in 'Value' column
|
| 66 |
+
"""
|
| 67 |
+
true_labels = df['Label']
|
| 68 |
+
prevalance = sum(true_labels) / len(true_labels)
|
| 69 |
+
predictions = df['Prediction']
|
| 70 |
+
|
| 71 |
+
# print(true_labels.value_counts())
|
| 72 |
+
# print(predictions.max())
|
| 73 |
+
|
| 74 |
+
acc = accuracy_score(true_labels, predictions.round())
|
| 75 |
+
bacc = balanced_accuracy_score(true_labels, predictions.round())
|
| 76 |
+
precision = precision_score(true_labels, predictions.round(), zero_division=0.0)
|
| 77 |
+
recall = recall_score(true_labels, predictions.round())
|
| 78 |
+
specificity = specificity_score(true_labels, predictions.round())
|
| 79 |
+
mcc = matthews_corrcoef(true_labels, predictions.round())
|
| 80 |
+
bmcc = balanced_mcc_score(recall, specificity, prevalance)
|
| 81 |
+
f1 = f1_score(true_labels, predictions.round())
|
| 82 |
+
|
| 83 |
+
auc = roc_auc_score(true_labels, predictions)
|
| 84 |
+
ap = average_precision_score(true_labels, predictions)
|
| 85 |
+
dap = ap - prevalance
|
| 86 |
+
scores = df.sort_values(by='Prediction', ascending=False)[['Label', 'Prediction']].values
|
| 87 |
+
bedroc = CalcBEDROC(scores, 0, 20)
|
| 88 |
+
ef = ef_top_per(predictions, prevalance, 0.01)
|
| 89 |
+
|
| 90 |
+
metrics_dict = {'ACC': acc, 'BACC': bacc, 'MCC': mcc, 'BMCC': bmcc, 'Precision': precision, 'Recall': recall, 'F1-score': f1,
|
| 91 |
+
'AUC': auc, 'dAP': dap, 'BEDROC': bedroc, 'EF-1%' : ef}
|
| 92 |
+
df_metrics = pd.DataFrame(metrics_dict.items(), columns=['Metric', 'Value'])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
return df_metrics
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_metrics(
|
| 99 |
+
tasks : list[str] = ['AID', 'UID'],
|
| 100 |
+
models : list[str] = ['MHNfs', 'RF'],
|
| 101 |
+
settings : list[str] = ['1+1x3', '1+3x3', '1+7x3', '2+2x3', '2+6x3', '2+14x3', '4+4x3', '4+12x3', '4+28x3', '8+8x3', '8+24x3', '8+56x3'],
|
| 102 |
+
overwrite: bool = False):
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
Computes classification metrics for each combination.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
file = f'{MAIN_DIR}/results_used.csv.gz'
|
| 109 |
+
|
| 110 |
+
if overwrite:
|
| 111 |
+
df = pd.DataFrame()
|
| 112 |
+
else:
|
| 113 |
+
df = pd.read_csv(file)
|
| 114 |
+
|
| 115 |
+
path_preprocessed = "" # todo
|
| 116 |
+
df_pubchem = pd.read_csv(path_preprocessed)
|
| 117 |
+
|
| 118 |
+
for task in tasks:
|
| 119 |
+
for model in models:
|
| 120 |
+
for setting in settings:
|
| 121 |
+
dir = f'{MAIN_DIR}/predictions/{model}/{task}/{setting}'
|
| 122 |
+
try:
|
| 123 |
+
targets = [x[:-4] for x in os.listdir(dir)]
|
| 124 |
+
pubchem_targets = df_pubchem[task].astype(str).unique().tolist()
|
| 125 |
+
|
| 126 |
+
for target in tqdm(targets, desc=f'{task} - {model} - {setting}'):
|
| 127 |
+
|
| 128 |
+
if target not in pubchem_targets:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
# Skip already computed targets
|
| 132 |
+
if not overwrite and any((df['Model'] == model) & (df['Setting'] == setting) & (df['Task'] == task) & (df['TID'] == target)):
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
# Load predictions
|
| 136 |
+
df_task = pd.read_csv(f'{dir}/{target}.csv')
|
| 137 |
+
|
| 138 |
+
# Retrieve oragnism and L1 protein classification
|
| 139 |
+
try:
|
| 140 |
+
org = df_pubchem.loc[df_pubchem[task] == target, 'Organism'].values[0]
|
| 141 |
+
l1 = df_pubchem.loc[df_pubchem[task] == target, 'L1'].values[0]
|
| 142 |
+
except:
|
| 143 |
+
org = df_pubchem.loc[df_pubchem[task] == int(target), 'Organism'].values[0]
|
| 144 |
+
l1 = df_pubchem.loc[df_pubchem[task] == int(target), 'L1'].values[0]
|
| 145 |
+
if l1 == None:
|
| 146 |
+
print(target, l1)
|
| 147 |
+
|
| 148 |
+
# Compute metrics for each fold
|
| 149 |
+
for fold in df_task.Fold.unique():
|
| 150 |
+
metrics = (compute_metrics(df_task[df_task.Fold == fold]).assign(
|
| 151 |
+
Model=model, Task=task, TID=target, Organism=org, L1=l1, Setting=setting, Fold=fold,
|
| 152 |
+
)
|
| 153 |
+
).rename(columns={'Target' : task})
|
| 154 |
+
df = pd.concat([df, metrics], ignore_index=True)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(e)
|
| 157 |
+
raise e
|
| 158 |
+
|
| 159 |
+
df.to_csv(file, index=False)
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
#get_metrics()
|
| 163 |
+
get_metrics(settings=['1+7x3', '2+6x3', '4+4x3', '2+14x3', '4+12x3','8+8x3'], overwrite=True)
|